diff --git a/.env.example b/.env.example index d87021ea4b..096903299e 100644 --- a/.env.example +++ b/.env.example @@ -20,6 +20,11 @@ DOMAIN_CLIENT=http://localhost:3080 DOMAIN_SERVER=http://localhost:3080 NO_INDEX=true +# Use the address that is at most n number of hops away from the Express application. +# req.socket.remoteAddress is the first hop, and the rest are looked for in the X-Forwarded-For header from right to left. +# A value of 0 means that the first untrusted address would be req.socket.remoteAddress, i.e. there is no reverse proxy. +# Defaulted to 1. +TRUST_PROXY=1 #===============# # JSON Logging # @@ -83,7 +88,7 @@ PROXY= #============# ANTHROPIC_API_KEY=user_provided -# ANTHROPIC_MODELS=claude-3-5-haiku-20241022,claude-3-5-sonnet-20241022,claude-3-5-sonnet-latest,claude-3-5-sonnet-20240620,claude-3-opus-20240229,claude-3-sonnet-20240229,claude-3-haiku-20240307,claude-2.1,claude-2,claude-1.2,claude-1,claude-1-100k,claude-instant-1,claude-instant-1-100k +# ANTHROPIC_MODELS=claude-3-7-sonnet-latest,claude-3-7-sonnet-20250219,claude-3-5-haiku-20241022,claude-3-5-sonnet-20241022,claude-3-5-sonnet-latest,claude-3-5-sonnet-20240620,claude-3-opus-20240229,claude-3-sonnet-20240229,claude-3-haiku-20240307,claude-2.1,claude-2,claude-1.2,claude-1,claude-1-100k,claude-instant-1,claude-instant-1-100k # ANTHROPIC_REVERSE_PROXY= #============# @@ -137,12 +142,12 @@ GOOGLE_KEY=user_provided # GOOGLE_AUTH_HEADER=true # Gemini API (AI Studio) -# GOOGLE_MODELS=gemini-2.0-flash-exp,gemini-2.0-flash-thinking-exp-1219,gemini-exp-1121,gemini-exp-1114,gemini-1.5-flash-latest,gemini-1.0-pro,gemini-1.0-pro-001,gemini-1.0-pro-latest,gemini-1.0-pro-vision-latest,gemini-1.5-pro-latest,gemini-pro,gemini-pro-vision +# GOOGLE_MODELS=gemini-2.5-pro-preview-05-06,gemini-2.5-flash-preview-04-17,gemini-2.0-flash-001,gemini-2.0-flash-exp,gemini-2.0-flash-lite-001,gemini-1.5-pro-002,gemini-1.5-flash-002 # Vertex AI -# GOOGLE_MODELS=gemini-1.5-flash-preview-0514,gemini-1.5-pro-preview-0514,gemini-1.0-pro-vision-001,gemini-1.0-pro-002,gemini-1.0-pro-001,gemini-pro-vision,gemini-1.0-pro +# GOOGLE_MODELS=gemini-2.5-pro-preview-05-06,gemini-2.5-flash-preview-04-17,gemini-2.0-flash-001,gemini-2.0-flash-exp,gemini-2.0-flash-lite-001,gemini-1.5-pro-002,gemini-1.5-flash-002 -# GOOGLE_TITLE_MODEL=gemini-pro +# GOOGLE_TITLE_MODEL=gemini-2.0-flash-lite-001 # GOOGLE_LOC=us-central1 @@ -170,7 +175,7 @@ GOOGLE_KEY=user_provided #============# OPENAI_API_KEY=user_provided -# OPENAI_MODELS=o1,o1-mini,o1-preview,gpt-4o,chatgpt-4o-latest,gpt-4o-mini,gpt-3.5-turbo-0125,gpt-3.5-turbo-0301,gpt-3.5-turbo,gpt-4,gpt-4-0613,gpt-4-vision-preview,gpt-3.5-turbo-0613,gpt-3.5-turbo-16k-0613,gpt-4-0125-preview,gpt-4-turbo-preview,gpt-4-1106-preview,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-instruct-0914,gpt-3.5-turbo-16k +# OPENAI_MODELS=o1,o1-mini,o1-preview,gpt-4o,gpt-4.5-preview,chatgpt-4o-latest,gpt-4o-mini,gpt-3.5-turbo-0125,gpt-3.5-turbo-0301,gpt-3.5-turbo,gpt-4,gpt-4-0613,gpt-4-vision-preview,gpt-3.5-turbo-0613,gpt-3.5-turbo-16k-0613,gpt-4-0125-preview,gpt-4-turbo-preview,gpt-4-1106-preview,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-instruct-0914,gpt-3.5-turbo-16k DEBUG_OPENAI=false @@ -204,12 +209,6 @@ ASSISTANTS_API_KEY=user_provided # More info, including how to enable use of Assistants with Azure here: # https://www.librechat.ai/docs/configuration/librechat_yaml/ai_endpoints/azure#using-assistants-with-azure -#============# -# OpenRouter # -#============# -# !!!Warning: Use the variable above instead of this one. Using this one will override the OpenAI endpoint -# OPENROUTER_API_KEY= - #============# # Plugins # #============# @@ -232,6 +231,14 @@ AZURE_AI_SEARCH_SEARCH_OPTION_QUERY_TYPE= AZURE_AI_SEARCH_SEARCH_OPTION_TOP= AZURE_AI_SEARCH_SEARCH_OPTION_SELECT= +# OpenAI Image Tools Customization +#---------------- +# IMAGE_GEN_OAI_DESCRIPTION_WITH_FILES=Custom description for image generation tool when files are present +# IMAGE_GEN_OAI_DESCRIPTION_NO_FILES=Custom description for image generation tool when no files are present +# IMAGE_EDIT_OAI_DESCRIPTION=Custom description for image editing tool +# IMAGE_GEN_OAI_PROMPT_DESCRIPTION=Custom prompt description for image generation tool +# IMAGE_EDIT_OAI_PROMPT_DESCRIPTION=Custom prompt description for image editing tool + # DALL·E #---------------- # DALLE_API_KEY= @@ -249,6 +256,13 @@ AZURE_AI_SEARCH_SEARCH_OPTION_SELECT= # DALLE3_AZURE_API_VERSION= # DALLE2_AZURE_API_VERSION= +# Flux +#----------------- +FLUX_API_BASE_URL=https://api.us1.bfl.ai +# FLUX_API_BASE_URL = 'https://api.bfl.ml'; + +# Get your API key at https://api.us1.bfl.ai/auth/profile +# FLUX_API_KEY= # Google #----------------- @@ -292,6 +306,10 @@ MEILI_NO_ANALYTICS=true MEILI_HOST=http://0.0.0.0:7700 MEILI_MASTER_KEY=DrhYf7zENyR6AlUCKmnz0eYASOQdl6zxH7s7MKFSfFCt +# Optional: Disable indexing, useful in a multi-node setup +# where only one instance should perform an index sync. +# MEILI_NO_SYNC=true + #==================================================# # Speech to Text & Text to Speech # #==================================================# @@ -354,7 +372,7 @@ ILLEGAL_MODEL_REQ_SCORE=5 # Balance # #========================# -CHECK_BALANCE=false +# CHECK_BALANCE=false # START_BALANCE=20000 # note: the number of tokens that will be credited after registration. #========================# @@ -389,7 +407,7 @@ FACEBOOK_CALLBACK_URL=/oauth/facebook/callback GITHUB_CLIENT_ID= GITHUB_CLIENT_SECRET= GITHUB_CALLBACK_URL=/oauth/github/callback -# GitHub Eenterprise +# GitHub Enterprise # GITHUB_ENTERPRISE_BASE_URL= # GITHUB_ENTERPRISE_USER_AGENT= @@ -422,15 +440,19 @@ OPENID_NAME_CLAIM= OPENID_BUTTON_LABEL= OPENID_IMAGE_URL= +# Set to true to automatically redirect to the OpenID provider when a user visits the login page +# This will bypass the login form completely for users, only use this if OpenID is your only authentication method +OPENID_AUTO_REDIRECT=false # LDAP LDAP_URL= LDAP_BIND_DN= LDAP_BIND_CREDENTIALS= LDAP_USER_SEARCH_BASE= -LDAP_SEARCH_FILTER=mail={{username}} +#LDAP_SEARCH_FILTER="mail=" LDAP_CA_CERT_PATH= # LDAP_TLS_REJECT_UNAUTHORIZED= +# LDAP_STARTTLS= # LDAP_LOGIN_USES_USERNAME=true # LDAP_ID= # LDAP_USERNAME= @@ -463,6 +485,24 @@ FIREBASE_STORAGE_BUCKET= FIREBASE_MESSAGING_SENDER_ID= FIREBASE_APP_ID= +#========================# +# S3 AWS Bucket # +#========================# + +AWS_ENDPOINT_URL= +AWS_ACCESS_KEY_ID= +AWS_SECRET_ACCESS_KEY= +AWS_REGION= +AWS_BUCKET_NAME= + +#========================# +# Azure Blob Storage # +#========================# + +AZURE_STORAGE_CONNECTION_STRING= +AZURE_STORAGE_PUBLIC_ACCESS=false +AZURE_CONTAINER_NAME=files + #========================# # Shared Links # #========================# @@ -495,6 +535,16 @@ HELP_AND_FAQ_URL=https://librechat.ai # Google tag manager id #ANALYTICS_GTM_ID=user provided google tag manager id +#===============# +# REDIS Options # +#===============# + +# REDIS_URI=10.10.10.10:6379 +# USE_REDIS=true + +# USE_REDIS_CLUSTER=true +# REDIS_CA=/path/to/ca.crt + #==================================================# # Others # #==================================================# @@ -502,9 +552,6 @@ HELP_AND_FAQ_URL=https://librechat.ai # NODE_ENV= -# REDIS_URI= -# USE_REDIS= - # E2E_USER_EMAIL= # E2E_USER_PASSWORD= @@ -527,4 +574,4 @@ HELP_AND_FAQ_URL=https://librechat.ai #=====================================================# # OpenWeather # #=====================================================# -OPENWEATHER_API_KEY= \ No newline at end of file +OPENWEATHER_API_KEY= diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 5951ed694e..09444a1b44 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -24,22 +24,40 @@ Project maintainers have the right and responsibility to remove, edit, or reject ## To contribute to this project, please adhere to the following guidelines: -## 1. Development notes +## 1. Development Setup -1. Before starting work, make sure your main branch has the latest commits with `npm run update` -2. Run linting command to find errors: `npm run lint`. Alternatively, ensure husky pre-commit checks are functioning. +1. Use Node.JS 20.x. +2. Install typescript globally: `npm i -g typescript`. +3. Run `npm ci` to install dependencies. +4. Build the data provider: `npm run build:data-provider`. +5. Build MCP: `npm run build:mcp`. +6. Build data schemas: `npm run build:data-schemas`. +7. Setup and run unit tests: + - Copy `.env.test`: `cp api/test/.env.test.example api/test/.env.test`. + - Run backend unit tests: `npm run test:api`. + - Run frontend unit tests: `npm run test:client`. +8. Setup and run integration tests: + - Build client: `cd client && npm run build`. + - Create `.env`: `cp .env.example .env`. + - Install [MongoDB Community Edition](https://www.mongodb.com/docs/manual/administration/install-community/), ensure that `mongosh` connects to your local instance. + - Run: `npx install playwright`, then `npx playwright install`. + - Copy `config.local`: `cp e2e/config.local.example.ts e2e/config.local.ts`. + - Copy `librechat.yaml`: `cp librechat.example.yaml librechat.yaml`. + - Run: `npm run e2e`. + +## 2. Development Notes + +1. Before starting work, make sure your main branch has the latest commits with `npm run update`. +3. Run linting command to find errors: `npm run lint`. Alternatively, ensure husky pre-commit checks are functioning. 3. After your changes, reinstall packages in your current branch using `npm run reinstall` and ensure everything still works. - Restart the ESLint server ("ESLint: Restart ESLint Server" in VS Code command bar) and your IDE after reinstalling or updating. 4. Clear web app localStorage and cookies before and after changes. -5. For frontend changes: - - Install typescript globally: `npm i -g typescript`. - - Compile typescript before and after changes to check for introduced errors: `cd client && tsc --noEmit`. -6. Run tests locally: - - Backend unit tests: `npm run test:api` - - Frontend unit tests: `npm run test:client` - - Integration tests: `npm run e2e` (requires playwright installed, `npx install playwright`) +5. For frontend changes, compile typescript before and after changes to check for introduced errors: `cd client && npm run build`. +6. Run backend unit tests: `npm run test:api`. +7. Run frontend unit tests: `npm run test:client`. +8. Run integration tests: `npm run e2e`. -## 2. Git Workflow +## 3. Git Workflow We utilize a GitFlow workflow to manage changes to this project's codebase. Follow these general steps when contributing code: @@ -49,7 +67,7 @@ We utilize a GitFlow workflow to manage changes to this project's codebase. Foll 4. Submit a pull request with a clear and concise description of your changes and the reasons behind them. 5. We will review your pull request, provide feedback as needed, and eventually merge the approved changes into the main branch. -## 3. Commit Message Format +## 4. Commit Message Format We follow the [semantic format](https://gist.github.com/joshbuchea/6f47e86d2510bce28f8e7f42ae84c716) for commit messages. @@ -76,7 +94,7 @@ feat: add hat wobble ``` -## 4. Pull Request Process +## 5. Pull Request Process When submitting a pull request, please follow these guidelines: @@ -91,7 +109,7 @@ Ensure that your changes meet the following criteria: - The commit history is clean and easy to follow. You can use `git rebase` or `git merge --squash` to clean your commit history before submitting the pull request. - The pull request description clearly outlines the changes and the reasons behind them. Be sure to include the steps to test the pull request. -## 5. Naming Conventions +## 6. Naming Conventions Apply the following naming conventions to branches, labels, and other Git-related entities: @@ -100,7 +118,7 @@ Apply the following naming conventions to branches, labels, and other Git-relate - **JS/TS:** Directories and file names: Descriptive and camelCase. First letter uppercased for React files (e.g., `helperFunction.ts, ReactComponent.tsx`). - **Docs:** Directories and file names: Descriptive and snake_case (e.g., `config_files.md`). -## 6. TypeScript Conversion +## 7. TypeScript Conversion 1. **Original State**: The project was initially developed entirely in JavaScript (JS). @@ -126,7 +144,7 @@ Apply the following naming conventions to branches, labels, and other Git-relate - **Current Stance**: At present, this backend transition is of lower priority and might not be pursued. -## 7. Module Import Conventions +## 8. Module Import Conventions - `npm` packages first, - from shortest line (top) to longest (bottom) diff --git a/.github/ISSUE_TEMPLATE/BUG-REPORT.yml b/.github/ISSUE_TEMPLATE/BUG-REPORT.yml index 3a3b828ee1..610396959f 100644 --- a/.github/ISSUE_TEMPLATE/BUG-REPORT.yml +++ b/.github/ISSUE_TEMPLATE/BUG-REPORT.yml @@ -79,6 +79,8 @@ body: For UI-related issues, browser console logs can be very helpful. You can provide these as screenshots or paste the text here. render: shell + validations: + required: true - type: textarea id: screenshots attributes: diff --git a/.github/ISSUE_TEMPLATE/LOCIZE_TRANSLATION_ACCESS_REQUEST.yml b/.github/ISSUE_TEMPLATE/LOCIZE_TRANSLATION_ACCESS_REQUEST.yml new file mode 100644 index 0000000000..49b01a814d --- /dev/null +++ b/.github/ISSUE_TEMPLATE/LOCIZE_TRANSLATION_ACCESS_REQUEST.yml @@ -0,0 +1,42 @@ +name: Locize Translation Access Request +description: Request access to an additional language in Locize for LibreChat translations. +title: "Locize Access Request: " +labels: ["🌍 i18n", "🔑 access request"] +body: + - type: markdown + attributes: + value: | + Thank you for your interest in contributing to LibreChat translations! + Please fill out the form below to request access to an additional language in **Locize**. + + **🔗 Available Languages:** [View the list here](https://www.librechat.ai/docs/translation) + + **📌 Note:** Ensure that the requested language is supported before submitting your request. + - type: input + id: account_name + attributes: + label: Locize Account Name + description: Please provide your Locize account name (e.g., John Doe). + placeholder: e.g., John Doe + validations: + required: true + - type: input + id: language_requested + attributes: + label: Language Code (ISO 639-1) + description: | + Enter the **ISO 639-1** language code for the language you want to translate into. + Example: `es` for Spanish, `zh-Hant` for Traditional Chinese. + + **🔗 Reference:** [Available Languages](https://www.librechat.ai/docs/translation) + placeholder: e.g., es + validations: + required: true + - type: checkboxes + id: agreement + attributes: + label: Agreement + description: By submitting this request, you confirm that you will contribute responsibly and adhere to the project guidelines. + options: + - label: I agree to use my access solely for contributing to LibreChat translations. + required: true \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/QUESTION.yml b/.github/ISSUE_TEMPLATE/QUESTION.yml deleted file mode 100644 index c66e6baa3b..0000000000 --- a/.github/ISSUE_TEMPLATE/QUESTION.yml +++ /dev/null @@ -1,50 +0,0 @@ -name: Question -description: Ask your question -title: "[Question]: " -labels: ["❓ question"] -body: - - type: markdown - attributes: - value: | - Thanks for taking the time to fill this! - - type: textarea - id: what-is-your-question - attributes: - label: What is your question? - description: Please give as many details as possible - placeholder: Please give as many details as possible - validations: - required: true - - type: textarea - id: more-details - attributes: - label: More Details - description: Please provide more details if needed. - placeholder: Please provide more details if needed. - validations: - required: true - - type: dropdown - id: browsers - attributes: - label: What is the main subject of your question? - multiple: true - options: - - Documentation - - Installation - - UI - - Endpoints - - User System/OAuth - - Other - - type: textarea - id: screenshots - attributes: - label: Screenshots - description: If applicable, add screenshots to help explain your problem. You can drag and drop, paste images directly here or link to them. - - type: checkboxes - id: terms - attributes: - label: Code of Conduct - description: By submitting this issue, you agree to follow our [Code of Conduct](https://github.com/danny-avila/LibreChat/blob/main/.github/CODE_OF_CONDUCT.md) - options: - - label: I agree to follow this project's Code of Conduct - required: true diff --git a/.github/configuration-release.json b/.github/configuration-release.json new file mode 100644 index 0000000000..68fe80ed8f --- /dev/null +++ b/.github/configuration-release.json @@ -0,0 +1,60 @@ +{ + "categories": [ + { + "title": "### ✨ New Features", + "labels": ["feat"] + }, + { + "title": "### 🌍 Internationalization", + "labels": ["i18n"] + }, + { + "title": "### 👐 Accessibility", + "labels": ["a11y"] + }, + { + "title": "### 🔧 Fixes", + "labels": ["Fix", "fix"] + }, + { + "title": "### ⚙️ Other Changes", + "labels": ["ci", "style", "docs", "refactor", "chore"] + } + ], + "ignore_labels": [ + "🔁 duplicate", + "📊 analytics", + "🌱 good first issue", + "🔍 investigation", + "🙏 help wanted", + "❌ invalid", + "❓ question", + "🚫 wontfix", + "🚀 release", + "version" + ], + "base_branches": ["main"], + "sort": { + "order": "ASC", + "on_property": "mergedAt" + }, + "label_extractor": [ + { + "pattern": "^(?:[^A-Za-z0-9]*)(feat|fix|chore|docs|refactor|ci|style|a11y|i18n)\\s*:", + "target": "$1", + "flags": "i", + "on_property": "title", + "method": "match" + }, + { + "pattern": "^(?:[^A-Za-z0-9]*)(v\\d+\\.\\d+\\.\\d+(?:-rc\\d+)?).*", + "target": "version", + "flags": "i", + "on_property": "title", + "method": "match" + } + ], + "template": "## [#{{TO_TAG}}] - #{{TO_TAG_DATE}}\n\nChanges from #{{FROM_TAG}} to #{{TO_TAG}}.\n\n#{{CHANGELOG}}\n\n[See full release details][release-#{{TO_TAG}}]\n\n[release-#{{TO_TAG}}]: https://github.com/#{{OWNER}}/#{{REPO}}/releases/tag/#{{TO_TAG}}\n\n---", + "pr_template": "- #{{TITLE}} by **@#{{AUTHOR}}** in [##{{NUMBER}}](#{{URL}})", + "empty_template": "- no changes" +} \ No newline at end of file diff --git a/.github/configuration-unreleased.json b/.github/configuration-unreleased.json new file mode 100644 index 0000000000..29eaf5e13b --- /dev/null +++ b/.github/configuration-unreleased.json @@ -0,0 +1,68 @@ +{ + "categories": [ + { + "title": "### ✨ New Features", + "labels": ["feat"] + }, + { + "title": "### 🌍 Internationalization", + "labels": ["i18n"] + }, + { + "title": "### 👐 Accessibility", + "labels": ["a11y"] + }, + { + "title": "### 🔧 Fixes", + "labels": ["Fix", "fix"] + }, + { + "title": "### ⚙️ Other Changes", + "labels": ["ci", "style", "docs", "refactor", "chore"] + } + ], + "ignore_labels": [ + "🔁 duplicate", + "📊 analytics", + "🌱 good first issue", + "🔍 investigation", + "🙏 help wanted", + "❌ invalid", + "❓ question", + "🚫 wontfix", + "🚀 release", + "version", + "action" + ], + "base_branches": ["main"], + "sort": { + "order": "ASC", + "on_property": "mergedAt" + }, + "label_extractor": [ + { + "pattern": "^(?:[^A-Za-z0-9]*)(feat|fix|chore|docs|refactor|ci|style|a11y|i18n)\\s*:", + "target": "$1", + "flags": "i", + "on_property": "title", + "method": "match" + }, + { + "pattern": "^(?:[^A-Za-z0-9]*)(v\\d+\\.\\d+\\.\\d+(?:-rc\\d+)?).*", + "target": "version", + "flags": "i", + "on_property": "title", + "method": "match" + }, + { + "pattern": "^(?:[^A-Za-z0-9]*)(action)\\b.*", + "target": "action", + "flags": "i", + "on_property": "title", + "method": "match" + } + ], + "template": "## [Unreleased]\n\n#{{CHANGELOG}}\n\n---", + "pr_template": "- #{{TITLE}} by **@#{{AUTHOR}}** in [##{{NUMBER}}](#{{URL}})", + "empty_template": "- no changes" +} \ No newline at end of file diff --git a/.github/workflows/backend-review.yml b/.github/workflows/backend-review.yml index 5bc3d3b2db..b7bccecae8 100644 --- a/.github/workflows/backend-review.yml +++ b/.github/workflows/backend-review.yml @@ -39,6 +39,9 @@ jobs: - name: Install MCP Package run: npm run build:mcp + - name: Install Data Schemas Package + run: npm run build:data-schemas + - name: Create empty auth.json file run: | mkdir -p api/data @@ -61,4 +64,7 @@ jobs: run: cd api && npm run test:ci - name: Run librechat-data-provider unit tests - run: cd packages/data-provider && npm run test:ci \ No newline at end of file + run: cd packages/data-provider && npm run test:ci + + - name: Run librechat-mcp unit tests + run: cd packages/mcp && npm run test:ci \ No newline at end of file diff --git a/.github/workflows/data-schemas.yml b/.github/workflows/data-schemas.yml new file mode 100644 index 0000000000..fee72fbe02 --- /dev/null +++ b/.github/workflows/data-schemas.yml @@ -0,0 +1,58 @@ +name: Publish `@librechat/data-schemas` to NPM + +on: + push: + branches: + - main + paths: + - 'packages/data-schemas/package.json' + workflow_dispatch: + inputs: + reason: + description: 'Reason for manual trigger' + required: false + default: 'Manual publish requested' + +jobs: + build-and-publish: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Use Node.js + uses: actions/setup-node@v4 + with: + node-version: '18.x' + + - name: Install dependencies + run: cd packages/data-schemas && npm ci + + - name: Build + run: cd packages/data-schemas && npm run build + + - name: Set up npm authentication + run: echo "//registry.npmjs.org/:_authToken=${{ secrets.PUBLISH_NPM_TOKEN }}" > ~/.npmrc + + - name: Check version change + id: check + working-directory: packages/data-schemas + run: | + PACKAGE_VERSION=$(node -p "require('./package.json').version") + PUBLISHED_VERSION=$(npm view @librechat/data-schemas version 2>/dev/null || echo "0.0.0") + if [ "$PACKAGE_VERSION" = "$PUBLISHED_VERSION" ]; then + echo "No version change, skipping publish" + echo "skip=true" >> $GITHUB_OUTPUT + else + echo "Version changed, proceeding with publish" + echo "skip=false" >> $GITHUB_OUTPUT + fi + + - name: Pack package + if: steps.check.outputs.skip != 'true' + working-directory: packages/data-schemas + run: npm pack + + - name: Publish + if: steps.check.outputs.skip != 'true' + working-directory: packages/data-schemas + run: npm publish *.tgz --access public \ No newline at end of file diff --git a/.github/workflows/generate-release-changelog-pr.yml b/.github/workflows/generate-release-changelog-pr.yml new file mode 100644 index 0000000000..405f0ca6dc --- /dev/null +++ b/.github/workflows/generate-release-changelog-pr.yml @@ -0,0 +1,95 @@ +name: Generate Release Changelog PR + +on: + push: + tags: + - 'v*.*.*' + workflow_dispatch: + +jobs: + generate-release-changelog-pr: + permissions: + contents: write # Needed for pushing commits and creating branches. + pull-requests: write + runs-on: ubuntu-latest + steps: + # 1. Checkout the repository (with full history). + - name: Checkout Repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + # 2. Generate the release changelog using our custom configuration. + - name: Generate Release Changelog + id: generate_release + uses: mikepenz/release-changelog-builder-action@v5.1.0 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + configuration: ".github/configuration-release.json" + owner: ${{ github.repository_owner }} + repo: ${{ github.event.repository.name }} + outputFile: CHANGELOG-release.md + + # 3. Update the main CHANGELOG.md: + # - If it doesn't exist, create it with a basic header. + # - Remove the "Unreleased" section (if present). + # - Prepend the new release changelog above previous releases. + # - Remove all temporary files before committing. + - name: Update CHANGELOG.md + run: | + # Determine the release tag, e.g. "v1.2.3" + TAG=${GITHUB_REF##*/} + echo "Using release tag: $TAG" + + # Ensure CHANGELOG.md exists; if not, create a basic header. + if [ ! -f CHANGELOG.md ]; then + echo "# Changelog" > CHANGELOG.md + echo "" >> CHANGELOG.md + echo "All notable changes to this project will be documented in this file." >> CHANGELOG.md + echo "" >> CHANGELOG.md + fi + + echo "Updating CHANGELOG.md…" + + # Remove the "Unreleased" section (from "## [Unreleased]" until the first occurrence of '---') if it exists. + if grep -q "^## \[Unreleased\]" CHANGELOG.md; then + awk '/^## \[Unreleased\]/{flag=1} flag && /^---/{flag=0; next} !flag' CHANGELOG.md > CHANGELOG.cleaned + else + cp CHANGELOG.md CHANGELOG.cleaned + fi + + # Split the cleaned file into: + # - header.md: content before the first release header ("## [v..."). + # - tail.md: content from the first release header onward. + awk '/^## \[v/{exit} {print}' CHANGELOG.cleaned > header.md + awk 'f{print} /^## \[v/{f=1; print}' CHANGELOG.cleaned > tail.md + + # Combine header, the new release changelog, and the tail. + echo "Combining updated changelog parts..." + cat header.md CHANGELOG-release.md > CHANGELOG.md.new + echo "" >> CHANGELOG.md.new + cat tail.md >> CHANGELOG.md.new + + mv CHANGELOG.md.new CHANGELOG.md + + # Remove temporary files. + rm -f CHANGELOG.cleaned header.md tail.md CHANGELOG-release.md + + echo "Final CHANGELOG.md content:" + cat CHANGELOG.md + + # 4. Create (or update) the Pull Request with the updated CHANGELOG.md. + - name: Create Pull Request + uses: peter-evans/create-pull-request@v7 + with: + token: ${{ secrets.GITHUB_TOKEN }} + sign-commits: true + commit-message: "chore: update CHANGELOG for release ${{ github.ref_name }}" + base: main + branch: "changelog/${{ github.ref_name }}" + reviewers: danny-avila + title: "📜 docs: Changelog for release ${{ github.ref_name }}" + body: | + **Description**: + - This PR updates the CHANGELOG.md by removing the "Unreleased" section and adding new release notes for release ${{ github.ref_name }} above previous releases. diff --git a/.github/workflows/generate-unreleased-changelog-pr.yml b/.github/workflows/generate-unreleased-changelog-pr.yml new file mode 100644 index 0000000000..133e19f1e2 --- /dev/null +++ b/.github/workflows/generate-unreleased-changelog-pr.yml @@ -0,0 +1,107 @@ +name: Generate Unreleased Changelog PR + +on: + schedule: + - cron: "0 0 * * 1" # Runs every Monday at 00:00 UTC + workflow_dispatch: + +jobs: + generate-unreleased-changelog-pr: + permissions: + contents: write # Needed for pushing commits and creating branches. + pull-requests: write + runs-on: ubuntu-latest + steps: + # 1. Checkout the repository on main. + - name: Checkout Repository on Main + uses: actions/checkout@v4 + with: + ref: main + fetch-depth: 0 + + # 4. Get the latest version tag. + - name: Get Latest Tag + id: get_latest_tag + run: | + LATEST_TAG=$(git describe --tags $(git rev-list --tags --max-count=1) || echo "none") + echo "Latest tag: $LATEST_TAG" + echo "tag=$LATEST_TAG" >> $GITHUB_OUTPUT + + # 5. Generate the Unreleased changelog. + - name: Generate Unreleased Changelog + id: generate_unreleased + uses: mikepenz/release-changelog-builder-action@v5.1.0 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + configuration: ".github/configuration-unreleased.json" + owner: ${{ github.repository_owner }} + repo: ${{ github.event.repository.name }} + outputFile: CHANGELOG-unreleased.md + fromTag: ${{ steps.get_latest_tag.outputs.tag }} + toTag: main + + # 7. Update CHANGELOG.md with the new Unreleased section. + - name: Update CHANGELOG.md + id: update_changelog + run: | + # Create CHANGELOG.md if it doesn't exist. + if [ ! -f CHANGELOG.md ]; then + echo "# Changelog" > CHANGELOG.md + echo "" >> CHANGELOG.md + echo "All notable changes to this project will be documented in this file." >> CHANGELOG.md + echo "" >> CHANGELOG.md + fi + + echo "Updating CHANGELOG.md…" + + # Extract content before the "## [Unreleased]" (or first version header if missing). + if grep -q "^## \[Unreleased\]" CHANGELOG.md; then + awk '/^## \[Unreleased\]/{exit} {print}' CHANGELOG.md > CHANGELOG_TMP.md + else + awk '/^## \[v/{exit} {print}' CHANGELOG.md > CHANGELOG_TMP.md + fi + + # Append the generated Unreleased changelog. + echo "" >> CHANGELOG_TMP.md + cat CHANGELOG-unreleased.md >> CHANGELOG_TMP.md + echo "" >> CHANGELOG_TMP.md + + # Append the remainder of the original changelog (starting from the first version header). + awk 'f{print} /^## \[v/{f=1; print}' CHANGELOG.md >> CHANGELOG_TMP.md + + # Replace the old file with the updated file. + mv CHANGELOG_TMP.md CHANGELOG.md + + # Remove the temporary generated file. + rm -f CHANGELOG-unreleased.md + + echo "Final CHANGELOG.md:" + cat CHANGELOG.md + + # 8. Check if CHANGELOG.md has any updates. + - name: Check for CHANGELOG.md changes + id: changelog_changes + run: | + if git diff --quiet CHANGELOG.md; then + echo "has_changes=false" >> $GITHUB_OUTPUT + else + echo "has_changes=true" >> $GITHUB_OUTPUT + fi + + # 9. Create (or update) the Pull Request only if there are changes. + - name: Create Pull Request + if: steps.changelog_changes.outputs.has_changes == 'true' + uses: peter-evans/create-pull-request@v7 + with: + token: ${{ secrets.GITHUB_TOKEN }} + base: main + branch: "changelog/unreleased-update" + sign-commits: true + commit-message: "action: update Unreleased changelog" + title: "📜 docs: Unreleased Changelog" + body: | + **Description**: + - This PR updates the Unreleased section in CHANGELOG.md. + - It compares the current main branch with the latest version tag (determined as ${{ steps.get_latest_tag.outputs.tag }}), + regenerates the Unreleased changelog, removes any old Unreleased block, and inserts the new content. diff --git a/.github/workflows/i18n-unused-keys.yml b/.github/workflows/i18n-unused-keys.yml index 79f95d3b27..f720a61783 100644 --- a/.github/workflows/i18n-unused-keys.yml +++ b/.github/workflows/i18n-unused-keys.yml @@ -4,6 +4,7 @@ on: pull_request: paths: - "client/src/**" + - "api/**" jobs: detect-unused-i18n-keys: @@ -21,7 +22,7 @@ jobs: # Define paths I18N_FILE="client/src/locales/en/translation.json" - SOURCE_DIR="client/src" + SOURCE_DIRS=("client/src" "api") # Check if translation file exists if [[ ! -f "$I18N_FILE" ]]; then @@ -37,7 +38,38 @@ jobs: # Check if each key is used in the source code for KEY in $KEYS; do - if ! grep -r --include=\*.{js,jsx,ts,tsx} -q "$KEY" "$SOURCE_DIR"; then + FOUND=false + + # Special case for dynamically constructed special variable keys + if [[ "$KEY" == com_ui_special_var_* ]]; then + # Check if TSpecialVarLabel is used in the codebase + for DIR in "${SOURCE_DIRS[@]}"; do + if grep -r --include=\*.{js,jsx,ts,tsx} -q "TSpecialVarLabel" "$DIR"; then + FOUND=true + break + fi + done + + # Also check if the key is directly used somewhere + if [[ "$FOUND" == false ]]; then + for DIR in "${SOURCE_DIRS[@]}"; do + if grep -r --include=\*.{js,jsx,ts,tsx} -q "$KEY" "$DIR"; then + FOUND=true + break + fi + done + fi + else + # Regular check for other keys + for DIR in "${SOURCE_DIRS[@]}"; do + if grep -r --include=\*.{js,jsx,ts,tsx} -q "$KEY" "$DIR"; then + FOUND=true + break + fi + done + fi + + if [[ "$FOUND" == false ]]; then UNUSED_KEYS+=("$KEY") fi done @@ -59,8 +91,8 @@ jobs: run: | PR_NUMBER=$(jq --raw-output .pull_request.number "$GITHUB_EVENT_PATH") - # Format the unused keys list correctly, filtering out empty entries - FILTERED_KEYS=$(echo "$unused_keys" | jq -r '.[]' | grep -v '^\s*$' | sed 's/^/- `/;s/$/`/' ) + # Format the unused keys list as checkboxes for easy manual checking. + FILTERED_KEYS=$(echo "$unused_keys" | jq -r '.[]' | grep -v '^\s*$' | sed 's/^/- [ ] `/;s/$/`/' ) COMMENT_BODY=$(cat < - Locize Logo + Locize Logo

diff --git a/api/app/clients/AnthropicClient.js b/api/app/clients/AnthropicClient.js index 522b6beb4f..91939975c4 100644 --- a/api/app/clients/AnthropicClient.js +++ b/api/app/clients/AnthropicClient.js @@ -2,12 +2,14 @@ const Anthropic = require('@anthropic-ai/sdk'); const { HttpsProxyAgent } = require('https-proxy-agent'); const { Constants, + ErrorTypes, EModelEndpoint, + parseTextParts, anthropicSettings, getResponseSender, validateVisionModel, } = require('librechat-data-provider'); -const { encodeAndFormat } = require('~/server/services/Files/images/encode'); +const { SplitStreamHandler: _Handler } = require('@librechat/agents'); const { truncateText, formatMessage, @@ -16,8 +18,15 @@ const { parseParamFromPrompt, createContextHandlers, } = require('./prompts'); +const { + getClaudeHeaders, + configureReasoning, + checkPromptCacheSupport, +} = require('~/server/services/Endpoints/anthropic/helpers'); const { getModelMaxTokens, getModelMaxOutputTokens, matchModelName } = require('~/utils'); const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); +const { encodeAndFormat } = require('~/server/services/Files/images/encode'); +const { createFetch, createStreamEventHandlers } = require('./generators'); const Tokenizer = require('~/server/services/Tokenizer'); const { sleep } = require('~/server/utils'); const BaseClient = require('./BaseClient'); @@ -26,6 +35,15 @@ const { logger } = require('~/config'); const HUMAN_PROMPT = '\n\nHuman:'; const AI_PROMPT = '\n\nAssistant:'; +class SplitStreamHandler extends _Handler { + getDeltaContent(chunk) { + return (chunk?.delta?.text ?? chunk?.completion) || ''; + } + getReasoningDelta(chunk) { + return chunk?.delta?.thinking || ''; + } +} + /** Helper function to introduce a delay before retrying */ function delayBeforeRetry(attempts, baseDelay = 1000) { return new Promise((resolve) => setTimeout(resolve, baseDelay * attempts)); @@ -68,6 +86,8 @@ class AnthropicClient extends BaseClient { /** The key for the usage object's output tokens * @type {string} */ this.outputTokensKey = 'output_tokens'; + /** @type {SplitStreamHandler | undefined} */ + this.streamHandler; } setOptions(options) { @@ -97,9 +117,10 @@ class AnthropicClient extends BaseClient { const modelMatch = matchModelName(this.modelOptions.model, EModelEndpoint.anthropic); this.isClaude3 = modelMatch.includes('claude-3'); - this.isLegacyOutput = !modelMatch.includes('claude-3-5-sonnet'); - this.supportsCacheControl = - this.options.promptCache && this.checkPromptCacheSupport(modelMatch); + this.isLegacyOutput = !( + /claude-3[-.]5-sonnet/.test(modelMatch) || /claude-3[-.]7/.test(modelMatch) + ); + this.supportsCacheControl = this.options.promptCache && checkPromptCacheSupport(modelMatch); if ( this.isLegacyOutput && @@ -125,16 +146,21 @@ class AnthropicClient extends BaseClient { this.options.endpointType ?? this.options.endpoint, this.options.endpointTokenConfig, ) ?? - 1500; + anthropicSettings.maxOutputTokens.reset(this.modelOptions.model); this.maxPromptTokens = this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens; - if (this.maxPromptTokens + this.maxResponseTokens > this.maxContextTokens) { - throw new Error( - `maxPromptTokens + maxOutputTokens (${this.maxPromptTokens} + ${this.maxResponseTokens} = ${ - this.maxPromptTokens + this.maxResponseTokens - }) must be less than or equal to maxContextTokens (${this.maxContextTokens})`, - ); + const reservedTokens = this.maxPromptTokens + this.maxResponseTokens; + if (reservedTokens > this.maxContextTokens) { + const info = `Total Possible Tokens + Max Output Tokens must be less than or equal to Max Context Tokens: ${this.maxPromptTokens} (total possible output) + ${this.maxResponseTokens} (max output) = ${reservedTokens}/${this.maxContextTokens} (max context)`; + const errorMessage = `{ "type": "${ErrorTypes.INPUT_LENGTH}", "info": "${info}" }`; + logger.warn(info); + throw new Error(errorMessage); + } else if (this.maxResponseTokens === this.maxContextTokens) { + const info = `Max Output Tokens must be less than Max Context Tokens: ${this.maxResponseTokens} (max output) = ${this.maxContextTokens} (max context)`; + const errorMessage = `{ "type": "${ErrorTypes.INPUT_LENGTH}", "info": "${info}" }`; + logger.warn(info); + throw new Error(errorMessage); } this.sender = @@ -159,7 +185,10 @@ class AnthropicClient extends BaseClient { getClient(requestOptions) { /** @type {Anthropic.ClientOptions} */ const options = { - fetch: this.fetch, + fetch: createFetch({ + directEndpoint: this.options.directEndpoint, + reverseProxyUrl: this.options.reverseProxyUrl, + }), apiKey: this.apiKey, }; @@ -171,18 +200,9 @@ class AnthropicClient extends BaseClient { options.baseURL = this.options.reverseProxyUrl; } - if ( - this.supportsCacheControl && - requestOptions?.model && - requestOptions.model.includes('claude-3-5-sonnet') - ) { - options.defaultHeaders = { - 'anthropic-beta': 'max-tokens-3-5-sonnet-2024-07-15,prompt-caching-2024-07-31', - }; - } else if (this.supportsCacheControl) { - options.defaultHeaders = { - 'anthropic-beta': 'prompt-caching-2024-07-31', - }; + const headers = getClaudeHeaders(requestOptions?.model, this.supportsCacheControl); + if (headers) { + options.defaultHeaders = headers; } return new Anthropic(options); @@ -376,13 +396,13 @@ class AnthropicClient extends BaseClient { const formattedMessages = orderedMessages.map((message, i) => { const formattedMessage = this.useMessages ? formatMessage({ - message, - endpoint: EModelEndpoint.anthropic, - }) + message, + endpoint: EModelEndpoint.anthropic, + }) : { - author: message.isCreatedByUser ? this.userLabel : this.assistantLabel, - content: message?.content ?? message.text, - }; + author: message.isCreatedByUser ? this.userLabel : this.assistantLabel, + content: message?.content ?? message.text, + }; const needsTokenCount = this.contextStrategy && !orderedMessages[i].tokenCount; /* If tokens were never counted, or, is a Vision request and the message has files, count again */ @@ -398,6 +418,9 @@ class AnthropicClient extends BaseClient { this.contextHandlers?.processFile(file); continue; } + if (file.metadata?.fileIdentifier) { + continue; + } orderedMessages[i].tokenCount += this.calculateImageTokenCost({ width: file.width, @@ -657,7 +680,7 @@ class AnthropicClient extends BaseClient { } getCompletion() { - logger.debug('AnthropicClient doesn\'t use getCompletion (all handled in sendCompletion)'); + logger.debug("AnthropicClient doesn't use getCompletion (all handled in sendCompletion)"); } /** @@ -668,29 +691,41 @@ class AnthropicClient extends BaseClient { * @returns {Promise} The response from the Anthropic client. */ async createResponse(client, options, useMessages) { - return useMessages ?? this.useMessages + return (useMessages ?? this.useMessages) ? await client.messages.create(options) : await client.completions.create(options); } + getMessageMapMethod() { + /** + * @param {TMessage} msg + */ + return (msg) => { + if (msg.text != null && msg.text && msg.text.startsWith(':::thinking')) { + msg.text = msg.text.replace(/:::thinking.*?:::/gs, '').trim(); + } else if (msg.content != null) { + msg.text = parseTextParts(msg.content, true); + delete msg.content; + } + + return msg; + }; + } + /** - * @param {string} modelName - * @returns {boolean} + * @param {string[]} [intermediateReply] + * @returns {string} */ - checkPromptCacheSupport(modelName) { - const modelMatch = matchModelName(modelName, EModelEndpoint.anthropic); - if (modelMatch.includes('claude-3-5-sonnet-latest')) { - return false; + getStreamText(intermediateReply) { + if (!this.streamHandler) { + return intermediateReply?.join('') ?? ''; } - if ( - modelMatch === 'claude-3-5-sonnet' || - modelMatch === 'claude-3-5-haiku' || - modelMatch === 'claude-3-haiku' || - modelMatch === 'claude-3-opus' - ) { - return true; - } - return false; + + const reasoningText = this.streamHandler.reasoningTokens.join(''); + + const reasoningBlock = reasoningText.length > 0 ? `:::thinking\n${reasoningText}\n:::\n` : ''; + + return `${reasoningBlock}${this.streamHandler.tokens.join('')}`; } async sendCompletion(payload, { onProgress, abortController }) { @@ -710,7 +745,6 @@ class AnthropicClient extends BaseClient { user_id: this.user, }; - let text = ''; const { stream, model, @@ -721,22 +755,34 @@ class AnthropicClient extends BaseClient { topK: top_k, } = this.modelOptions; - const requestOptions = { + let requestOptions = { model, stream: stream || true, stop_sequences, temperature, metadata, - top_p, - top_k, }; if (this.useMessages) { requestOptions.messages = payload; - requestOptions.max_tokens = maxOutputTokens || legacy.maxOutputTokens.default; + requestOptions.max_tokens = + maxOutputTokens || anthropicSettings.maxOutputTokens.reset(requestOptions.model); } else { requestOptions.prompt = payload; - requestOptions.max_tokens_to_sample = maxOutputTokens || 1500; + requestOptions.max_tokens_to_sample = maxOutputTokens || legacy.maxOutputTokens.default; + } + + requestOptions = configureReasoning(requestOptions, { + thinking: this.options.thinking, + thinkingBudget: this.options.thinkingBudget, + }); + + if (!/claude-3[-.]7/.test(model)) { + requestOptions.top_p = top_p; + requestOptions.top_k = top_k; + } else if (requestOptions.thinking == null) { + requestOptions.topP = top_p; + requestOptions.topK = top_k; } if (this.systemMessage && this.supportsCacheControl === true) { @@ -756,13 +802,14 @@ class AnthropicClient extends BaseClient { } logger.debug('[AnthropicClient]', { ...requestOptions }); + const handlers = createStreamEventHandlers(this.options.res); + this.streamHandler = new SplitStreamHandler({ + accumulate: true, + runId: this.responseMessageId, + handlers, + }); - const handleChunk = (currentChunk) => { - if (currentChunk) { - text += currentChunk; - onProgress(currentChunk); - } - }; + let intermediateReply = this.streamHandler.tokens; const maxRetries = 3; const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE; @@ -783,22 +830,15 @@ class AnthropicClient extends BaseClient { }); for await (const completion of response) { - // Handle each completion as before const type = completion?.type ?? ''; if (tokenEventTypes.has(type)) { logger.debug(`[AnthropicClient] ${type}`, completion); this[type] = completion; } - if (completion?.delta?.text) { - handleChunk(completion.delta.text); - } else if (completion.completion) { - handleChunk(completion.completion); - } - + this.streamHandler.handle(completion); await sleep(streamRate); } - // Successful processing, exit loop break; } catch (error) { attempts += 1; @@ -808,6 +848,10 @@ class AnthropicClient extends BaseClient { if (attempts < maxRetries) { await delayBeforeRetry(attempts, 350); + } else if (this.streamHandler && this.streamHandler.reasoningTokens.length) { + return this.getStreamText(); + } else if (intermediateReply.length > 0) { + return this.getStreamText(intermediateReply); } else { throw new Error(`Operation failed after ${maxRetries} attempts: ${error.message}`); } @@ -823,8 +867,7 @@ class AnthropicClient extends BaseClient { } await processResponse.bind(this)(); - - return text.trim(); + return this.getStreamText(intermediateReply); } getSaveOptions() { @@ -834,6 +877,8 @@ class AnthropicClient extends BaseClient { promptPrefix: this.options.promptPrefix, modelLabel: this.options.modelLabel, promptCache: this.options.promptCache, + thinking: this.options.thinking, + thinkingBudget: this.options.thinkingBudget, resendFiles: this.options.resendFiles, iconURL: this.options.iconURL, greeting: this.options.greeting, @@ -843,7 +888,7 @@ class AnthropicClient extends BaseClient { } getBuildMessagesOptions() { - logger.debug('AnthropicClient doesn\'t use getBuildMessagesOptions'); + logger.debug("AnthropicClient doesn't use getBuildMessagesOptions"); } getEncoding() { diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index ebf3ca12d9..55b8780180 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -5,13 +5,15 @@ const { isAgentsEndpoint, isParamEndpoint, EModelEndpoint, + ContentTypes, + excludedKeys, ErrorTypes, Constants, } = require('librechat-data-provider'); -const { getMessages, saveMessage, updateMessage, saveConvo } = require('~/models'); -const { addSpaceIfNeeded, isEnabled } = require('~/server/utils'); +const { getMessages, saveMessage, updateMessage, saveConvo, getConvo } = require('~/models'); +const { checkBalance } = require('~/models/balanceMethods'); const { truncateToolCallOutputs } = require('./prompts'); -const checkBalance = require('~/models/checkBalance'); +const { addSpaceIfNeeded } = require('~/server/utils'); const { getFiles } = require('~/models/File'); const TextStream = require('./TextStream'); const { logger } = require('~/config'); @@ -26,15 +28,10 @@ class BaseClient { month: 'long', day: 'numeric', }); - this.fetch = this.fetch.bind(this); /** @type {boolean} */ this.skipSaveConvo = false; /** @type {boolean} */ this.skipSaveUserMessage = false; - /** @type {ClientDatabaseSavePromise} */ - this.userMessagePromise; - /** @type {ClientDatabaseSavePromise} */ - this.responsePromise; /** @type {string} */ this.user; /** @type {string} */ @@ -55,6 +52,10 @@ class BaseClient { * Flag to determine if the client re-submitted the latest assistant message. * @type {boolean | undefined} */ this.continued; + /** + * Flag to determine if the client has already fetched the conversation while saving new messages. + * @type {boolean | undefined} */ + this.fetchedConvo; /** @type {TMessage[]} */ this.currentMessages = []; /** @type {import('librechat-data-provider').VisionModes | undefined} */ @@ -62,15 +63,15 @@ class BaseClient { } setOptions() { - throw new Error('Method \'setOptions\' must be implemented.'); + throw new Error("Method 'setOptions' must be implemented."); } async getCompletion() { - throw new Error('Method \'getCompletion\' must be implemented.'); + throw new Error("Method 'getCompletion' must be implemented."); } async sendCompletion() { - throw new Error('Method \'sendCompletion\' must be implemented.'); + throw new Error("Method 'sendCompletion' must be implemented."); } getSaveOptions() { @@ -236,11 +237,11 @@ class BaseClient { const userMessage = opts.isEdited ? this.currentMessages[this.currentMessages.length - 2] : this.createUserMessage({ - messageId: userMessageId, - parentMessageId, - conversationId, - text: message, - }); + messageId: userMessageId, + parentMessageId, + conversationId, + text: message, + }); if (typeof opts?.getReqData === 'function') { opts.getReqData({ @@ -360,17 +361,14 @@ class BaseClient { * context: TMessage[], * remainingContextTokens: number, * messagesToRefine: TMessage[], - * summaryIndex: number, - * }>} An object with four properties: `context`, `summaryIndex`, `remainingContextTokens`, and `messagesToRefine`. + * }>} An object with three properties: `context`, `remainingContextTokens`, and `messagesToRefine`. * `context` is an array of messages that fit within the token limit. - * `summaryIndex` is the index of the first message in the `messagesToRefine` array. * `remainingContextTokens` is the number of tokens remaining within the limit after adding the messages to the context. * `messagesToRefine` is an array of messages that were not added to the context because they would have exceeded the token limit. */ async getMessagesWithinTokenLimit({ messages: _messages, maxContextTokens, instructions }) { // Every reply is primed with <|start|>assistant<|message|>, so we // start with 3 tokens for the label after all messages have been counted. - let summaryIndex = -1; let currentTokenCount = 3; const instructionsTokenCount = instructions?.tokenCount ?? 0; let remainingContextTokens = @@ -403,14 +401,12 @@ class BaseClient { } const prunedMemory = messages; - summaryIndex = prunedMemory.length - 1; remainingContextTokens -= currentTokenCount; return { context: context.reverse(), remainingContextTokens, messagesToRefine: prunedMemory, - summaryIndex, }; } @@ -453,7 +449,7 @@ class BaseClient { let orderedWithInstructions = this.addInstructions(orderedMessages, instructions); - let { context, remainingContextTokens, messagesToRefine, summaryIndex } = + let { context, remainingContextTokens, messagesToRefine } = await this.getMessagesWithinTokenLimit({ messages: orderedWithInstructions, instructions, @@ -523,7 +519,7 @@ class BaseClient { } // Make sure to only continue summarization logic if the summary message was generated - shouldSummarize = summaryMessage && shouldSummarize; + shouldSummarize = summaryMessage != null && shouldSummarize === true; logger.debug('[BaseClient] Context Count (2/2)', { remainingContextTokens, @@ -533,17 +529,18 @@ class BaseClient { /** @type {Record | undefined} */ let tokenCountMap; if (buildTokenMap) { - tokenCountMap = orderedWithInstructions.reduce((map, message, index) => { + const currentPayload = shouldSummarize ? orderedWithInstructions : context; + tokenCountMap = currentPayload.reduce((map, message, index) => { const { messageId } = message; if (!messageId) { return map; } - if (shouldSummarize && index === summaryIndex && !usePrevSummary) { + if (shouldSummarize && index === messagesToRefine.length - 1 && !usePrevSummary) { map.summaryMessage = { ...summaryMessage, messageId, tokenCount: summaryTokenCount }; } - map[messageId] = orderedWithInstructions[index].tokenCount; + map[messageId] = currentPayload[index].tokenCount; return map; }, {}); } @@ -562,6 +559,8 @@ class BaseClient { } async sendMessage(message, opts = {}) { + /** @type {Promise} */ + let userMessagePromise; const { user, head, isEdited, conversationId, responseMessageId, saveOptions, userMessage } = await this.handleStartMethods(message, opts); @@ -623,17 +622,18 @@ class BaseClient { } if (!isEdited && !this.skipSaveUserMessage) { - this.userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user); + userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user); this.savedMessageIds.add(userMessage.messageId); if (typeof opts?.getReqData === 'function') { opts.getReqData({ - userMessagePromise: this.userMessagePromise, + userMessagePromise, }); } } + const balance = this.options.req?.app?.locals?.balance; if ( - isEnabled(process.env.CHECK_BALANCE) && + balance?.enabled && supportsBalanceCheck[this.options.endpointType ?? this.options.endpoint] ) { await checkBalance({ @@ -652,7 +652,9 @@ class BaseClient { /** @type {string|string[]|undefined} */ const completion = await this.sendCompletion(payload, opts); - this.abortController.requestCompleted = true; + if (this.abortController) { + this.abortController.requestCompleted = true; + } /** @type {TMessage} */ const responseMessage = { @@ -673,7 +675,8 @@ class BaseClient { responseMessage.text = addSpaceIfNeeded(generation) + completion; } else if ( Array.isArray(completion) && - isParamEndpoint(this.options.endpoint, this.options.endpointType) + (this.clientName === EModelEndpoint.agents || + isParamEndpoint(this.options.endpoint, this.options.endpointType)) ) { responseMessage.text = ''; responseMessage.content = completion; @@ -699,7 +702,13 @@ class BaseClient { if (usage != null && Number(usage[this.outputTokensKey]) > 0) { responseMessage.tokenCount = usage[this.outputTokensKey]; completionTokens = responseMessage.tokenCount; - await this.updateUserMessageTokenCount({ usage, tokenCountMap, userMessage, opts }); + await this.updateUserMessageTokenCount({ + usage, + tokenCountMap, + userMessage, + userMessagePromise, + opts, + }); } else { responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage); completionTokens = responseMessage.tokenCount; @@ -708,8 +717,8 @@ class BaseClient { await this.recordTokenUsage({ promptTokens, completionTokens, usage }); } - if (this.userMessagePromise) { - await this.userMessagePromise; + if (userMessagePromise) { + await userMessagePromise; } if (this.artifactPromises) { @@ -724,7 +733,11 @@ class BaseClient { } } - this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user); + responseMessage.databasePromise = this.saveMessageToDatabase( + responseMessage, + saveOptions, + user, + ); this.savedMessageIds.add(responseMessage.messageId); delete responseMessage.tokenCount; return responseMessage; @@ -745,9 +758,16 @@ class BaseClient { * @param {StreamUsage} params.usage * @param {Record} params.tokenCountMap * @param {TMessage} params.userMessage + * @param {Promise} params.userMessagePromise * @param {object} params.opts */ - async updateUserMessageTokenCount({ usage, tokenCountMap, userMessage, opts }) { + async updateUserMessageTokenCount({ + usage, + tokenCountMap, + userMessage, + userMessagePromise, + opts, + }) { /** @type {boolean} */ const shouldUpdateCount = this.calculateCurrentTokenCount != null && @@ -783,7 +803,7 @@ class BaseClient { Note: we update the user message to be sure it gets the calculated token count; though `AskController` saves the user message, EditController does not */ - await this.userMessagePromise; + await userMessagePromise; await this.updateMessageInDatabase({ messageId: userMessage.messageId, tokenCount: userMessageTokenCount, @@ -849,7 +869,7 @@ class BaseClient { } const savedMessage = await saveMessage( - this.options.req, + this.options?.req, { ...message, endpoint: this.options.endpoint, @@ -863,16 +883,40 @@ class BaseClient { return { message: savedMessage }; } - const conversation = await saveConvo( - this.options.req, - { - conversationId: message.conversationId, - endpoint: this.options.endpoint, - endpointType: this.options.endpointType, - ...endpointOptions, - }, - { context: 'api/app/clients/BaseClient.js - saveMessageToDatabase #saveConvo' }, - ); + const fieldsToKeep = { + conversationId: message.conversationId, + endpoint: this.options.endpoint, + endpointType: this.options.endpointType, + ...endpointOptions, + }; + + const existingConvo = + this.fetchedConvo === true + ? null + : await getConvo(this.options?.req?.user?.id, message.conversationId); + + const unsetFields = {}; + const exceptions = new Set(['spec', 'iconURL']); + if (existingConvo != null) { + this.fetchedConvo = true; + for (const key in existingConvo) { + if (!key) { + continue; + } + if (excludedKeys.has(key) && !exceptions.has(key)) { + continue; + } + + if (endpointOptions?.[key] === undefined) { + unsetFields[key] = 1; + } + } + } + + const conversation = await saveConvo(this.options?.req, fieldsToKeep, { + context: 'api/app/clients/BaseClient.js - saveMessageToDatabase #saveConvo', + unsetFields, + }); return { message: savedMessage, conversation }; } @@ -993,11 +1037,17 @@ class BaseClient { const processValue = (value) => { if (Array.isArray(value)) { for (let item of value) { - if (!item || !item.type || item.type === 'image_url') { + if ( + !item || + !item.type || + item.type === ContentTypes.THINK || + item.type === ContentTypes.ERROR || + item.type === ContentTypes.IMAGE_URL + ) { continue; } - if (item.type === 'tool_call' && item.tool_call != null) { + if (item.type === ContentTypes.TOOL_CALL && item.tool_call != null) { const toolName = item.tool_call?.name || ''; if (toolName != null && toolName && typeof toolName === 'string') { numTokens += this.getTokenCount(toolName); @@ -1093,9 +1143,13 @@ class BaseClient { return message; } - const files = await getFiles({ - file_id: { $in: fileIds }, - }); + const files = await getFiles( + { + file_id: { $in: fileIds }, + }, + {}, + {}, + ); await this.addImageURLs(message, files, this.visionMode); diff --git a/api/app/clients/ChatGPTClient.js b/api/app/clients/ChatGPTClient.js index 5450300a17..07b2fa97bb 100644 --- a/api/app/clients/ChatGPTClient.js +++ b/api/app/clients/ChatGPTClient.js @@ -1,4 +1,4 @@ -const Keyv = require('keyv'); +const { Keyv } = require('keyv'); const crypto = require('crypto'); const { CohereClient } = require('cohere-ai'); const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source'); @@ -339,7 +339,7 @@ class ChatGPTClient extends BaseClient { opts.body = JSON.stringify(modelOptions); if (modelOptions.stream) { - // eslint-disable-next-line no-async-promise-executor + return new Promise(async (resolve, reject) => { try { let done = false; diff --git a/api/app/clients/GoogleClient.js b/api/app/clients/GoogleClient.js index 03461a6796..c9102e9ae2 100644 --- a/api/app/clients/GoogleClient.js +++ b/api/app/clients/GoogleClient.js @@ -9,6 +9,7 @@ const { validateVisionModel, getResponseSender, endpointSettings, + parseTextParts, EModelEndpoint, ContentTypes, VisionModes, @@ -51,7 +52,7 @@ class GoogleClient extends BaseClient { const serviceKey = creds[AuthKeys.GOOGLE_SERVICE_KEY] ?? {}; this.serviceKey = - serviceKey && typeof serviceKey === 'string' ? JSON.parse(serviceKey) : serviceKey ?? {}; + serviceKey && typeof serviceKey === 'string' ? JSON.parse(serviceKey) : (serviceKey ?? {}); /** @type {string | null | undefined} */ this.project_id = this.serviceKey.project_id; this.client_email = this.serviceKey.client_email; @@ -73,6 +74,8 @@ class GoogleClient extends BaseClient { * @type {string} */ this.outputTokensKey = 'output_tokens'; this.visionMode = VisionModes.generative; + /** @type {string} */ + this.systemMessage; if (options.skipSetOptions) { return; } @@ -137,8 +140,7 @@ class GoogleClient extends BaseClient { this.options.attachments?.then((attachments) => this.checkVisionRequest(attachments)); /** @type {boolean} Whether using a "GenerativeAI" Model */ - this.isGenerativeModel = - this.modelOptions.model.includes('gemini') || this.modelOptions.model.includes('learnlm'); + this.isGenerativeModel = /gemini|learnlm|gemma/.test(this.modelOptions.model); this.maxContextTokens = this.options.maxContextTokens ?? @@ -184,7 +186,7 @@ class GoogleClient extends BaseClient { if (typeof this.options.artifactsPrompt === 'string' && this.options.artifactsPrompt) { promptPrefix = `${promptPrefix ?? ''}\n${this.options.artifactsPrompt}`.trim(); } - this.options.promptPrefix = promptPrefix; + this.systemMessage = promptPrefix; this.initializeClient(); return this; } @@ -196,7 +198,11 @@ class GoogleClient extends BaseClient { */ checkVisionRequest(attachments) { /* Validation vision request */ - this.defaultVisionModel = this.options.visionModel ?? 'gemini-pro-vision'; + this.defaultVisionModel = + this.options.visionModel ?? + (!EXCLUDED_GENAI_MODELS.test(this.modelOptions.model) + ? this.modelOptions.model + : 'gemini-pro-vision'); const availableModels = this.options.modelsConfig?.[EModelEndpoint.google]; this.isVisionModel = validateVisionModel({ model: this.modelOptions.model, availableModels }); @@ -311,10 +317,13 @@ class GoogleClient extends BaseClient { this.contextHandlers?.processFile(file); continue; } + if (file.metadata?.fileIdentifier) { + continue; + } } this.augmentedPrompt = await this.contextHandlers.createContext(); - this.options.promptPrefix = this.augmentedPrompt + this.options.promptPrefix; + this.systemMessage = this.augmentedPrompt + this.systemMessage; } } @@ -361,8 +370,8 @@ class GoogleClient extends BaseClient { throw new Error('[GoogleClient] PaLM 2 and Codey models are no longer supported.'); } - if (this.options.promptPrefix) { - const instructionsTokenCount = this.getTokenCount(this.options.promptPrefix); + if (this.systemMessage) { + const instructionsTokenCount = this.getTokenCount(this.systemMessage); this.maxContextTokens = this.maxContextTokens - instructionsTokenCount; if (this.maxContextTokens < 0) { @@ -417,8 +426,8 @@ class GoogleClient extends BaseClient { ], }; - if (this.options.promptPrefix) { - payload.instances[0].context = this.options.promptPrefix; + if (this.systemMessage) { + payload.instances[0].context = this.systemMessage; } logger.debug('[GoogleClient] buildMessages', payload); @@ -464,7 +473,7 @@ class GoogleClient extends BaseClient { identityPrefix = `${identityPrefix}\nYou are ${this.options.modelLabel}`; } - let promptPrefix = (this.options.promptPrefix ?? '').trim(); + let promptPrefix = (this.systemMessage ?? '').trim(); if (identityPrefix) { promptPrefix = `${identityPrefix}${promptPrefix}`; @@ -639,7 +648,7 @@ class GoogleClient extends BaseClient { let error; try { if (!EXCLUDED_GENAI_MODELS.test(modelName) && !this.project_id) { - /** @type {GenAI} */ + /** @type {GenerativeModel} */ const client = this.client; /** @type {GenerateContentRequest} */ const requestOptions = { @@ -648,7 +657,7 @@ class GoogleClient extends BaseClient { generationConfig: googleGenConfigSchema.parse(this.modelOptions), }; - const promptPrefix = (this.options.promptPrefix ?? '').trim(); + const promptPrefix = (this.systemMessage ?? '').trim(); if (promptPrefix.length) { requestOptions.systemInstruction = { parts: [ @@ -663,7 +672,17 @@ class GoogleClient extends BaseClient { /** @type {GenAIUsageMetadata} */ let usageMetadata; - const result = await client.generateContentStream(requestOptions); + abortController.signal.addEventListener( + 'abort', + () => { + logger.warn('[GoogleClient] Request was aborted', abortController.signal.reason); + }, + { once: true }, + ); + + const result = await client.generateContentStream(requestOptions, { + signal: abortController.signal, + }); for await (const chunk of result.stream) { usageMetadata = !usageMetadata ? chunk?.usageMetadata @@ -758,6 +777,22 @@ class GoogleClient extends BaseClient { return this.usage; } + getMessageMapMethod() { + /** + * @param {TMessage} msg + */ + return (msg) => { + if (msg.text != null && msg.text && msg.text.startsWith(':::thinking')) { + msg.text = msg.text.replace(/:::thinking.*?:::/gs, '').trim(); + } else if (msg.content != null) { + msg.text = parseTextParts(msg.content, true); + delete msg.content; + } + + return msg; + }; + } + /** * Calculates the correct token count for the current user message based on the token count map and API usage. * Edge case: If the calculation results in a negative value, it returns the original estimate. @@ -815,7 +850,8 @@ class GoogleClient extends BaseClient { let reply = ''; const { abortController } = options; - const model = this.modelOptions.modelName ?? this.modelOptions.model ?? ''; + const model = + this.options.titleModel ?? this.modelOptions.modelName ?? this.modelOptions.model ?? ''; const safetySettings = getSafetySettings(model); if (!EXCLUDED_GENAI_MODELS.test(model) && !this.project_id) { logger.debug('Identified titling model as GenAI version'); diff --git a/api/app/clients/OllamaClient.js b/api/app/clients/OllamaClient.js index d86e120f43..77d007580c 100644 --- a/api/app/clients/OllamaClient.js +++ b/api/app/clients/OllamaClient.js @@ -2,7 +2,7 @@ const { z } = require('zod'); const axios = require('axios'); const { Ollama } = require('ollama'); const { Constants } = require('librechat-data-provider'); -const { deriveBaseURL } = require('~/utils'); +const { deriveBaseURL, logAxiosError } = require('~/utils'); const { sleep } = require('~/server/utils'); const { logger } = require('~/config'); @@ -68,7 +68,7 @@ class OllamaClient { } catch (error) { const logMessage = 'Failed to fetch models from Ollama API. If you are not using Ollama directly, and instead, through some aggregator or reverse proxy that handles fetching via OpenAI spec, ensure the name of the endpoint doesn\'t start with `ollama` (case-insensitive).'; - logger.error(logMessage, error); + logAxiosError({ message: logMessage, error }); return []; } } diff --git a/api/app/clients/OpenAIClient.js b/api/app/clients/OpenAIClient.js index 368e7d6e84..280db89284 100644 --- a/api/app/clients/OpenAIClient.js +++ b/api/app/clients/OpenAIClient.js @@ -1,12 +1,14 @@ -const OpenAI = require('openai'); const { OllamaClient } = require('./OllamaClient'); const { HttpsProxyAgent } = require('https-proxy-agent'); -const { SplitStreamHandler, GraphEvents } = require('@librechat/agents'); +const { SplitStreamHandler, CustomOpenAIClient: OpenAI } = require('@librechat/agents'); const { Constants, ImageDetail, + ContentTypes, + parseTextParts, EModelEndpoint, resolveHeaders, + KnownEndpoints, openAISettings, ImageDetailCost, CohereConstants, @@ -29,17 +31,18 @@ const { createContextHandlers, } = require('./prompts'); const { encodeAndFormat } = require('~/server/services/Files/images/encode'); +const { createFetch, createStreamEventHandlers } = require('./generators'); const { addSpaceIfNeeded, isEnabled, sleep } = require('~/server/utils'); const Tokenizer = require('~/server/services/Tokenizer'); const { spendTokens } = require('~/models/spendTokens'); const { handleOpenAIErrors } = require('./tools/util'); const { createLLM, RunManager } = require('./llm'); -const { logger, sendEvent } = require('~/config'); const ChatGPTClient = require('./ChatGPTClient'); const { summaryBuffer } = require('./memory'); const { runTitleChain } = require('./chains'); const { tokenSplit } = require('./document'); const BaseClient = require('./BaseClient'); +const { logger } = require('~/config'); class OpenAIClient extends BaseClient { constructor(apiKey, options = {}) { @@ -105,21 +108,17 @@ class OpenAIClient extends BaseClient { this.checkVisionRequest(this.options.attachments); } - const omniPattern = /\b(o1|o3)\b/i; + const omniPattern = /\b(o\d)\b/i; this.isOmni = omniPattern.test(this.modelOptions.model); - const { OPENROUTER_API_KEY, OPENAI_FORCE_PROMPT } = process.env ?? {}; - if (OPENROUTER_API_KEY && !this.azure) { - this.apiKey = OPENROUTER_API_KEY; - this.useOpenRouter = true; - } - + const { OPENAI_FORCE_PROMPT } = process.env ?? {}; const { reverseProxyUrl: reverseProxy } = this.options; if ( !this.useOpenRouter && - reverseProxy && - reverseProxy.includes('https://openrouter.ai/api/v1') + ((reverseProxy && reverseProxy.includes(KnownEndpoints.openrouter)) || + (this.options.endpoint && + this.options.endpoint.toLowerCase().includes(KnownEndpoints.openrouter))) ) { this.useOpenRouter = true; } @@ -228,10 +227,6 @@ class OpenAIClient extends BaseClient { logger.debug('Using Azure endpoint'); } - if (this.useOpenRouter) { - this.completionsUrl = 'https://openrouter.ai/api/v1/chat/completions'; - } - return this; } @@ -306,7 +301,9 @@ class OpenAIClient extends BaseClient { } getEncoding() { - return this.model?.includes('gpt-4o') ? 'o200k_base' : 'cl100k_base'; + return this.modelOptions?.model && /gpt-4[^-\s]/.test(this.modelOptions.model) + ? 'o200k_base' + : 'cl100k_base'; } /** @@ -458,6 +455,9 @@ class OpenAIClient extends BaseClient { this.contextHandlers?.processFile(file); continue; } + if (file.metadata?.fileIdentifier) { + continue; + } orderedMessages[i].tokenCount += this.calculateImageTokenCost({ width: file.width, @@ -475,7 +475,9 @@ class OpenAIClient extends BaseClient { promptPrefix = this.augmentedPrompt + promptPrefix; } - if (promptPrefix && this.isOmni !== true) { + const noSystemModelRegex = /\b(o1-preview|o1-mini)\b/i.test(this.modelOptions.model); + + if (promptPrefix && !noSystemModelRegex) { promptPrefix = `Instructions:\n${promptPrefix.trim()}`; instructions = { role: 'system', @@ -503,11 +505,27 @@ class OpenAIClient extends BaseClient { }; /** EXPERIMENTAL */ - if (promptPrefix && this.isOmni === true) { + if (promptPrefix && noSystemModelRegex) { const lastUserMessageIndex = payload.findLastIndex((message) => message.role === 'user'); if (lastUserMessageIndex !== -1) { - payload[lastUserMessageIndex].content = - `${promptPrefix}\n${payload[lastUserMessageIndex].content}`; + if (Array.isArray(payload[lastUserMessageIndex].content)) { + const firstTextPartIndex = payload[lastUserMessageIndex].content.findIndex( + (part) => part.type === ContentTypes.TEXT, + ); + if (firstTextPartIndex !== -1) { + const firstTextPart = payload[lastUserMessageIndex].content[firstTextPartIndex]; + payload[lastUserMessageIndex].content[firstTextPartIndex].text = + `${promptPrefix}\n${firstTextPart.text}`; + } else { + payload[lastUserMessageIndex].content.unshift({ + type: ContentTypes.TEXT, + text: promptPrefix, + }); + } + } else { + payload[lastUserMessageIndex].content = + `${promptPrefix}\n${payload[lastUserMessageIndex].content}`; + } } } @@ -596,7 +614,7 @@ class OpenAIClient extends BaseClient { return result.trim(); } - logger.debug('[OpenAIClient] sendCompletion: result', result); + logger.debug('[OpenAIClient] sendCompletion: result', { ...result }); if (this.isChatCompletion) { reply = result.choices[0].message.content; @@ -613,7 +631,7 @@ class OpenAIClient extends BaseClient { } initializeLLM({ - model = 'gpt-4o-mini', + model = openAISettings.model.default, modelName, temperature = 0.2, max_tokens, @@ -714,7 +732,7 @@ class OpenAIClient extends BaseClient { const { OPENAI_TITLE_MODEL } = process.env ?? {}; - let model = this.options.titleModel ?? OPENAI_TITLE_MODEL ?? 'gpt-4o-mini'; + let model = this.options.titleModel ?? OPENAI_TITLE_MODEL ?? openAISettings.model.default; if (model === Constants.CURRENT_MODEL) { model = this.modelOptions.model; } @@ -805,7 +823,7 @@ ${convo} const completionTokens = this.getTokenCount(title); - this.recordTokenUsage({ promptTokens, completionTokens, context: 'title' }); + await this.recordTokenUsage({ promptTokens, completionTokens, context: 'title' }); } catch (e) { logger.error( '[OpenAIClient] There was an issue generating the title with the completion method', @@ -907,7 +925,7 @@ ${convo} let prompt; // TODO: remove the gpt fallback and make it specific to endpoint - const { OPENAI_SUMMARY_MODEL = 'gpt-4o-mini' } = process.env ?? {}; + const { OPENAI_SUMMARY_MODEL = openAISettings.model.default } = process.env ?? {}; let model = this.options.summaryModel ?? OPENAI_SUMMARY_MODEL; if (model === Constants.CURRENT_MODEL) { model = this.modelOptions.model; @@ -1108,6 +1126,9 @@ ${convo} return (msg) => { if (msg.text != null && msg.text && msg.text.startsWith(':::thinking')) { msg.text = msg.text.replace(/:::thinking.*?:::/gs, '').trim(); + } else if (msg.content != null) { + msg.text = parseTextParts(msg.content, true); + delete msg.content; } return msg; @@ -1159,10 +1180,6 @@ ${convo} opts.httpAgent = new HttpsProxyAgent(this.options.proxy); } - if (this.isVisionModel) { - modelOptions.max_tokens = 4000; - } - /** @type {TAzureConfig | undefined} */ const azureConfig = this.options?.req?.app?.locals?.[EModelEndpoint.azureOpenAI]; @@ -1212,9 +1229,9 @@ ${convo} opts.baseURL = this.langchainProxy ? constructAzureURL({ - baseURL: this.langchainProxy, - azureOptions: this.azure, - }) + baseURL: this.langchainProxy, + azureOptions: this.azure, + }) : this.azureEndpoint.split(/(? { + const dropParams = [...this.options.dropParams]; + dropParams.forEach((param) => { delete modelOptions[param]; }); logger.debug('[OpenAIClient] chatCompletion: dropped params', { - dropParams: this.options.dropParams, + dropParams: dropParams, modelOptions, }); } @@ -1301,15 +1357,11 @@ ${convo} let streamResolve; if ( - this.isOmni === true && - (this.azure || /o1(?!-(?:mini|preview)).*$/.test(modelOptions.model)) && - !/o3-.*$/.test(this.modelOptions.model) && - modelOptions.stream + (!this.isOmni || /^o1-(mini|preview)/i.test(modelOptions.model)) && + modelOptions.reasoning_effort != null ) { - delete modelOptions.stream; - delete modelOptions.stop; - } else if (!this.isOmni && modelOptions.reasoning_effort != null) { delete modelOptions.reasoning_effort; + delete modelOptions.temperature; } let reasoningKey = 'reasoning_content'; @@ -1317,16 +1369,19 @@ ${convo} modelOptions.include_reasoning = true; reasoningKey = 'reasoning'; } + if (this.useOpenRouter && modelOptions.reasoning_effort != null) { + modelOptions.reasoning = { + effort: modelOptions.reasoning_effort, + }; + delete modelOptions.reasoning_effort; + } + const handlers = createStreamEventHandlers(this.options.res); this.streamHandler = new SplitStreamHandler({ reasoningKey, accumulate: true, runId: this.responseMessageId, - handlers: { - [GraphEvents.ON_RUN_STEP]: (event) => sendEvent(this.options.res, event), - [GraphEvents.ON_MESSAGE_DELTA]: (event) => sendEvent(this.options.res, event), - [GraphEvents.ON_REASONING_DELTA]: (event) => sendEvent(this.options.res, event), - }, + handlers, }); intermediateReply = this.streamHandler.tokens; @@ -1340,12 +1395,6 @@ ${convo} ...modelOptions, stream: true, }; - if ( - this.options.endpoint === EModelEndpoint.openAI || - this.options.endpoint === EModelEndpoint.azureOpenAI - ) { - params.stream_options = { include_usage: true }; - } const stream = await openai.beta.chat.completions .stream(params) .on('abort', () => { @@ -1430,6 +1479,11 @@ ${convo} }); } + if (openai.abortHandler && abortController.signal) { + abortController.signal.removeEventListener('abort', openai.abortHandler); + openai.abortHandler = undefined; + } + if (!chatCompletion && UnexpectedRoleError) { throw new Error( 'OpenAI error: Invalid final message: OpenAI expects final message to include role=assistant', diff --git a/api/app/clients/PluginsClient.js b/api/app/clients/PluginsClient.js index bfe222e248..d0ffe2ef75 100644 --- a/api/app/clients/PluginsClient.js +++ b/api/app/clients/PluginsClient.js @@ -5,9 +5,8 @@ const { addImages, buildErrorInput, buildPromptPrefix } = require('./output_pars const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents'); const { processFileURL } = require('~/server/services/Files/process'); const { EModelEndpoint } = require('librechat-data-provider'); +const { checkBalance } = require('~/models/balanceMethods'); const { formatLangChainMessages } = require('./prompts'); -const checkBalance = require('~/models/checkBalance'); -const { isEnabled } = require('~/server/utils'); const { extractBaseURL } = require('~/utils'); const { loadTools } = require('./tools/util'); const { logger } = require('~/config'); @@ -253,12 +252,14 @@ class PluginsClient extends OpenAIClient { await this.recordTokenUsage(responseMessage); } - this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user); + const databasePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user); delete responseMessage.tokenCount; - return { ...responseMessage, ...result }; + return { ...responseMessage, ...result, databasePromise }; } async sendMessage(message, opts = {}) { + /** @type {Promise} */ + let userMessagePromise; /** @type {{ filteredTools: string[], includedTools: string[] }} */ const { filteredTools = [], includedTools = [] } = this.options.req.app.locals; @@ -328,15 +329,16 @@ class PluginsClient extends OpenAIClient { } if (!this.skipSaveUserMessage) { - this.userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user); + userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user); if (typeof opts?.getReqData === 'function') { opts.getReqData({ - userMessagePromise: this.userMessagePromise, + userMessagePromise, }); } } - if (isEnabled(process.env.CHECK_BALANCE)) { + const balance = this.options.req?.app?.locals?.balance; + if (balance?.enabled) { await checkBalance({ req: this.options.req, res: this.options.res, diff --git a/api/app/clients/callbacks/createStartHandler.js b/api/app/clients/callbacks/createStartHandler.js index 4bc32bc0c2..b7292aaf17 100644 --- a/api/app/clients/callbacks/createStartHandler.js +++ b/api/app/clients/callbacks/createStartHandler.js @@ -1,8 +1,8 @@ const { promptTokensEstimate } = require('openai-chat-tokens'); const { EModelEndpoint, supportsBalanceCheck } = require('librechat-data-provider'); const { formatFromLangChain } = require('~/app/clients/prompts'); -const checkBalance = require('~/models/checkBalance'); -const { isEnabled } = require('~/server/utils'); +const { getBalanceConfig } = require('~/server/services/Config'); +const { checkBalance } = require('~/models/balanceMethods'); const { logger } = require('~/config'); const createStartHandler = ({ @@ -49,8 +49,8 @@ const createStartHandler = ({ prelimPromptTokens += tokenBuffer; try { - // TODO: if plugins extends to non-OpenAI models, this will need to be updated - if (isEnabled(process.env.CHECK_BALANCE) && supportsBalanceCheck[EModelEndpoint.openAI]) { + const balance = await getBalanceConfig(); + if (balance?.enabled && supportsBalanceCheck[EModelEndpoint.openAI]) { const generations = initialMessageCount && messages.length > initialMessageCount ? messages.slice(initialMessageCount) diff --git a/api/app/clients/generators.js b/api/app/clients/generators.js new file mode 100644 index 0000000000..9814cac7a5 --- /dev/null +++ b/api/app/clients/generators.js @@ -0,0 +1,71 @@ +const fetch = require('node-fetch'); +const { GraphEvents } = require('@librechat/agents'); +const { logger, sendEvent } = require('~/config'); +const { sleep } = require('~/server/utils'); + +/** + * Makes a function to make HTTP request and logs the process. + * @param {Object} params + * @param {boolean} [params.directEndpoint] - Whether to use a direct endpoint. + * @param {string} [params.reverseProxyUrl] - The reverse proxy URL to use for the request. + * @returns {Promise} - A promise that resolves to the response of the fetch request. + */ +function createFetch({ directEndpoint = false, reverseProxyUrl = '' }) { + /** + * Makes an HTTP request and logs the process. + * @param {RequestInfo} url - The URL to make the request to. Can be a string or a Request object. + * @param {RequestInit} [init] - Optional init options for the request. + * @returns {Promise} - A promise that resolves to the response of the fetch request. + */ + return async (_url, init) => { + let url = _url; + if (directEndpoint) { + url = reverseProxyUrl; + } + logger.debug(`Making request to ${url}`); + if (typeof Bun !== 'undefined') { + return await fetch(url, init); + } + return await fetch(url, init); + }; +} + +// Add this at the module level outside the class +/** + * Creates event handlers for stream events that don't capture client references + * @param {Object} res - The response object to send events to + * @returns {Object} Object containing handler functions + */ +function createStreamEventHandlers(res) { + return { + [GraphEvents.ON_RUN_STEP]: (event) => { + if (res) { + sendEvent(res, event); + } + }, + [GraphEvents.ON_MESSAGE_DELTA]: (event) => { + if (res) { + sendEvent(res, event); + } + }, + [GraphEvents.ON_REASONING_DELTA]: (event) => { + if (res) { + sendEvent(res, event); + } + }, + }; +} + +function createHandleLLMNewToken(streamRate) { + return async () => { + if (streamRate) { + await sleep(streamRate); + } + }; +} + +module.exports = { + createFetch, + createHandleLLMNewToken, + createStreamEventHandlers, +}; diff --git a/api/app/clients/llm/createLLM.js b/api/app/clients/llm/createLLM.js index 7dc0d40ceb..c8d6666bce 100644 --- a/api/app/clients/llm/createLLM.js +++ b/api/app/clients/llm/createLLM.js @@ -34,6 +34,7 @@ function createLLM({ let credentials = { openAIApiKey }; let configuration = { apiKey: openAIApiKey, + ...(configOptions.basePath && { baseURL: configOptions.basePath }), }; /** @type {AzureOptions} */ diff --git a/api/app/clients/prompts/addCacheControl.js b/api/app/clients/prompts/addCacheControl.js index eed5910dc9..6bfd901a65 100644 --- a/api/app/clients/prompts/addCacheControl.js +++ b/api/app/clients/prompts/addCacheControl.js @@ -1,7 +1,7 @@ /** * Anthropic API: Adds cache control to the appropriate user messages in the payload. - * @param {Array} messages - The array of message objects. - * @returns {Array} - The updated array of message objects with cache control added. + * @param {Array} messages - The array of message objects. + * @returns {Array} - The updated array of message objects with cache control added. */ function addCacheControl(messages) { if (!Array.isArray(messages) || messages.length < 2) { @@ -13,7 +13,9 @@ function addCacheControl(messages) { for (let i = updatedMessages.length - 1; i >= 0 && userMessagesModified < 2; i--) { const message = updatedMessages[i]; - if (message.role !== 'user') { + if (message.getType != null && message.getType() !== 'human') { + continue; + } else if (message.getType == null && message.role !== 'user') { continue; } diff --git a/api/app/clients/prompts/formatAgentMessages.spec.js b/api/app/clients/prompts/formatAgentMessages.spec.js index 20731f6984..360fa00a34 100644 --- a/api/app/clients/prompts/formatAgentMessages.spec.js +++ b/api/app/clients/prompts/formatAgentMessages.spec.js @@ -282,4 +282,80 @@ describe('formatAgentMessages', () => { // Additional check to ensure the consecutive assistant messages were combined expect(result[1].content).toHaveLength(2); }); + + it('should skip THINK type content parts', () => { + const payload = [ + { + role: 'assistant', + content: [ + { type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'Initial response' }, + { type: ContentTypes.THINK, [ContentTypes.THINK]: 'Reasoning about the problem...' }, + { type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'Final answer' }, + ], + }, + ]; + + const result = formatAgentMessages(payload); + + expect(result).toHaveLength(1); + expect(result[0]).toBeInstanceOf(AIMessage); + expect(result[0].content).toEqual('Initial response\nFinal answer'); + }); + + it('should join TEXT content as string when THINK content type is present', () => { + const payload = [ + { + role: 'assistant', + content: [ + { type: ContentTypes.THINK, [ContentTypes.THINK]: 'Analyzing the problem...' }, + { type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'First part of response' }, + { type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'Second part of response' }, + { type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'Final part of response' }, + ], + }, + ]; + + const result = formatAgentMessages(payload); + + expect(result).toHaveLength(1); + expect(result[0]).toBeInstanceOf(AIMessage); + expect(typeof result[0].content).toBe('string'); + expect(result[0].content).toBe( + 'First part of response\nSecond part of response\nFinal part of response', + ); + expect(result[0].content).not.toContain('Analyzing the problem...'); + }); + + it('should exclude ERROR type content parts', () => { + const payload = [ + { + role: 'assistant', + content: [ + { type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'Hello there' }, + { + type: ContentTypes.ERROR, + [ContentTypes.ERROR]: + 'An error occurred while processing the request: Something went wrong', + }, + { type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'Final answer' }, + ], + }, + ]; + + const result = formatAgentMessages(payload); + + expect(result).toHaveLength(1); + expect(result[0]).toBeInstanceOf(AIMessage); + expect(result[0].content).toEqual([ + { type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'Hello there' }, + { type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'Final answer' }, + ]); + + // Make sure no error content exists in the result + const hasErrorContent = result[0].content.some( + (item) => + item.type === ContentTypes.ERROR || JSON.stringify(item).includes('An error occurred'), + ); + expect(hasErrorContent).toBe(false); + }); }); diff --git a/api/app/clients/prompts/formatMessages.js b/api/app/clients/prompts/formatMessages.js index d84e62cca8..9fa0d40497 100644 --- a/api/app/clients/prompts/formatMessages.js +++ b/api/app/clients/prompts/formatMessages.js @@ -153,6 +153,7 @@ const formatAgentMessages = (payload) => { let currentContent = []; let lastAIMessage = null; + let hasReasoning = false; for (const part of message.content) { if (part.type === ContentTypes.TEXT && part.tool_call_ids) { /* @@ -207,11 +208,27 @@ const formatAgentMessages = (payload) => { content: output || '', }), ); + } else if (part.type === ContentTypes.THINK) { + hasReasoning = true; + continue; + } else if (part.type === ContentTypes.ERROR || part.type === ContentTypes.AGENT_UPDATE) { + continue; } else { currentContent.push(part); } } + if (hasReasoning) { + currentContent = currentContent + .reduce((acc, curr) => { + if (curr.type === ContentTypes.TEXT) { + return `${acc}${curr[ContentTypes.TEXT]}\n`; + } + return acc; + }, '') + .trim(); + } + if (currentContent.length > 0) { messages.push(new AIMessage({ content: currentContent })); } diff --git a/api/app/clients/specs/AnthropicClient.test.js b/api/app/clients/specs/AnthropicClient.test.js index eef6bb6748..223f3038c0 100644 --- a/api/app/clients/specs/AnthropicClient.test.js +++ b/api/app/clients/specs/AnthropicClient.test.js @@ -1,3 +1,4 @@ +const { SplitStreamHandler } = require('@librechat/agents'); const { anthropicSettings } = require('librechat-data-provider'); const AnthropicClient = require('~/app/clients/AnthropicClient'); @@ -405,4 +406,327 @@ describe('AnthropicClient', () => { expect(Number.isNaN(result)).toBe(false); }); }); + + describe('maxOutputTokens handling for different models', () => { + it('should not cap maxOutputTokens for Claude 3.5 Sonnet models', () => { + const client = new AnthropicClient('test-api-key'); + const highTokenValue = anthropicSettings.legacy.maxOutputTokens.default * 10; + + client.setOptions({ + modelOptions: { + model: 'claude-3-5-sonnet', + maxOutputTokens: highTokenValue, + }, + }); + + expect(client.modelOptions.maxOutputTokens).toBe(highTokenValue); + + // Test with decimal notation + client.setOptions({ + modelOptions: { + model: 'claude-3.5-sonnet', + maxOutputTokens: highTokenValue, + }, + }); + + expect(client.modelOptions.maxOutputTokens).toBe(highTokenValue); + }); + + it('should not cap maxOutputTokens for Claude 3.7 models', () => { + const client = new AnthropicClient('test-api-key'); + const highTokenValue = anthropicSettings.legacy.maxOutputTokens.default * 2; + + client.setOptions({ + modelOptions: { + model: 'claude-3-7-sonnet', + maxOutputTokens: highTokenValue, + }, + }); + + expect(client.modelOptions.maxOutputTokens).toBe(highTokenValue); + + // Test with decimal notation + client.setOptions({ + modelOptions: { + model: 'claude-3.7-sonnet', + maxOutputTokens: highTokenValue, + }, + }); + + expect(client.modelOptions.maxOutputTokens).toBe(highTokenValue); + }); + + it('should cap maxOutputTokens for Claude 3.5 Haiku models', () => { + const client = new AnthropicClient('test-api-key'); + const highTokenValue = anthropicSettings.legacy.maxOutputTokens.default * 2; + + client.setOptions({ + modelOptions: { + model: 'claude-3-5-haiku', + maxOutputTokens: highTokenValue, + }, + }); + + expect(client.modelOptions.maxOutputTokens).toBe( + anthropicSettings.legacy.maxOutputTokens.default, + ); + + // Test with decimal notation + client.setOptions({ + modelOptions: { + model: 'claude-3.5-haiku', + maxOutputTokens: highTokenValue, + }, + }); + + expect(client.modelOptions.maxOutputTokens).toBe( + anthropicSettings.legacy.maxOutputTokens.default, + ); + }); + + it('should cap maxOutputTokens for Claude 3 Haiku and Opus models', () => { + const client = new AnthropicClient('test-api-key'); + const highTokenValue = anthropicSettings.legacy.maxOutputTokens.default * 2; + + // Test haiku + client.setOptions({ + modelOptions: { + model: 'claude-3-haiku', + maxOutputTokens: highTokenValue, + }, + }); + + expect(client.modelOptions.maxOutputTokens).toBe( + anthropicSettings.legacy.maxOutputTokens.default, + ); + + // Test opus + client.setOptions({ + modelOptions: { + model: 'claude-3-opus', + maxOutputTokens: highTokenValue, + }, + }); + + expect(client.modelOptions.maxOutputTokens).toBe( + anthropicSettings.legacy.maxOutputTokens.default, + ); + }); + }); + + describe('topK/topP parameters for different models', () => { + beforeEach(() => { + // Mock the SplitStreamHandler + jest.spyOn(SplitStreamHandler.prototype, 'handle').mockImplementation(() => {}); + }); + + afterEach(() => { + jest.restoreAllMocks(); + }); + + it('should include top_k and top_p parameters for non-claude-3.7 models', async () => { + const client = new AnthropicClient('test-api-key'); + + // Create a mock async generator function + async function* mockAsyncGenerator() { + yield { type: 'message_start', message: { usage: {} } }; + yield { delta: { text: 'Test response' } }; + yield { type: 'message_delta', usage: {} }; + } + + // Mock createResponse to return the async generator + jest.spyOn(client, 'createResponse').mockImplementation(() => { + return mockAsyncGenerator(); + }); + + client.setOptions({ + modelOptions: { + model: 'claude-3-opus', + temperature: 0.7, + topK: 10, + topP: 0.9, + }, + }); + + // Mock getClient to capture the request options + let capturedOptions = null; + jest.spyOn(client, 'getClient').mockImplementation((options) => { + capturedOptions = options; + return {}; + }); + + const payload = [{ role: 'user', content: 'Test message' }]; + await client.sendCompletion(payload, {}); + + // Check the options passed to getClient + expect(capturedOptions).toHaveProperty('top_k', 10); + expect(capturedOptions).toHaveProperty('top_p', 0.9); + }); + + it('should include top_k and top_p parameters for claude-3-5-sonnet models', async () => { + const client = new AnthropicClient('test-api-key'); + + // Create a mock async generator function + async function* mockAsyncGenerator() { + yield { type: 'message_start', message: { usage: {} } }; + yield { delta: { text: 'Test response' } }; + yield { type: 'message_delta', usage: {} }; + } + + // Mock createResponse to return the async generator + jest.spyOn(client, 'createResponse').mockImplementation(() => { + return mockAsyncGenerator(); + }); + + client.setOptions({ + modelOptions: { + model: 'claude-3-5-sonnet', + temperature: 0.7, + topK: 10, + topP: 0.9, + }, + }); + + // Mock getClient to capture the request options + let capturedOptions = null; + jest.spyOn(client, 'getClient').mockImplementation((options) => { + capturedOptions = options; + return {}; + }); + + const payload = [{ role: 'user', content: 'Test message' }]; + await client.sendCompletion(payload, {}); + + // Check the options passed to getClient + expect(capturedOptions).toHaveProperty('top_k', 10); + expect(capturedOptions).toHaveProperty('top_p', 0.9); + }); + + it('should not include top_k and top_p parameters for claude-3-7-sonnet models', async () => { + const client = new AnthropicClient('test-api-key'); + + // Create a mock async generator function + async function* mockAsyncGenerator() { + yield { type: 'message_start', message: { usage: {} } }; + yield { delta: { text: 'Test response' } }; + yield { type: 'message_delta', usage: {} }; + } + + // Mock createResponse to return the async generator + jest.spyOn(client, 'createResponse').mockImplementation(() => { + return mockAsyncGenerator(); + }); + + client.setOptions({ + modelOptions: { + model: 'claude-3-7-sonnet', + temperature: 0.7, + topK: 10, + topP: 0.9, + }, + }); + + // Mock getClient to capture the request options + let capturedOptions = null; + jest.spyOn(client, 'getClient').mockImplementation((options) => { + capturedOptions = options; + return {}; + }); + + const payload = [{ role: 'user', content: 'Test message' }]; + await client.sendCompletion(payload, {}); + + // Check the options passed to getClient + expect(capturedOptions).not.toHaveProperty('top_k'); + expect(capturedOptions).not.toHaveProperty('top_p'); + }); + + it('should not include top_k and top_p parameters for models with decimal notation (claude-3.7)', async () => { + const client = new AnthropicClient('test-api-key'); + + // Create a mock async generator function + async function* mockAsyncGenerator() { + yield { type: 'message_start', message: { usage: {} } }; + yield { delta: { text: 'Test response' } }; + yield { type: 'message_delta', usage: {} }; + } + + // Mock createResponse to return the async generator + jest.spyOn(client, 'createResponse').mockImplementation(() => { + return mockAsyncGenerator(); + }); + + client.setOptions({ + modelOptions: { + model: 'claude-3.7-sonnet', + temperature: 0.7, + topK: 10, + topP: 0.9, + }, + }); + + // Mock getClient to capture the request options + let capturedOptions = null; + jest.spyOn(client, 'getClient').mockImplementation((options) => { + capturedOptions = options; + return {}; + }); + + const payload = [{ role: 'user', content: 'Test message' }]; + await client.sendCompletion(payload, {}); + + // Check the options passed to getClient + expect(capturedOptions).not.toHaveProperty('top_k'); + expect(capturedOptions).not.toHaveProperty('top_p'); + }); + }); + + it('should include top_k and top_p parameters for Claude-3.7 models when thinking is explicitly disabled', async () => { + const client = new AnthropicClient('test-api-key', { + modelOptions: { + model: 'claude-3-7-sonnet', + temperature: 0.7, + topK: 10, + topP: 0.9, + }, + thinking: false, + }); + + async function* mockAsyncGenerator() { + yield { type: 'message_start', message: { usage: {} } }; + yield { delta: { text: 'Test response' } }; + yield { type: 'message_delta', usage: {} }; + } + + jest.spyOn(client, 'createResponse').mockImplementation(() => { + return mockAsyncGenerator(); + }); + + let capturedOptions = null; + jest.spyOn(client, 'getClient').mockImplementation((options) => { + capturedOptions = options; + return {}; + }); + + const payload = [{ role: 'user', content: 'Test message' }]; + await client.sendCompletion(payload, {}); + + expect(capturedOptions).toHaveProperty('topK', 10); + expect(capturedOptions).toHaveProperty('topP', 0.9); + + client.setOptions({ + modelOptions: { + model: 'claude-3.7-sonnet', + temperature: 0.7, + topK: 10, + topP: 0.9, + }, + thinking: false, + }); + + await client.sendCompletion(payload, {}); + + expect(capturedOptions).toHaveProperty('topK', 10); + expect(capturedOptions).toHaveProperty('topP', 0.9); + }); }); diff --git a/api/app/clients/specs/BaseClient.test.js b/api/app/clients/specs/BaseClient.test.js index e899449fb9..d620d5f647 100644 --- a/api/app/clients/specs/BaseClient.test.js +++ b/api/app/clients/specs/BaseClient.test.js @@ -30,7 +30,9 @@ jest.mock('~/models', () => ({ updateFileUsage: jest.fn(), })); -jest.mock('@langchain/openai', () => { +const { getConvo, saveConvo } = require('~/models'); + +jest.mock('@librechat/agents', () => { return { ChatOpenAI: jest.fn().mockImplementation(() => { return {}; @@ -162,7 +164,7 @@ describe('BaseClient', () => { const result = await TestClient.getMessagesWithinTokenLimit({ messages }); expect(result.context).toEqual(expectedContext); - expect(result.summaryIndex).toEqual(expectedIndex); + expect(result.messagesToRefine.length - 1).toEqual(expectedIndex); expect(result.remainingContextTokens).toBe(expectedRemainingContextTokens); expect(result.messagesToRefine).toEqual(expectedMessagesToRefine); }); @@ -198,7 +200,7 @@ describe('BaseClient', () => { const result = await TestClient.getMessagesWithinTokenLimit({ messages }); expect(result.context).toEqual(expectedContext); - expect(result.summaryIndex).toEqual(expectedIndex); + expect(result.messagesToRefine.length - 1).toEqual(expectedIndex); expect(result.remainingContextTokens).toBe(expectedRemainingContextTokens); expect(result.messagesToRefine).toEqual(expectedMessagesToRefine); }); @@ -540,10 +542,11 @@ describe('BaseClient', () => { test('saveMessageToDatabase is called with the correct arguments', async () => { const saveOptions = TestClient.getSaveOptions(); - const user = {}; // Mock user + const user = {}; const opts = { user }; + const saveSpy = jest.spyOn(TestClient, 'saveMessageToDatabase'); await TestClient.sendMessage('Hello, world!', opts); - expect(TestClient.saveMessageToDatabase).toHaveBeenCalledWith( + expect(saveSpy).toHaveBeenCalledWith( expect.objectContaining({ sender: expect.any(String), text: expect.any(String), @@ -557,6 +560,157 @@ describe('BaseClient', () => { ); }); + test('should handle existing conversation when getConvo retrieves one', async () => { + const existingConvo = { + conversationId: 'existing-convo-id', + endpoint: 'openai', + endpointType: 'openai', + model: 'gpt-3.5-turbo', + messages: [ + { role: 'user', content: 'Existing message 1' }, + { role: 'assistant', content: 'Existing response 1' }, + ], + temperature: 1, + }; + + const { temperature: _temp, ...newConvo } = existingConvo; + + const user = { + id: 'user-id', + }; + + getConvo.mockResolvedValue(existingConvo); + saveConvo.mockResolvedValue(newConvo); + + TestClient = initializeFakeClient( + apiKey, + { + ...options, + req: { + user, + }, + }, + [], + ); + + const saveSpy = jest.spyOn(TestClient, 'saveMessageToDatabase'); + + const newMessage = 'New message in existing conversation'; + const response = await TestClient.sendMessage(newMessage, { + user, + conversationId: existingConvo.conversationId, + }); + + expect(getConvo).toHaveBeenCalledWith(user.id, existingConvo.conversationId); + expect(TestClient.conversationId).toBe(existingConvo.conversationId); + expect(response.conversationId).toBe(existingConvo.conversationId); + expect(TestClient.fetchedConvo).toBe(true); + + expect(saveSpy).toHaveBeenCalledWith( + expect.objectContaining({ + conversationId: existingConvo.conversationId, + text: newMessage, + }), + expect.any(Object), + expect.any(Object), + ); + + expect(saveConvo).toHaveBeenCalledTimes(2); + expect(saveConvo).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + conversationId: existingConvo.conversationId, + }), + expect.objectContaining({ + context: 'api/app/clients/BaseClient.js - saveMessageToDatabase #saveConvo', + unsetFields: { + temperature: 1, + }, + }), + ); + + await TestClient.sendMessage('Another message', { + conversationId: existingConvo.conversationId, + }); + expect(getConvo).toHaveBeenCalledTimes(1); + }); + + test('should correctly handle existing conversation and unset fields appropriately', async () => { + const existingConvo = { + conversationId: 'existing-convo-id', + endpoint: 'openai', + endpointType: 'openai', + model: 'gpt-3.5-turbo', + messages: [ + { role: 'user', content: 'Existing message 1' }, + { role: 'assistant', content: 'Existing response 1' }, + ], + title: 'Existing Conversation', + someExistingField: 'existingValue', + anotherExistingField: 'anotherValue', + temperature: 0.7, + modelLabel: 'GPT-3.5', + }; + + getConvo.mockResolvedValue(existingConvo); + saveConvo.mockResolvedValue(existingConvo); + + TestClient = initializeFakeClient( + apiKey, + { + ...options, + modelOptions: { + model: 'gpt-4', + temperature: 0.5, + }, + }, + [], + ); + + const newMessage = 'New message in existing conversation'; + await TestClient.sendMessage(newMessage, { + conversationId: existingConvo.conversationId, + }); + + expect(saveConvo).toHaveBeenCalledTimes(2); + + const saveConvoCall = saveConvo.mock.calls[0]; + const [, savedFields, saveOptions] = saveConvoCall; + + // Instead of checking all excludedKeys, we'll just check specific fields + // that we know should be excluded + expect(savedFields).not.toHaveProperty('messages'); + expect(savedFields).not.toHaveProperty('title'); + + // Only check that someExistingField is in unsetFields + expect(saveOptions.unsetFields).toHaveProperty('someExistingField', 1); + + // Mock saveConvo to return the expected fields + saveConvo.mockImplementation((req, fields) => { + return Promise.resolve({ + ...fields, + endpoint: 'openai', + endpointType: 'openai', + model: 'gpt-4', + temperature: 0.5, + }); + }); + + // Only check the conversationId since that's the only field we can be sure about + expect(savedFields).toHaveProperty('conversationId', 'existing-convo-id'); + + expect(TestClient.fetchedConvo).toBe(true); + + await TestClient.sendMessage('Another message', { + conversationId: existingConvo.conversationId, + }); + + expect(getConvo).toHaveBeenCalledTimes(1); + + const secondSaveConvoCall = saveConvo.mock.calls[1]; + expect(secondSaveConvoCall[2]).toHaveProperty('unsetFields', {}); + }); + test('sendCompletion is called with the correct arguments', async () => { const payload = {}; // Mock payload TestClient.buildMessages.mockReturnValue({ prompt: payload, tokenCountMap: null }); diff --git a/api/app/clients/specs/FakeClient.js b/api/app/clients/specs/FakeClient.js index 7f4b75e1db..a466bb97f9 100644 --- a/api/app/clients/specs/FakeClient.js +++ b/api/app/clients/specs/FakeClient.js @@ -56,7 +56,6 @@ const initializeFakeClient = (apiKey, options, fakeMessages) => { let TestClient = new FakeClient(apiKey); TestClient.options = options; TestClient.abortController = { abort: jest.fn() }; - TestClient.saveMessageToDatabase = jest.fn(); TestClient.loadHistory = jest .fn() .mockImplementation((conversationId, parentMessageId = null) => { @@ -86,7 +85,6 @@ const initializeFakeClient = (apiKey, options, fakeMessages) => { return 'Mock response text'; }); - // eslint-disable-next-line no-unused-vars TestClient.getCompletion = jest.fn().mockImplementation(async (..._args) => { return { choices: [ diff --git a/api/app/clients/specs/OpenAIClient.test.js b/api/app/clients/specs/OpenAIClient.test.js index 2aaec518eb..579f636eef 100644 --- a/api/app/clients/specs/OpenAIClient.test.js +++ b/api/app/clients/specs/OpenAIClient.test.js @@ -1,9 +1,7 @@ jest.mock('~/cache/getLogStores'); require('dotenv').config(); -const OpenAI = require('openai'); -const getLogStores = require('~/cache/getLogStores'); const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source'); -const { genAzureChatCompletion } = require('~/utils/azureUtils'); +const getLogStores = require('~/cache/getLogStores'); const OpenAIClient = require('../OpenAIClient'); jest.mock('meilisearch'); @@ -36,19 +34,21 @@ jest.mock('~/models', () => ({ updateFileUsage: jest.fn(), })); -jest.mock('@langchain/openai', () => { - return { - ChatOpenAI: jest.fn().mockImplementation(() => { - return {}; - }), - }; +// Import the actual module but mock specific parts +const agents = jest.requireActual('@librechat/agents'); +const { CustomOpenAIClient } = agents; + +// Also mock ChatOpenAI to prevent real API calls +agents.ChatOpenAI = jest.fn().mockImplementation(() => { + return {}; +}); +agents.AzureChatOpenAI = jest.fn().mockImplementation(() => { + return {}; }); -jest.mock('openai'); - -jest.spyOn(OpenAI, 'constructor').mockImplementation(function (...options) { - // We can add additional logic here if needed - return new OpenAI(...options); +// Mock only the CustomOpenAIClient constructor +jest.spyOn(CustomOpenAIClient, 'constructor').mockImplementation(function (...options) { + return new CustomOpenAIClient(...options); }); const finalChatCompletion = jest.fn().mockResolvedValue({ @@ -120,7 +120,13 @@ const create = jest.fn().mockResolvedValue({ ], }); -OpenAI.mockImplementation(() => ({ +// Mock the implementation of CustomOpenAIClient instances +jest.spyOn(CustomOpenAIClient.prototype, 'constructor').mockImplementation(function () { + return this; +}); + +// Create a mock for the CustomOpenAIClient class +const mockCustomOpenAIClient = jest.fn().mockImplementation(() => ({ beta: { chat: { completions: { @@ -135,11 +141,14 @@ OpenAI.mockImplementation(() => ({ }, })); -describe('OpenAIClient', () => { - const mockSet = jest.fn(); - const mockCache = { set: mockSet }; +CustomOpenAIClient.mockImplementation = mockCustomOpenAIClient; +describe('OpenAIClient', () => { beforeEach(() => { + const mockCache = { + get: jest.fn().mockResolvedValue({}), + set: jest.fn(), + }; getLogStores.mockReturnValue(mockCache); }); let client; @@ -202,14 +211,6 @@ describe('OpenAIClient', () => { expect(client.modelOptions.temperature).toBe(0.7); }); - it('should set apiKey and useOpenRouter if OPENROUTER_API_KEY is present', () => { - process.env.OPENROUTER_API_KEY = 'openrouter-key'; - client.setOptions({}); - expect(client.apiKey).toBe('openrouter-key'); - expect(client.useOpenRouter).toBe(true); - delete process.env.OPENROUTER_API_KEY; // Cleanup - }); - it('should set FORCE_PROMPT based on OPENAI_FORCE_PROMPT or reverseProxyUrl', () => { process.env.OPENAI_FORCE_PROMPT = 'true'; client.setOptions({}); @@ -534,7 +535,6 @@ describe('OpenAIClient', () => { afterEach(() => { delete process.env.AZURE_OPENAI_DEFAULT_MODEL; delete process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME; - delete process.env.OPENROUTER_API_KEY; }); it('should call getCompletion and fetchEventSource when using a text/instruct model', async () => { @@ -567,41 +567,6 @@ describe('OpenAIClient', () => { expect(requestBody).toHaveProperty('model'); expect(requestBody.model).toBe(model); }); - - it('[Azure OpenAI] should call chatCompletion and OpenAI.stream with correct args', async () => { - // Set a default model - process.env.AZURE_OPENAI_DEFAULT_MODEL = 'gpt4-turbo'; - - const onProgress = jest.fn().mockImplementation(() => ({})); - client.azure = defaultAzureOptions; - const chatCompletion = jest.spyOn(client, 'chatCompletion'); - await client.sendMessage('Hi mom!', { - replaceOptions: true, - ...defaultOptions, - modelOptions: { model: 'gpt4-turbo', stream: true }, - onProgress, - azure: defaultAzureOptions, - }); - - expect(chatCompletion).toHaveBeenCalled(); - expect(chatCompletion.mock.calls.length).toBe(1); - - const chatCompletionArgs = chatCompletion.mock.calls[0][0]; - const { payload } = chatCompletionArgs; - - expect(payload[0].role).toBe('user'); - expect(payload[0].content).toBe('Hi mom!'); - - // Azure OpenAI does not use the model property, and will error if it's passed - // This check ensures the model property is not present - const streamArgs = stream.mock.calls[0][0]; - expect(streamArgs).not.toHaveProperty('model'); - - // Check if the baseURL is correct - const constructorArgs = OpenAI.mock.calls[0][0]; - const expectedURL = genAzureChatCompletion(defaultAzureOptions).split('/chat')[0]; - expect(constructorArgs.baseURL).toBe(expectedURL); - }); }); describe('checkVisionRequest functionality', () => { diff --git a/api/app/clients/tools/index.js b/api/app/clients/tools/index.js index b8df50c77d..87b1884e88 100644 --- a/api/app/clients/tools/index.js +++ b/api/app/clients/tools/index.js @@ -2,13 +2,15 @@ const availableTools = require('./manifest.json'); // Structured Tools const DALLE3 = require('./structured/DALLE3'); +const FluxAPI = require('./structured/FluxAPI'); const OpenWeather = require('./structured/OpenWeather'); -const createYouTubeTools = require('./structured/YouTube'); const StructuredWolfram = require('./structured/Wolfram'); +const createYouTubeTools = require('./structured/YouTube'); const StructuredACS = require('./structured/AzureAISearch'); const StructuredSD = require('./structured/StableDiffusion'); const GoogleSearchAPI = require('./structured/GoogleSearch'); const TraversaalSearch = require('./structured/TraversaalSearch'); +const createOpenAIImageTools = require('./structured/OpenAIImageTools'); const TavilySearchResults = require('./structured/TavilySearchResults'); /** @type {Record} */ @@ -30,6 +32,7 @@ module.exports = { manifestToolMap, // Structured Tools DALLE3, + FluxAPI, OpenWeather, StructuredSD, StructuredACS, @@ -38,4 +41,5 @@ module.exports = { StructuredWolfram, createYouTubeTools, TavilySearchResults, + createOpenAIImageTools, }; diff --git a/api/app/clients/tools/manifest.json b/api/app/clients/tools/manifest.json index 7cb92b8d87..55c1b1c51e 100644 --- a/api/app/clients/tools/manifest.json +++ b/api/app/clients/tools/manifest.json @@ -44,6 +44,20 @@ } ] }, + { + "name": "OpenAI Image Tools", + "pluginKey": "image_gen_oai", + "toolkit": true, + "description": "Image Generation and Editing using OpenAI's latest state-of-the-art models", + "icon": "/assets/image_gen_oai.png", + "authConfig": [ + { + "authField": "IMAGE_GEN_OAI_API_KEY", + "label": "OpenAI Image Tools API Key", + "description": "Your OpenAI API Key for Image Generation and Editing" + } + ] + }, { "name": "Wolfram", "pluginKey": "wolfram", @@ -164,5 +178,19 @@ "description": "Sign up at OpenWeather, then get your key at API keys." } ] + }, + { + "name": "Flux", + "pluginKey": "flux", + "description": "Generate images using text with the Flux API.", + "icon": "https://blackforestlabs.ai/wp-content/uploads/2024/07/bfl_logo_retraced_blk.png", + "isAuthRequired": "true", + "authConfig": [ + { + "authField": "FLUX_API_KEY", + "label": "Your Flux API Key", + "description": "Provide your Flux API key from your user profile." + } + ] } ] diff --git a/api/app/clients/tools/structured/DALLE3.js b/api/app/clients/tools/structured/DALLE3.js index b604ad4ea4..fc0f1851f6 100644 --- a/api/app/clients/tools/structured/DALLE3.js +++ b/api/app/clients/tools/structured/DALLE3.js @@ -1,14 +1,17 @@ const { z } = require('zod'); const path = require('path'); const OpenAI = require('openai'); +const fetch = require('node-fetch'); const { v4: uuidv4 } = require('uuid'); const { Tool } = require('@langchain/core/tools'); const { HttpsProxyAgent } = require('https-proxy-agent'); -const { FileContext } = require('librechat-data-provider'); +const { FileContext, ContentTypes } = require('librechat-data-provider'); const { getImageBasename } = require('~/server/services/Files/images'); const extractBaseURL = require('~/utils/extractBaseURL'); const { logger } = require('~/config'); +const displayMessage = + 'DALL-E displayed an image. All generated images are already plainly visible, so don\'t repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.'; class DALLE3 extends Tool { constructor(fields = {}) { super(); @@ -114,10 +117,7 @@ class DALLE3 extends Tool { if (this.isAgent === true && typeof value === 'string') { return [value, {}]; } else if (this.isAgent === true && typeof value === 'object') { - return [ - 'DALL-E displayed an image. All generated images are already plainly visible, so don\'t repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.', - value, - ]; + return [displayMessage, value]; } return value; @@ -160,6 +160,32 @@ Error Message: ${error.message}`); ); } + if (this.isAgent) { + let fetchOptions = {}; + if (process.env.PROXY) { + fetchOptions.agent = new HttpsProxyAgent(process.env.PROXY); + } + const imageResponse = await fetch(theImageUrl, fetchOptions); + const arrayBuffer = await imageResponse.arrayBuffer(); + const base64 = Buffer.from(arrayBuffer).toString('base64'); + const content = [ + { + type: ContentTypes.IMAGE_URL, + image_url: { + url: `data:image/png;base64,${base64}`, + }, + }, + ]; + + const response = [ + { + type: ContentTypes.TEXT, + text: displayMessage, + }, + ]; + return [response, { content }]; + } + const imageBasename = getImageBasename(theImageUrl); const imageExt = path.extname(imageBasename); diff --git a/api/app/clients/tools/structured/FluxAPI.js b/api/app/clients/tools/structured/FluxAPI.js new file mode 100644 index 0000000000..80f9772200 --- /dev/null +++ b/api/app/clients/tools/structured/FluxAPI.js @@ -0,0 +1,554 @@ +const { z } = require('zod'); +const axios = require('axios'); +const fetch = require('node-fetch'); +const { v4: uuidv4 } = require('uuid'); +const { Tool } = require('@langchain/core/tools'); +const { HttpsProxyAgent } = require('https-proxy-agent'); +const { FileContext, ContentTypes } = require('librechat-data-provider'); +const { logger } = require('~/config'); + +const displayMessage = + 'Flux displayed an image. All generated images are already plainly visible, so don\'t repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.'; + +/** + * FluxAPI - A tool for generating high-quality images from text prompts using the Flux API. + * Each call generates one image. If multiple images are needed, make multiple consecutive calls with the same or varied prompts. + */ +class FluxAPI extends Tool { + // Pricing constants in USD per image + static PRICING = { + FLUX_PRO_1_1_ULTRA: -0.06, // /v1/flux-pro-1.1-ultra + FLUX_PRO_1_1: -0.04, // /v1/flux-pro-1.1 + FLUX_PRO: -0.05, // /v1/flux-pro + FLUX_DEV: -0.025, // /v1/flux-dev + FLUX_PRO_FINETUNED: -0.06, // /v1/flux-pro-finetuned + FLUX_PRO_1_1_ULTRA_FINETUNED: -0.07, // /v1/flux-pro-1.1-ultra-finetuned + }; + + constructor(fields = {}) { + super(); + + /** @type {boolean} Used to initialize the Tool without necessary variables. */ + this.override = fields.override ?? false; + + this.userId = fields.userId; + this.fileStrategy = fields.fileStrategy; + + /** @type {boolean} **/ + this.isAgent = fields.isAgent; + this.returnMetadata = fields.returnMetadata ?? false; + + if (fields.processFileURL) { + /** @type {processFileURL} Necessary for output to contain all image metadata. */ + this.processFileURL = fields.processFileURL.bind(this); + } + + this.apiKey = fields.FLUX_API_KEY || this.getApiKey(); + + this.name = 'flux'; + this.description = + 'Use Flux to generate images from text descriptions. This tool can generate images and list available finetunes. Each generate call creates one image. For multiple images, make multiple consecutive calls.'; + + this.description_for_model = `// Transform any image description into a detailed, high-quality prompt. Never submit a prompt under 3 sentences. Follow these core rules: + // 1. ALWAYS enhance basic prompts into 5-10 detailed sentences (e.g., "a cat" becomes: "A close-up photo of a sleek Siamese cat with piercing blue eyes. The cat sits elegantly on a vintage leather armchair, its tail curled gracefully around its paws. Warm afternoon sunlight streams through a nearby window, casting gentle shadows across its face and highlighting the subtle variations in its cream and chocolate-point fur. The background is softly blurred, creating a shallow depth of field that draws attention to the cat's expressive features. The overall composition has a peaceful, contemplative mood with a professional photography style.") + // 2. Each prompt MUST be 3-6 descriptive sentences minimum, focusing on visual elements: lighting, composition, mood, and style + // Use action: 'list_finetunes' to see available custom models. When using finetunes, use endpoint: '/v1/flux-pro-finetuned' (default) or '/v1/flux-pro-1.1-ultra-finetuned' for higher quality and aspect ratio.`; + + // Add base URL from environment variable with fallback + this.baseUrl = process.env.FLUX_API_BASE_URL || 'https://api.us1.bfl.ai'; + + // Define the schema for structured input + this.schema = z.object({ + action: z + .enum(['generate', 'list_finetunes', 'generate_finetuned']) + .default('generate') + .describe( + 'Action to perform: "generate" for image generation, "generate_finetuned" for finetuned model generation, "list_finetunes" to get available custom models', + ), + prompt: z + .string() + .optional() + .describe( + 'Text prompt for image generation. Required when action is "generate". Not used for list_finetunes.', + ), + width: z + .number() + .optional() + .describe( + 'Width of the generated image in pixels. Must be a multiple of 32. Default is 1024.', + ), + height: z + .number() + .optional() + .describe( + 'Height of the generated image in pixels. Must be a multiple of 32. Default is 768.', + ), + prompt_upsampling: z + .boolean() + .optional() + .default(false) + .describe('Whether to perform upsampling on the prompt.'), + steps: z + .number() + .int() + .optional() + .describe('Number of steps to run the model for, a number from 1 to 50. Default is 40.'), + seed: z.number().optional().describe('Optional seed for reproducibility.'), + safety_tolerance: z + .number() + .optional() + .default(6) + .describe( + 'Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict.', + ), + endpoint: z + .enum([ + '/v1/flux-pro-1.1', + '/v1/flux-pro', + '/v1/flux-dev', + '/v1/flux-pro-1.1-ultra', + '/v1/flux-pro-finetuned', + '/v1/flux-pro-1.1-ultra-finetuned', + ]) + .optional() + .default('/v1/flux-pro-1.1') + .describe('Endpoint to use for image generation.'), + raw: z + .boolean() + .optional() + .default(false) + .describe( + 'Generate less processed, more natural-looking images. Only works for /v1/flux-pro-1.1-ultra.', + ), + finetune_id: z.string().optional().describe('ID of the finetuned model to use'), + finetune_strength: z + .number() + .optional() + .default(1.1) + .describe('Strength of the finetuning effect (typically between 0.1 and 1.2)'), + guidance: z.number().optional().default(2.5).describe('Guidance scale for finetuned models'), + aspect_ratio: z + .string() + .optional() + .default('16:9') + .describe('Aspect ratio for ultra models (e.g., "16:9")'), + }); + } + + getAxiosConfig() { + const config = {}; + if (process.env.PROXY) { + config.httpsAgent = new HttpsProxyAgent(process.env.PROXY); + } + return config; + } + + /** @param {Object|string} value */ + getDetails(value) { + if (typeof value === 'string') { + return value; + } + return JSON.stringify(value, null, 2); + } + + getApiKey() { + const apiKey = process.env.FLUX_API_KEY || ''; + if (!apiKey && !this.override) { + throw new Error('Missing FLUX_API_KEY environment variable.'); + } + return apiKey; + } + + wrapInMarkdown(imageUrl) { + const serverDomain = process.env.DOMAIN_SERVER || 'http://localhost:3080'; + return `![generated image](${serverDomain}${imageUrl})`; + } + + returnValue(value) { + if (this.isAgent === true && typeof value === 'string') { + return [value, {}]; + } else if (this.isAgent === true && typeof value === 'object') { + if (Array.isArray(value)) { + return value; + } + return [displayMessage, value]; + } + return value; + } + + async _call(data) { + const { action = 'generate', ...imageData } = data; + + // Use provided API key for this request if available, otherwise use default + const requestApiKey = this.apiKey || this.getApiKey(); + + // Handle list_finetunes action + if (action === 'list_finetunes') { + return this.getMyFinetunes(requestApiKey); + } + + // Handle finetuned generation + if (action === 'generate_finetuned') { + return this.generateFinetunedImage(imageData, requestApiKey); + } + + // For generate action, ensure prompt is provided + if (!imageData.prompt) { + throw new Error('Missing required field: prompt'); + } + + let payload = { + prompt: imageData.prompt, + prompt_upsampling: imageData.prompt_upsampling || false, + safety_tolerance: imageData.safety_tolerance || 6, + output_format: imageData.output_format || 'png', + }; + + // Add optional parameters if provided + if (imageData.width) { + payload.width = imageData.width; + } + if (imageData.height) { + payload.height = imageData.height; + } + if (imageData.steps) { + payload.steps = imageData.steps; + } + if (imageData.seed !== undefined) { + payload.seed = imageData.seed; + } + if (imageData.raw) { + payload.raw = imageData.raw; + } + + const generateUrl = `${this.baseUrl}${imageData.endpoint || '/v1/flux-pro'}`; + const resultUrl = `${this.baseUrl}/v1/get_result`; + + logger.debug('[FluxAPI] Generating image with payload:', payload); + logger.debug('[FluxAPI] Using endpoint:', generateUrl); + + let taskResponse; + try { + taskResponse = await axios.post(generateUrl, payload, { + headers: { + 'x-key': requestApiKey, + 'Content-Type': 'application/json', + Accept: 'application/json', + }, + ...this.getAxiosConfig(), + }); + } catch (error) { + const details = this.getDetails(error?.response?.data || error.message); + logger.error('[FluxAPI] Error while submitting task:', details); + + return this.returnValue( + `Something went wrong when trying to generate the image. The Flux API may be unavailable: + Error Message: ${details}`, + ); + } + + const taskId = taskResponse.data.id; + + // Polling for the result + let status = 'Pending'; + let resultData = null; + while (status !== 'Ready' && status !== 'Error') { + try { + // Wait 2 seconds between polls + await new Promise((resolve) => setTimeout(resolve, 2000)); + const resultResponse = await axios.get(resultUrl, { + headers: { + 'x-key': requestApiKey, + Accept: 'application/json', + }, + params: { id: taskId }, + ...this.getAxiosConfig(), + }); + status = resultResponse.data.status; + + if (status === 'Ready') { + resultData = resultResponse.data.result; + break; + } else if (status === 'Error') { + logger.error('[FluxAPI] Error in task:', resultResponse.data); + return this.returnValue('An error occurred during image generation.'); + } + } catch (error) { + const details = this.getDetails(error?.response?.data || error.message); + logger.error('[FluxAPI] Error while getting result:', details); + return this.returnValue('An error occurred while retrieving the image.'); + } + } + + // If no result data + if (!resultData || !resultData.sample) { + logger.error('[FluxAPI] No image data received from API. Response:', resultData); + return this.returnValue('No image data received from Flux API.'); + } + + // Try saving the image locally + const imageUrl = resultData.sample; + const imageName = `img-${uuidv4()}.png`; + + if (this.isAgent) { + try { + // Fetch the image and convert to base64 + const fetchOptions = {}; + if (process.env.PROXY) { + fetchOptions.agent = new HttpsProxyAgent(process.env.PROXY); + } + const imageResponse = await fetch(imageUrl, fetchOptions); + const arrayBuffer = await imageResponse.arrayBuffer(); + const base64 = Buffer.from(arrayBuffer).toString('base64'); + const content = [ + { + type: ContentTypes.IMAGE_URL, + image_url: { + url: `data:image/png;base64,${base64}`, + }, + }, + ]; + + const response = [ + { + type: ContentTypes.TEXT, + text: displayMessage, + }, + ]; + return [response, { content }]; + } catch (error) { + logger.error('Error processing image for agent:', error); + return this.returnValue(`Failed to process the image. ${error.message}`); + } + } + + try { + logger.debug('[FluxAPI] Saving image:', imageUrl); + const result = await this.processFileURL({ + fileStrategy: this.fileStrategy, + userId: this.userId, + URL: imageUrl, + fileName: imageName, + basePath: 'images', + context: FileContext.image_generation, + }); + + logger.debug('[FluxAPI] Image saved to path:', result.filepath); + + // Calculate cost based on endpoint + /** + * TODO: Cost handling + const endpoint = imageData.endpoint || '/v1/flux-pro'; + const endpointKey = Object.entries(FluxAPI.PRICING).find(([key, _]) => + endpoint.includes(key.toLowerCase().replace(/_/g, '-')), + )?.[0]; + const cost = FluxAPI.PRICING[endpointKey] || 0; + */ + this.result = this.returnMetadata ? result : this.wrapInMarkdown(result.filepath); + return this.returnValue(this.result); + } catch (error) { + const details = this.getDetails(error?.message ?? 'No additional error details.'); + logger.error('Error while saving the image:', details); + return this.returnValue(`Failed to save the image locally. ${details}`); + } + } + + async getMyFinetunes(apiKey = null) { + const finetunesUrl = `${this.baseUrl}/v1/my_finetunes`; + const detailsUrl = `${this.baseUrl}/v1/finetune_details`; + + try { + const headers = { + 'x-key': apiKey || this.getApiKey(), + 'Content-Type': 'application/json', + Accept: 'application/json', + }; + + // Get list of finetunes + const response = await axios.get(finetunesUrl, { + headers, + ...this.getAxiosConfig(), + }); + const finetunes = response.data.finetunes; + + // Fetch details for each finetune + const finetuneDetails = await Promise.all( + finetunes.map(async (finetuneId) => { + try { + const detailResponse = await axios.get(`${detailsUrl}?finetune_id=${finetuneId}`, { + headers, + ...this.getAxiosConfig(), + }); + return { + id: finetuneId, + ...detailResponse.data, + }; + } catch (error) { + logger.error(`[FluxAPI] Error fetching details for finetune ${finetuneId}:`, error); + return { + id: finetuneId, + error: 'Failed to fetch details', + }; + } + }), + ); + + if (this.isAgent) { + const formattedDetails = JSON.stringify(finetuneDetails, null, 2); + return [`Here are the available finetunes:\n${formattedDetails}`, null]; + } + return JSON.stringify(finetuneDetails); + } catch (error) { + const details = this.getDetails(error?.response?.data || error.message); + logger.error('[FluxAPI] Error while getting finetunes:', details); + const errorMsg = `Failed to get finetunes: ${details}`; + return this.isAgent ? this.returnValue([errorMsg, {}]) : new Error(errorMsg); + } + } + + async generateFinetunedImage(imageData, requestApiKey) { + if (!imageData.prompt) { + throw new Error('Missing required field: prompt'); + } + + if (!imageData.finetune_id) { + throw new Error( + 'Missing required field: finetune_id for finetuned generation. Please supply a finetune_id!', + ); + } + + // Validate endpoint is appropriate for finetuned generation + const validFinetunedEndpoints = ['/v1/flux-pro-finetuned', '/v1/flux-pro-1.1-ultra-finetuned']; + const endpoint = imageData.endpoint || '/v1/flux-pro-finetuned'; + + if (!validFinetunedEndpoints.includes(endpoint)) { + throw new Error( + `Invalid endpoint for finetuned generation. Must be one of: ${validFinetunedEndpoints.join(', ')}`, + ); + } + + let payload = { + prompt: imageData.prompt, + prompt_upsampling: imageData.prompt_upsampling || false, + safety_tolerance: imageData.safety_tolerance || 6, + output_format: imageData.output_format || 'png', + finetune_id: imageData.finetune_id, + finetune_strength: imageData.finetune_strength || 1.0, + guidance: imageData.guidance || 2.5, + }; + + // Add optional parameters if provided + if (imageData.width) { + payload.width = imageData.width; + } + if (imageData.height) { + payload.height = imageData.height; + } + if (imageData.steps) { + payload.steps = imageData.steps; + } + if (imageData.seed !== undefined) { + payload.seed = imageData.seed; + } + if (imageData.raw) { + payload.raw = imageData.raw; + } + + const generateUrl = `${this.baseUrl}${endpoint}`; + const resultUrl = `${this.baseUrl}/v1/get_result`; + + logger.debug('[FluxAPI] Generating finetuned image with payload:', payload); + logger.debug('[FluxAPI] Using endpoint:', generateUrl); + + let taskResponse; + try { + taskResponse = await axios.post(generateUrl, payload, { + headers: { + 'x-key': requestApiKey, + 'Content-Type': 'application/json', + Accept: 'application/json', + }, + ...this.getAxiosConfig(), + }); + } catch (error) { + const details = this.getDetails(error?.response?.data || error.message); + logger.error('[FluxAPI] Error while submitting finetuned task:', details); + return this.returnValue( + `Something went wrong when trying to generate the finetuned image. The Flux API may be unavailable: + Error Message: ${details}`, + ); + } + + const taskId = taskResponse.data.id; + + // Polling for the result + let status = 'Pending'; + let resultData = null; + while (status !== 'Ready' && status !== 'Error') { + try { + // Wait 2 seconds between polls + await new Promise((resolve) => setTimeout(resolve, 2000)); + const resultResponse = await axios.get(resultUrl, { + headers: { + 'x-key': requestApiKey, + Accept: 'application/json', + }, + params: { id: taskId }, + ...this.getAxiosConfig(), + }); + status = resultResponse.data.status; + + if (status === 'Ready') { + resultData = resultResponse.data.result; + break; + } else if (status === 'Error') { + logger.error('[FluxAPI] Error in finetuned task:', resultResponse.data); + return this.returnValue('An error occurred during finetuned image generation.'); + } + } catch (error) { + const details = this.getDetails(error?.response?.data || error.message); + logger.error('[FluxAPI] Error while getting finetuned result:', details); + return this.returnValue('An error occurred while retrieving the finetuned image.'); + } + } + + // If no result data + if (!resultData || !resultData.sample) { + logger.error('[FluxAPI] No image data received from API. Response:', resultData); + return this.returnValue('No image data received from Flux API.'); + } + + // Try saving the image locally + const imageUrl = resultData.sample; + const imageName = `img-${uuidv4()}.png`; + + try { + logger.debug('[FluxAPI] Saving finetuned image:', imageUrl); + const result = await this.processFileURL({ + fileStrategy: this.fileStrategy, + userId: this.userId, + URL: imageUrl, + fileName: imageName, + basePath: 'images', + context: FileContext.image_generation, + }); + + logger.debug('[FluxAPI] Finetuned image saved to path:', result.filepath); + + // Calculate cost based on endpoint + const endpointKey = endpoint.includes('ultra') + ? 'FLUX_PRO_1_1_ULTRA_FINETUNED' + : 'FLUX_PRO_FINETUNED'; + const cost = FluxAPI.PRICING[endpointKey] || 0; + // Return the result based on returnMetadata flag + this.result = this.returnMetadata ? result : this.wrapInMarkdown(result.filepath); + return this.returnValue(this.result); + } catch (error) { + const details = this.getDetails(error?.message ?? 'No additional error details.'); + logger.error('Error while saving the finetuned image:', details); + return this.returnValue(`Failed to save the finetuned image locally. ${details}`); + } + } +} + +module.exports = FluxAPI; diff --git a/api/app/clients/tools/structured/OpenAIImageTools.js b/api/app/clients/tools/structured/OpenAIImageTools.js new file mode 100644 index 0000000000..85941a779a --- /dev/null +++ b/api/app/clients/tools/structured/OpenAIImageTools.js @@ -0,0 +1,518 @@ +const { z } = require('zod'); +const axios = require('axios'); +const { v4 } = require('uuid'); +const OpenAI = require('openai'); +const FormData = require('form-data'); +const { tool } = require('@langchain/core/tools'); +const { HttpsProxyAgent } = require('https-proxy-agent'); +const { ContentTypes, EImageOutputType } = require('librechat-data-provider'); +const { getStrategyFunctions } = require('~/server/services/Files/strategies'); +const { logAxiosError, extractBaseURL } = require('~/utils'); +const { getFiles } = require('~/models/File'); +const { logger } = require('~/config'); + +/** Default descriptions for image generation tool */ +const DEFAULT_IMAGE_GEN_DESCRIPTION = ` +Generates high-quality, original images based solely on text, not using any uploaded reference images. + +When to use \`image_gen_oai\`: +- To create entirely new images from detailed text descriptions that do NOT reference any image files. + +When NOT to use \`image_gen_oai\`: +- If the user has uploaded any images and requests modifications, enhancements, or remixing based on those uploads → use \`image_edit_oai\` instead. + +Generated image IDs will be returned in the response, so you can refer to them in future requests made to \`image_edit_oai\`. +`.trim(); + +/** Default description for image editing tool */ +const DEFAULT_IMAGE_EDIT_DESCRIPTION = + `Generates high-quality, original images based on text and one or more uploaded/referenced images. + +When to use \`image_edit_oai\`: +- The user wants to modify, extend, or remix one **or more** uploaded images, either: + - Previously generated, or in the current request (both to be included in the \`image_ids\` array). +- Always when the user refers to uploaded images for editing, enhancement, remixing, style transfer, or combining elements. +- Any current or existing images are to be used as visual guides. +- If there are any files in the current request, they are more likely than not expected as references for image edit requests. + +When NOT to use \`image_edit_oai\`: +- Brand-new generations that do not rely on an existing image → use \`image_gen_oai\` instead. + +Both generated and referenced image IDs will be returned in the response, so you can refer to them in future requests made to \`image_edit_oai\`. +`.trim(); + +/** Default prompt descriptions */ +const DEFAULT_IMAGE_GEN_PROMPT_DESCRIPTION = `Describe the image you want in detail. + Be highly specific—break your idea into layers: + (1) main concept and subject, + (2) composition and position, + (3) lighting and mood, + (4) style, medium, or camera details, + (5) important features (age, expression, clothing, etc.), + (6) background. + Use positive, descriptive language and specify what should be included, not what to avoid. + List number and characteristics of people/objects, and mention style/technical requirements (e.g., "DSLR photo, 85mm lens, golden hour"). + Do not reference any uploaded images—use for new image creation from text only.`; + +const DEFAULT_IMAGE_EDIT_PROMPT_DESCRIPTION = `Describe the changes, enhancements, or new ideas to apply to the uploaded image(s). + Be highly specific—break your request into layers: + (1) main concept or transformation, + (2) specific edits/replacements or composition guidance, + (3) desired style, mood, or technique, + (4) features/items to keep, change, or add (such as objects, people, clothing, lighting, etc.). + Use positive, descriptive language and clarify what should be included or changed, not what to avoid. + Always base this prompt on the most recently uploaded reference images.`; + +const displayMessage = + 'The tool displayed an image. All generated images are already plainly visible, so don\'t repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.'; + +/** + * Replaces unwanted characters from the input string + * @param {string} inputString - The input string to process + * @returns {string} - The processed string + */ +function replaceUnwantedChars(inputString) { + return inputString + .replace(/\r\n|\r|\n/g, ' ') + .replace(/"/g, '') + .trim(); +} + +function returnValue(value) { + if (typeof value === 'string') { + return [value, {}]; + } else if (typeof value === 'object') { + if (Array.isArray(value)) { + return value; + } + return [displayMessage, value]; + } + return value; +} + +const getImageGenDescription = () => { + return process.env.IMAGE_GEN_OAI_DESCRIPTION || DEFAULT_IMAGE_GEN_DESCRIPTION; +}; + +const getImageEditDescription = () => { + return process.env.IMAGE_EDIT_OAI_DESCRIPTION || DEFAULT_IMAGE_EDIT_DESCRIPTION; +}; + +const getImageGenPromptDescription = () => { + return process.env.IMAGE_GEN_OAI_PROMPT_DESCRIPTION || DEFAULT_IMAGE_GEN_PROMPT_DESCRIPTION; +}; + +const getImageEditPromptDescription = () => { + return process.env.IMAGE_EDIT_OAI_PROMPT_DESCRIPTION || DEFAULT_IMAGE_EDIT_PROMPT_DESCRIPTION; +}; + +/** + * Creates OpenAI Image tools (generation and editing) + * @param {Object} fields - Configuration fields + * @param {ServerRequest} fields.req - Whether the tool is being used in an agent context + * @param {boolean} fields.isAgent - Whether the tool is being used in an agent context + * @param {string} fields.IMAGE_GEN_OAI_API_KEY - The OpenAI API key + * @param {boolean} [fields.override] - Whether to override the API key check, necessary for app initialization + * @param {MongoFile[]} [fields.imageFiles] - The images to be used for editing + * @returns {Array} - Array of image tools + */ +function createOpenAIImageTools(fields = {}) { + /** @type {boolean} Used to initialize the Tool without necessary variables. */ + const override = fields.override ?? false; + /** @type {boolean} */ + if (!override && !fields.isAgent) { + throw new Error('This tool is only available for agents.'); + } + const { req } = fields; + const imageOutputType = req?.app.locals.imageOutputType || EImageOutputType.PNG; + const appFileStrategy = req?.app.locals.fileStrategy; + + const getApiKey = () => { + const apiKey = process.env.IMAGE_GEN_OAI_API_KEY ?? ''; + if (!apiKey && !override) { + throw new Error('Missing IMAGE_GEN_OAI_API_KEY environment variable.'); + } + return apiKey; + }; + + let apiKey = fields.IMAGE_GEN_OAI_API_KEY ?? getApiKey(); + const closureConfig = { apiKey }; + + let baseURL = 'https://api.openai.com/v1/'; + if (!override && process.env.IMAGE_GEN_OAI_BASEURL) { + baseURL = extractBaseURL(process.env.IMAGE_GEN_OAI_BASEURL); + closureConfig.baseURL = baseURL; + } + + // Note: Azure may not yet support the latest image generation models + if ( + !override && + process.env.IMAGE_GEN_OAI_AZURE_API_VERSION && + process.env.IMAGE_GEN_OAI_BASEURL + ) { + baseURL = process.env.IMAGE_GEN_OAI_BASEURL; + closureConfig.baseURL = baseURL; + closureConfig.defaultQuery = { 'api-version': process.env.IMAGE_GEN_OAI_AZURE_API_VERSION }; + closureConfig.defaultHeaders = { + 'api-key': process.env.IMAGE_GEN_OAI_API_KEY, + 'Content-Type': 'application/json', + }; + closureConfig.apiKey = process.env.IMAGE_GEN_OAI_API_KEY; + } + + const imageFiles = fields.imageFiles ?? []; + + /** + * Image Generation Tool + */ + const imageGenTool = tool( + async ( + { + prompt, + background = 'auto', + n = 1, + output_compression = 100, + quality = 'auto', + size = 'auto', + }, + runnableConfig, + ) => { + if (!prompt) { + throw new Error('Missing required field: prompt'); + } + const clientConfig = { ...closureConfig }; + if (process.env.PROXY) { + clientConfig.httpAgent = new HttpsProxyAgent(process.env.PROXY); + } + + /** @type {OpenAI} */ + const openai = new OpenAI(clientConfig); + let output_format = imageOutputType; + if ( + background === 'transparent' && + output_format !== EImageOutputType.PNG && + output_format !== EImageOutputType.WEBP + ) { + logger.warn( + '[ImageGenOAI] Transparent background requires PNG or WebP format, defaulting to PNG', + ); + output_format = EImageOutputType.PNG; + } + + let resp; + try { + const derivedSignal = runnableConfig?.signal + ? AbortSignal.any([runnableConfig.signal]) + : undefined; + resp = await openai.images.generate( + { + model: 'gpt-image-1', + prompt: replaceUnwantedChars(prompt), + n: Math.min(Math.max(1, n), 10), + background, + output_format, + output_compression: + output_format === EImageOutputType.WEBP || output_format === EImageOutputType.JPEG + ? output_compression + : undefined, + quality, + size, + }, + { + signal: derivedSignal, + }, + ); + } catch (error) { + const message = '[image_gen_oai] Problem generating the image:'; + logAxiosError({ error, message }); + return returnValue(`Something went wrong when trying to generate the image. The OpenAI API may be unavailable: +Error Message: ${error.message}`); + } + + if (!resp) { + return returnValue( + 'Something went wrong when trying to generate the image. The OpenAI API may be unavailable', + ); + } + + // For gpt-image-1, the response contains base64-encoded images + // TODO: handle cost in `resp.usage` + const base64Image = resp.data[0].b64_json; + + if (!base64Image) { + return returnValue( + 'No image data returned from OpenAI API. There may be a problem with the API or your configuration.', + ); + } + + const content = [ + { + type: ContentTypes.IMAGE_URL, + image_url: { + url: `data:image/${output_format};base64,${base64Image}`, + }, + }, + ]; + + const file_ids = [v4()]; + const response = [ + { + type: ContentTypes.TEXT, + text: displayMessage + `\n\ngenerated_image_id: "${file_ids[0]}"`, + }, + ]; + return [response, { content, file_ids }]; + }, + { + name: 'image_gen_oai', + description: getImageGenDescription(), + schema: z.object({ + prompt: z.string().max(32000).describe(getImageGenPromptDescription()), + background: z + .enum(['transparent', 'opaque', 'auto']) + .optional() + .describe( + 'Sets transparency for the background. Must be one of transparent, opaque or auto (default). When transparent, the output format should be png or webp.', + ), + /* + n: z + .number() + .int() + .min(1) + .max(10) + .optional() + .describe('The number of images to generate. Must be between 1 and 10.'), + output_compression: z + .number() + .int() + .min(0) + .max(100) + .optional() + .describe('The compression level (0-100%) for webp or jpeg formats. Defaults to 100.'), + */ + quality: z + .enum(['auto', 'high', 'medium', 'low']) + .optional() + .describe('The quality of the image. One of auto (default), high, medium, or low.'), + size: z + .enum(['auto', '1024x1024', '1536x1024', '1024x1536']) + .optional() + .describe( + 'The size of the generated image. One of 1024x1024, 1536x1024 (landscape), 1024x1536 (portrait), or auto (default).', + ), + }), + responseFormat: 'content_and_artifact', + }, + ); + + /** + * Image Editing Tool + */ + const imageEditTool = tool( + async ({ prompt, image_ids, quality = 'auto', size = 'auto' }, runnableConfig) => { + if (!prompt) { + throw new Error('Missing required field: prompt'); + } + + const clientConfig = { ...closureConfig }; + if (process.env.PROXY) { + clientConfig.httpAgent = new HttpsProxyAgent(process.env.PROXY); + } + + const formData = new FormData(); + formData.append('model', 'gpt-image-1'); + formData.append('prompt', replaceUnwantedChars(prompt)); + // TODO: `mask` support + // TODO: more than 1 image support + // formData.append('n', n.toString()); + formData.append('quality', quality); + formData.append('size', size); + + /** @type {Record>} */ + const streamMethods = {}; + + const requestFilesMap = Object.fromEntries(imageFiles.map((f) => [f.file_id, { ...f }])); + + const orderedFiles = new Array(image_ids.length); + const idsToFetch = []; + const indexOfMissing = Object.create(null); + + for (let i = 0; i < image_ids.length; i++) { + const id = image_ids[i]; + const file = requestFilesMap[id]; + + if (file) { + orderedFiles[i] = file; + } else { + idsToFetch.push(id); + indexOfMissing[id] = i; + } + } + + if (idsToFetch.length) { + const fetchedFiles = await getFiles( + { + user: req.user.id, + file_id: { $in: idsToFetch }, + height: { $exists: true }, + width: { $exists: true }, + }, + {}, + {}, + ); + + for (const file of fetchedFiles) { + requestFilesMap[file.file_id] = file; + orderedFiles[indexOfMissing[file.file_id]] = file; + } + } + for (const imageFile of orderedFiles) { + if (!imageFile) { + continue; + } + /** @type {NodeStream} */ + let stream; + /** @type {NodeStreamDownloader} */ + let getDownloadStream; + const source = imageFile.source || appFileStrategy; + if (!source) { + throw new Error('No source found for image file'); + } + if (streamMethods[source]) { + getDownloadStream = streamMethods[source]; + } else { + ({ getDownloadStream } = getStrategyFunctions(source)); + streamMethods[source] = getDownloadStream; + } + if (!getDownloadStream) { + throw new Error(`No download stream method found for source: ${source}`); + } + stream = await getDownloadStream(req, imageFile.filepath); + if (!stream) { + throw new Error('Failed to get download stream for image file'); + } + formData.append('image[]', stream, { + filename: imageFile.filename, + contentType: imageFile.type, + }); + } + + /** @type {import('axios').RawAxiosHeaders} */ + let headers = { + ...formData.getHeaders(), + }; + + if (process.env.IMAGE_GEN_OAI_AZURE_API_VERSION && process.env.IMAGE_GEN_OAI_BASEURL) { + headers['api-key'] = apiKey; + } else { + headers['Authorization'] = `Bearer ${apiKey}`; + } + + try { + const derivedSignal = runnableConfig?.signal + ? AbortSignal.any([runnableConfig.signal]) + : undefined; + + /** @type {import('axios').AxiosRequestConfig} */ + const axiosConfig = { + headers, + ...clientConfig, + signal: derivedSignal, + baseURL, + }; + + if (process.env.IMAGE_GEN_OAI_AZURE_API_VERSION && process.env.IMAGE_GEN_OAI_BASEURL) { + axiosConfig.params = { + 'api-version': process.env.IMAGE_GEN_OAI_AZURE_API_VERSION, + ...axiosConfig.params, + }; + } + const response = await axios.post('/images/edits', formData, axiosConfig); + + if (!response.data || !response.data.data || !response.data.data.length) { + return returnValue( + 'No image data returned from OpenAI API. There may be a problem with the API or your configuration.', + ); + } + + const base64Image = response.data.data[0].b64_json; + if (!base64Image) { + return returnValue( + 'No image data returned from OpenAI API. There may be a problem with the API or your configuration.', + ); + } + + const content = [ + { + type: ContentTypes.IMAGE_URL, + image_url: { + url: `data:image/${imageOutputType};base64,${base64Image}`, + }, + }, + ]; + + const file_ids = [v4()]; + const textResponse = [ + { + type: ContentTypes.TEXT, + text: + displayMessage + + `\n\ngenerated_image_id: "${file_ids[0]}"\nreferenced_image_ids: ["${image_ids.join('", "')}"]`, + }, + ]; + return [textResponse, { content, file_ids }]; + } catch (error) { + const message = '[image_edit_oai] Problem editing the image:'; + logAxiosError({ error, message }); + return returnValue(`Something went wrong when trying to edit the image. The OpenAI API may be unavailable: +Error Message: ${error.message || 'Unknown error'}`); + } + }, + { + name: 'image_edit_oai', + description: getImageEditDescription(), + schema: z.object({ + image_ids: z + .array(z.string()) + .min(1) + .describe( + ` +IDs (image ID strings) of previously generated or uploaded images that should guide the edit. + +Guidelines: +- If the user's request depends on any prior image(s), copy their image IDs into the \`image_ids\` array (in the same order the user refers to them). +- Never invent or hallucinate IDs; only use IDs that are still visible in the conversation context. +- If no earlier image is relevant, omit the field entirely. +`.trim(), + ), + prompt: z.string().max(32000).describe(getImageEditPromptDescription()), + /* + n: z + .number() + .int() + .min(1) + .max(10) + .optional() + .describe('The number of images to generate. Must be between 1 and 10. Defaults to 1.'), + */ + quality: z + .enum(['auto', 'high', 'medium', 'low']) + .optional() + .describe( + 'The quality of the image. One of auto (default), high, medium, or low. High/medium/low only supported for gpt-image-1.', + ), + size: z + .enum(['auto', '1024x1024', '1536x1024', '1024x1536', '256x256', '512x512']) + .optional() + .describe( + 'The size of the generated images. For gpt-image-1: auto (default), 1024x1024, 1536x1024, 1024x1536. For dall-e-2: 256x256, 512x512, 1024x1024.', + ), + }), + responseFormat: 'content_and_artifact', + }, + ); + + return [imageGenTool, imageEditTool]; +} + +module.exports = createOpenAIImageTools; diff --git a/api/app/clients/tools/structured/StableDiffusion.js b/api/app/clients/tools/structured/StableDiffusion.js index 6309da35d8..25a9e0abd3 100644 --- a/api/app/clients/tools/structured/StableDiffusion.js +++ b/api/app/clients/tools/structured/StableDiffusion.js @@ -6,10 +6,13 @@ const axios = require('axios'); const sharp = require('sharp'); const { v4: uuidv4 } = require('uuid'); const { Tool } = require('@langchain/core/tools'); -const { FileContext } = require('librechat-data-provider'); +const { FileContext, ContentTypes } = require('librechat-data-provider'); const paths = require('~/config/paths'); const { logger } = require('~/config'); +const displayMessage = + 'Stable Diffusion displayed an image. All generated images are already plainly visible, so don\'t repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.'; + class StableDiffusionAPI extends Tool { constructor(fields) { super(); @@ -21,6 +24,8 @@ class StableDiffusionAPI extends Tool { this.override = fields.override ?? false; /** @type {boolean} Necessary for output to contain all image metadata. */ this.returnMetadata = fields.returnMetadata ?? false; + /** @type {boolean} */ + this.isAgent = fields.isAgent; if (fields.uploadImageBuffer) { /** @type {uploadImageBuffer} Necessary for output to contain all image metadata. */ this.uploadImageBuffer = fields.uploadImageBuffer.bind(this); @@ -66,6 +71,16 @@ class StableDiffusionAPI extends Tool { return `![generated image](/${imageUrl})`; } + returnValue(value) { + if (this.isAgent === true && typeof value === 'string') { + return [value, {}]; + } else if (this.isAgent === true && typeof value === 'object') { + return [displayMessage, value]; + } + + return value; + } + getServerURL() { const url = process.env.SD_WEBUI_URL || ''; if (!url && !this.override) { @@ -113,6 +128,25 @@ class StableDiffusionAPI extends Tool { } try { + if (this.isAgent) { + const content = [ + { + type: ContentTypes.IMAGE_URL, + image_url: { + url: `data:image/png;base64,${image}`, + }, + }, + ]; + + const response = [ + { + type: ContentTypes.TEXT, + text: displayMessage, + }, + ]; + return [response, { content }]; + } + const buffer = Buffer.from(image.split(',', 1)[0], 'base64'); if (this.returnMetadata && this.uploadImageBuffer && this.req) { const file = await this.uploadImageBuffer({ @@ -154,7 +188,7 @@ class StableDiffusionAPI extends Tool { logger.error('[StableDiffusion] Error while saving the image:', error); } - return this.result; + return this.returnValue(this.result); } } diff --git a/api/app/clients/tools/structured/TavilySearchResults.js b/api/app/clients/tools/structured/TavilySearchResults.js index 9a62053ff0..9461293371 100644 --- a/api/app/clients/tools/structured/TavilySearchResults.js +++ b/api/app/clients/tools/structured/TavilySearchResults.js @@ -43,9 +43,39 @@ class TavilySearchResults extends Tool { .boolean() .optional() .describe('Whether to include answers in the search results. Default is False.'), - // include_raw_content: z.boolean().optional().describe('Whether to include raw content in the search results. Default is False.'), - // include_domains: z.array(z.string()).optional().describe('A list of domains to specifically include in the search results.'), - // exclude_domains: z.array(z.string()).optional().describe('A list of domains to specifically exclude from the search results.'), + include_raw_content: z + .boolean() + .optional() + .describe('Whether to include raw content in the search results. Default is False.'), + include_domains: z + .array(z.string()) + .optional() + .describe('A list of domains to specifically include in the search results.'), + exclude_domains: z + .array(z.string()) + .optional() + .describe('A list of domains to specifically exclude from the search results.'), + topic: z + .enum(['general', 'news', 'finance']) + .optional() + .describe( + 'The category of the search. Use news ONLY if query SPECIFCALLY mentions the word "news".', + ), + time_range: z + .enum(['day', 'week', 'month', 'year', 'd', 'w', 'm', 'y']) + .optional() + .describe('The time range back from the current date to filter results.'), + days: z + .number() + .min(1) + .optional() + .describe('Number of days back from the current date to include. Only if topic is news.'), + include_image_descriptions: z + .boolean() + .optional() + .describe( + 'When include_images is true, also add a descriptive text for each image. Default is false.', + ), }); } diff --git a/api/app/clients/tools/util/addOpenAPISpecs.js b/api/app/clients/tools/util/addOpenAPISpecs.js deleted file mode 100644 index 8b87be9941..0000000000 --- a/api/app/clients/tools/util/addOpenAPISpecs.js +++ /dev/null @@ -1,30 +0,0 @@ -const { loadSpecs } = require('./loadSpecs'); - -function transformSpec(input) { - return { - name: input.name_for_human, - pluginKey: input.name_for_model, - description: input.description_for_human, - icon: input?.logo_url ?? 'https://placehold.co/70x70.png', - // TODO: add support for authentication - isAuthRequired: 'false', - authConfig: [], - }; -} - -async function addOpenAPISpecs(availableTools) { - try { - const specs = (await loadSpecs({})).map(transformSpec); - if (specs.length > 0) { - return [...specs, ...availableTools]; - } - return availableTools; - } catch (error) { - return availableTools; - } -} - -module.exports = { - transformSpec, - addOpenAPISpecs, -}; diff --git a/api/app/clients/tools/util/addOpenAPISpecs.spec.js b/api/app/clients/tools/util/addOpenAPISpecs.spec.js deleted file mode 100644 index 21ff4eb8cc..0000000000 --- a/api/app/clients/tools/util/addOpenAPISpecs.spec.js +++ /dev/null @@ -1,76 +0,0 @@ -const { addOpenAPISpecs, transformSpec } = require('./addOpenAPISpecs'); -const { loadSpecs } = require('./loadSpecs'); -const { createOpenAPIPlugin } = require('../dynamic/OpenAPIPlugin'); - -jest.mock('./loadSpecs'); -jest.mock('../dynamic/OpenAPIPlugin'); - -describe('transformSpec', () => { - it('should transform input spec to a desired format', () => { - const input = { - name_for_human: 'Human Name', - name_for_model: 'Model Name', - description_for_human: 'Human Description', - logo_url: 'https://example.com/logo.png', - }; - - const expectedOutput = { - name: 'Human Name', - pluginKey: 'Model Name', - description: 'Human Description', - icon: 'https://example.com/logo.png', - isAuthRequired: 'false', - authConfig: [], - }; - - expect(transformSpec(input)).toEqual(expectedOutput); - }); - - it('should use default icon if logo_url is not provided', () => { - const input = { - name_for_human: 'Human Name', - name_for_model: 'Model Name', - description_for_human: 'Human Description', - }; - - const expectedOutput = { - name: 'Human Name', - pluginKey: 'Model Name', - description: 'Human Description', - icon: 'https://placehold.co/70x70.png', - isAuthRequired: 'false', - authConfig: [], - }; - - expect(transformSpec(input)).toEqual(expectedOutput); - }); -}); - -describe('addOpenAPISpecs', () => { - it('should add specs to available tools', async () => { - const availableTools = ['Tool1', 'Tool2']; - const specs = [ - { - name_for_human: 'Human Name', - name_for_model: 'Model Name', - description_for_human: 'Human Description', - logo_url: 'https://example.com/logo.png', - }, - ]; - - loadSpecs.mockResolvedValue(specs); - createOpenAPIPlugin.mockReturnValue('Plugin'); - - const result = await addOpenAPISpecs(availableTools); - expect(result).toEqual([...specs.map(transformSpec), ...availableTools]); - }); - - it('should return available tools if specs loading fails', async () => { - const availableTools = ['Tool1', 'Tool2']; - - loadSpecs.mockRejectedValue(new Error('Failed to load specs')); - - const result = await addOpenAPISpecs(availableTools); - expect(result).toEqual(availableTools); - }); -}); diff --git a/api/app/clients/tools/util/fileSearch.js b/api/app/clients/tools/util/fileSearch.js index 23ba58bb5a..54da483362 100644 --- a/api/app/clients/tools/util/fileSearch.js +++ b/api/app/clients/tools/util/fileSearch.js @@ -106,18 +106,21 @@ const createFileSearchTool = async ({ req, files, entity_id }) => { const formattedResults = validResults .flatMap((result) => - result.data.map(([docInfo, relevanceScore]) => ({ + result.data.map(([docInfo, distance]) => ({ filename: docInfo.metadata.source.split('/').pop(), content: docInfo.page_content, - relevanceScore, + distance, })), ) - .sort((a, b) => b.relevanceScore - a.relevanceScore); + // TODO: results should be sorted by relevance, not distance + .sort((a, b) => a.distance - b.distance) + // TODO: make this configurable + .slice(0, 10); const formattedString = formattedResults .map( (result) => - `File: ${result.filename}\nRelevance: ${result.relevanceScore.toFixed(4)}\nContent: ${ + `File: ${result.filename}\nRelevance: ${1.0 - result.distance.toFixed(4)}\nContent: ${ result.content }\n`, ) diff --git a/api/app/clients/tools/util/handleTools.js b/api/app/clients/tools/util/handleTools.js index f1dfa24a49..e480dd4928 100644 --- a/api/app/clients/tools/util/handleTools.js +++ b/api/app/clients/tools/util/handleTools.js @@ -1,7 +1,7 @@ -const { Tools, Constants } = require('librechat-data-provider'); const { SerpAPI } = require('@langchain/community/tools/serpapi'); const { Calculator } = require('@langchain/community/tools/calculator'); const { createCodeExecutionTool, EnvVar } = require('@librechat/agents'); +const { Tools, Constants, EToolResources } = require('librechat-data-provider'); const { getUserPluginAuthValue } = require('~/server/services/PluginService'); const { availableTools, @@ -10,6 +10,7 @@ const { GoogleSearchAPI, // Structured Tools DALLE3, + FluxAPI, OpenWeather, StructuredSD, StructuredACS, @@ -17,11 +18,12 @@ const { StructuredWolfram, createYouTubeTools, TavilySearchResults, + createOpenAIImageTools, } = require('../'); const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process'); const { createFileSearchTool, primeFiles: primeSearchFiles } = require('./fileSearch'); +const { loadAuthValues } = require('~/server/services/Tools/credentials'); const { createMCPTool } = require('~/server/services/MCP'); -const { loadSpecs } = require('./loadSpecs'); const { logger } = require('~/config'); const mcpToolPattern = new RegExp(`^.+${Constants.mcp_delimiter}.+$`); @@ -89,45 +91,6 @@ const validateTools = async (user, tools = []) => { } }; -const loadAuthValues = async ({ userId, authFields, throwError = true }) => { - let authValues = {}; - - /** - * Finds the first non-empty value for the given authentication field, supporting alternate fields. - * @param {string[]} fields Array of strings representing the authentication fields. Supports alternate fields delimited by "||". - * @returns {Promise<{ authField: string, authValue: string} | null>} An object containing the authentication field and value, or null if not found. - */ - const findAuthValue = async (fields) => { - for (const field of fields) { - let value = process.env[field]; - if (value) { - return { authField: field, authValue: value }; - } - try { - value = await getUserPluginAuthValue(userId, field, throwError); - } catch (err) { - if (field === fields[fields.length - 1] && !value) { - throw err; - } - } - if (value) { - return { authField: field, authValue: value }; - } - } - return null; - }; - - for (let authField of authFields) { - const fields = authField.split('||'); - const result = await findAuthValue(fields); - if (result) { - authValues[result.authField] = result.authValue; - } - } - - return authValues; -}; - /** @typedef {typeof import('@langchain/core/tools').Tool} ToolConstructor */ /** @typedef {import('@langchain/core/tools').Tool} Tool */ @@ -160,7 +123,7 @@ const getAuthFields = (toolKey) => { * * @param {object} object * @param {string} object.user - * @param {Agent} [object.agent] + * @param {Pick} [object.agent] * @param {string} [object.model] * @param {EModelEndpoint} [object.endpoint] * @param {LoadToolOptions} [object.options] @@ -182,6 +145,7 @@ const loadTools = async ({ returnMap = false, }) => { const toolConstructors = { + flux: FluxAPI, calculator: Calculator, google: GoogleSearchAPI, open_weather: OpenWeather, @@ -193,7 +157,7 @@ const loadTools = async ({ }; const customConstructors = { - serpapi: async () => { + serpapi: async (_toolContextMap) => { const authFields = getAuthFields('serpapi'); let envVar = authFields[0] ?? ''; let apiKey = process.env[envVar]; @@ -206,11 +170,40 @@ const loadTools = async ({ gl: 'us', }); }, - youtube: async () => { + youtube: async (_toolContextMap) => { const authFields = getAuthFields('youtube'); const authValues = await loadAuthValues({ userId: user, authFields }); return createYouTubeTools(authValues); }, + image_gen_oai: async (toolContextMap) => { + const authFields = getAuthFields('image_gen_oai'); + const authValues = await loadAuthValues({ userId: user, authFields }); + const imageFiles = options.tool_resources?.[EToolResources.image_edit]?.files ?? []; + let toolContext = ''; + for (let i = 0; i < imageFiles.length; i++) { + const file = imageFiles[i]; + if (!file) { + continue; + } + if (i === 0) { + toolContext = + 'Image files provided in this request (their image IDs listed in order of appearance) available for image editing:'; + } + toolContext += `\n\t- ${file.file_id}`; + if (i === imageFiles.length - 1) { + toolContext += `\n\nInclude any you need in the \`image_ids\` array when calling \`${EToolResources.image_edit}_oai\`. You may also include previously referenced or generated image IDs.`; + } + } + if (toolContext) { + toolContextMap.image_edit_oai = toolContext; + } + return createOpenAIImageTools({ + ...authValues, + isAgent: !!agent, + req: options.req, + imageFiles, + }); + }, }; const requestedTools = {}; @@ -230,13 +223,14 @@ const loadTools = async ({ }; const toolOptions = { - serpapi: { location: 'Austin,Texas,United States', hl: 'en', gl: 'us' }, + flux: imageGenOptions, dalle: imageGenOptions, 'stable-diffusion': imageGenOptions, + serpapi: { location: 'Austin,Texas,United States', hl: 'en', gl: 'us' }, }; + /** @type {Record} */ const toolContextMap = {}; - const remainingTools = []; const appTools = options.req?.app?.locals?.availableTools ?? {}; for (const tool of tools) { @@ -281,7 +275,7 @@ const loadTools = async ({ } if (customConstructors[tool]) { - requestedTools[tool] = customConstructors[tool]; + requestedTools[tool] = async () => customConstructors[tool](toolContextMap); continue; } @@ -296,30 +290,6 @@ const loadTools = async ({ requestedTools[tool] = toolInstance; continue; } - - if (functions === true) { - remainingTools.push(tool); - } - } - - let specs = null; - if (useSpecs === true && functions === true && remainingTools.length > 0) { - specs = await loadSpecs({ - llm: model, - user, - message: options.message, - memory: options.memory, - signal: options.signal, - tools: remainingTools, - map: true, - verbose: false, - }); - } - - for (const tool of remainingTools) { - if (specs && specs[tool]) { - requestedTools[tool] = specs[tool]; - } } if (returnMap) { @@ -345,7 +315,6 @@ const loadTools = async ({ module.exports = { loadToolWithAuth, - loadAuthValues, validateTools, loadTools, }; diff --git a/api/app/clients/tools/util/index.js b/api/app/clients/tools/util/index.js index 73d10270b6..ea67bb4ced 100644 --- a/api/app/clients/tools/util/index.js +++ b/api/app/clients/tools/util/index.js @@ -1,9 +1,8 @@ -const { validateTools, loadTools, loadAuthValues } = require('./handleTools'); +const { validateTools, loadTools } = require('./handleTools'); const handleOpenAIErrors = require('./handleOpenAIErrors'); module.exports = { handleOpenAIErrors, - loadAuthValues, validateTools, loadTools, }; diff --git a/api/app/clients/tools/util/loadSpecs.js b/api/app/clients/tools/util/loadSpecs.js deleted file mode 100644 index e5b543132a..0000000000 --- a/api/app/clients/tools/util/loadSpecs.js +++ /dev/null @@ -1,117 +0,0 @@ -const fs = require('fs'); -const path = require('path'); -const { z } = require('zod'); -const { logger } = require('~/config'); -const { createOpenAPIPlugin } = require('~/app/clients/tools/dynamic/OpenAPIPlugin'); - -// The minimum Manifest definition -const ManifestDefinition = z.object({ - schema_version: z.string().optional(), - name_for_human: z.string(), - name_for_model: z.string(), - description_for_human: z.string(), - description_for_model: z.string(), - auth: z.object({}).optional(), - api: z.object({ - // Spec URL or can be the filename of the OpenAPI spec yaml file, - // located in api\app\clients\tools\.well-known\openapi - url: z.string(), - type: z.string().optional(), - is_user_authenticated: z.boolean().nullable().optional(), - has_user_authentication: z.boolean().nullable().optional(), - }), - // use to override any params that the LLM will consistently get wrong - params: z.object({}).optional(), - logo_url: z.string().optional(), - contact_email: z.string().optional(), - legal_info_url: z.string().optional(), -}); - -function validateJson(json) { - try { - return ManifestDefinition.parse(json); - } catch (error) { - logger.debug('[validateJson] manifest parsing error', error); - return false; - } -} - -// omit the LLM to return the well known jsons as objects -async function loadSpecs({ llm, user, message, tools = [], map = false, memory, signal }) { - const directoryPath = path.join(__dirname, '..', '.well-known'); - let files = []; - - for (let i = 0; i < tools.length; i++) { - const filePath = path.join(directoryPath, tools[i] + '.json'); - - try { - // If the access Promise is resolved, it means that the file exists - // Then we can add it to the files array - await fs.promises.access(filePath, fs.constants.F_OK); - files.push(tools[i] + '.json'); - } catch (err) { - logger.error(`[loadSpecs] File ${tools[i] + '.json'} does not exist`, err); - } - } - - if (files.length === 0) { - files = (await fs.promises.readdir(directoryPath)).filter( - (file) => path.extname(file) === '.json', - ); - } - - const validJsons = []; - const constructorMap = {}; - - logger.debug('[validateJson] files', files); - - for (const file of files) { - if (path.extname(file) === '.json') { - const filePath = path.join(directoryPath, file); - const fileContent = await fs.promises.readFile(filePath, 'utf8'); - const json = JSON.parse(fileContent); - - if (!validateJson(json)) { - logger.debug('[validateJson] Invalid json', json); - continue; - } - - if (llm && map) { - constructorMap[json.name_for_model] = async () => - await createOpenAPIPlugin({ - data: json, - llm, - message, - memory, - signal, - user, - }); - continue; - } - - if (llm) { - validJsons.push(createOpenAPIPlugin({ data: json, llm })); - continue; - } - - validJsons.push(json); - } - } - - if (map) { - return constructorMap; - } - - const plugins = (await Promise.all(validJsons)).filter((plugin) => plugin); - - // logger.debug('[validateJson] plugins', plugins); - // logger.debug(plugins[0].name); - - return plugins; -} - -module.exports = { - loadSpecs, - validateJson, - ManifestDefinition, -}; diff --git a/api/app/clients/tools/util/loadSpecs.spec.js b/api/app/clients/tools/util/loadSpecs.spec.js deleted file mode 100644 index 7b906d86f0..0000000000 --- a/api/app/clients/tools/util/loadSpecs.spec.js +++ /dev/null @@ -1,101 +0,0 @@ -const fs = require('fs'); -const { validateJson, loadSpecs, ManifestDefinition } = require('./loadSpecs'); -const { createOpenAPIPlugin } = require('../dynamic/OpenAPIPlugin'); - -jest.mock('../dynamic/OpenAPIPlugin'); - -describe('ManifestDefinition', () => { - it('should validate correct json', () => { - const json = { - name_for_human: 'Test', - name_for_model: 'Test', - description_for_human: 'Test', - description_for_model: 'Test', - api: { - url: 'http://test.com', - }, - }; - - expect(() => ManifestDefinition.parse(json)).not.toThrow(); - }); - - it('should not validate incorrect json', () => { - const json = { - name_for_human: 'Test', - name_for_model: 'Test', - description_for_human: 'Test', - description_for_model: 'Test', - api: { - url: 123, // incorrect type - }, - }; - - expect(() => ManifestDefinition.parse(json)).toThrow(); - }); -}); - -describe('validateJson', () => { - it('should return parsed json if valid', () => { - const json = { - name_for_human: 'Test', - name_for_model: 'Test', - description_for_human: 'Test', - description_for_model: 'Test', - api: { - url: 'http://test.com', - }, - }; - - expect(validateJson(json)).toEqual(json); - }); - - it('should return false if json is not valid', () => { - const json = { - name_for_human: 'Test', - name_for_model: 'Test', - description_for_human: 'Test', - description_for_model: 'Test', - api: { - url: 123, // incorrect type - }, - }; - - expect(validateJson(json)).toEqual(false); - }); -}); - -describe('loadSpecs', () => { - beforeEach(() => { - jest.spyOn(fs.promises, 'readdir').mockResolvedValue(['test.json']); - jest.spyOn(fs.promises, 'readFile').mockResolvedValue( - JSON.stringify({ - name_for_human: 'Test', - name_for_model: 'Test', - description_for_human: 'Test', - description_for_model: 'Test', - api: { - url: 'http://test.com', - }, - }), - ); - createOpenAPIPlugin.mockResolvedValue({}); - }); - - afterEach(() => { - jest.restoreAllMocks(); - }); - - it('should return plugins', async () => { - const plugins = await loadSpecs({ llm: true, verbose: false }); - - expect(plugins).toHaveLength(1); - expect(createOpenAPIPlugin).toHaveBeenCalledTimes(1); - }); - - it('should return constructorMap if map is true', async () => { - const plugins = await loadSpecs({ llm: {}, map: true, verbose: false }); - - expect(plugins).toHaveProperty('Test'); - expect(createOpenAPIPlugin).not.toHaveBeenCalled(); - }); -}); diff --git a/api/cache/clearPendingReq.js b/api/cache/clearPendingReq.js index 122638d7f9..54db8e9690 100644 --- a/api/cache/clearPendingReq.js +++ b/api/cache/clearPendingReq.js @@ -1,7 +1,8 @@ +const { Time, CacheKeys } = require('librechat-data-provider'); +const { isEnabled } = require('~/server/utils'); const getLogStores = require('./getLogStores'); -const { isEnabled } = require('../server/utils'); + const { USE_REDIS, LIMIT_CONCURRENT_MESSAGES } = process.env ?? {}; -const ttl = 1000 * 60 * 1; /** * Clear or decrement pending requests from the cache. @@ -28,7 +29,7 @@ const clearPendingReq = async ({ userId, cache: _cache }) => { return; } - const namespace = 'pending_req'; + const namespace = CacheKeys.PENDING_REQ; const cache = _cache ?? getLogStores(namespace); if (!cache) { @@ -39,7 +40,7 @@ const clearPendingReq = async ({ userId, cache: _cache }) => { const currentReq = +((await cache.get(key)) ?? 0); if (currentReq && currentReq >= 1) { - await cache.set(key, currentReq - 1, ttl); + await cache.set(key, currentReq - 1, Time.ONE_MINUTE); } else { await cache.delete(key); } diff --git a/api/cache/getLogStores.js b/api/cache/getLogStores.js index 6592371f02..612638b97b 100644 --- a/api/cache/getLogStores.js +++ b/api/cache/getLogStores.js @@ -1,4 +1,4 @@ -const Keyv = require('keyv'); +const { Keyv } = require('keyv'); const { CacheKeys, ViolationTypes, Time } = require('librechat-data-provider'); const { logFile, violationFile } = require('./keyvFiles'); const { math, isEnabled } = require('~/server/utils'); @@ -19,7 +19,7 @@ const createViolationInstance = (namespace) => { // Serve cache from memory so no need to clear it on startup/exit const pending_req = isRedisEnabled ? new Keyv({ store: keyvRedis }) - : new Keyv({ namespace: 'pending_req' }); + : new Keyv({ namespace: CacheKeys.PENDING_REQ }); const config = isRedisEnabled ? new Keyv({ store: keyvRedis }) @@ -49,6 +49,10 @@ const genTitle = isRedisEnabled ? new Keyv({ store: keyvRedis, ttl: Time.TWO_MINUTES }) : new Keyv({ namespace: CacheKeys.GEN_TITLE, ttl: Time.TWO_MINUTES }); +const s3ExpiryInterval = isRedisEnabled + ? new Keyv({ store: keyvRedis, ttl: Time.THIRTY_MINUTES }) + : new Keyv({ namespace: CacheKeys.S3_EXPIRY_INTERVAL, ttl: Time.THIRTY_MINUTES }); + const modelQueries = isEnabled(process.env.USE_REDIS) ? new Keyv({ store: keyvRedis }) : new Keyv({ namespace: CacheKeys.MODEL_QUERIES }); @@ -60,7 +64,7 @@ const abortKeys = isRedisEnabled const namespaces = { [CacheKeys.ROLES]: roles, [CacheKeys.CONFIG_STORE]: config, - pending_req, + [CacheKeys.PENDING_REQ]: pending_req, [ViolationTypes.BAN]: new Keyv({ store: keyvMongo, namespace: CacheKeys.BANS, ttl: duration }), [CacheKeys.ENCODED_DOMAINS]: new Keyv({ store: keyvMongo, @@ -89,6 +93,7 @@ const namespaces = { [CacheKeys.ABORT_KEYS]: abortKeys, [CacheKeys.TOKEN_CONFIG]: tokenConfig, [CacheKeys.GEN_TITLE]: genTitle, + [CacheKeys.S3_EXPIRY_INTERVAL]: s3ExpiryInterval, [CacheKeys.MODEL_QUERIES]: modelQueries, [CacheKeys.AUDIO_RUNS]: audioRuns, [CacheKeys.MESSAGES]: messages, diff --git a/api/cache/ioredisClient.js b/api/cache/ioredisClient.js new file mode 100644 index 0000000000..cd48459ab4 --- /dev/null +++ b/api/cache/ioredisClient.js @@ -0,0 +1,92 @@ +const fs = require('fs'); +const Redis = require('ioredis'); +const { isEnabled } = require('~/server/utils'); +const logger = require('~/config/winston'); + +const { REDIS_URI, USE_REDIS, USE_REDIS_CLUSTER, REDIS_CA, REDIS_MAX_LISTENERS } = process.env; + +/** @type {import('ioredis').Redis | import('ioredis').Cluster} */ +let ioredisClient; +const redis_max_listeners = Number(REDIS_MAX_LISTENERS) || 40; + +function mapURI(uri) { + const regex = + /^(?:(?\w+):\/\/)?(?:(?[^:@]+)(?::(?[^@]+))?@)?(?[\w.-]+)(?::(?\d{1,5}))?$/; + const match = uri.match(regex); + + if (match) { + const { scheme, user, password, host, port } = match.groups; + + return { + scheme: scheme || 'none', + user: user || null, + password: password || null, + host: host || null, + port: port || null, + }; + } else { + const parts = uri.split(':'); + if (parts.length === 2) { + return { + scheme: 'none', + user: null, + password: null, + host: parts[0], + port: parts[1], + }; + } + + return { + scheme: 'none', + user: null, + password: null, + host: uri, + port: null, + }; + } +} + +if (REDIS_URI && isEnabled(USE_REDIS)) { + let redisOptions = null; + + if (REDIS_CA) { + const ca = fs.readFileSync(REDIS_CA); + redisOptions = { tls: { ca } }; + } + + if (isEnabled(USE_REDIS_CLUSTER)) { + const hosts = REDIS_URI.split(',').map((item) => { + var value = mapURI(item); + + return { + host: value.host, + port: value.port, + }; + }); + ioredisClient = new Redis.Cluster(hosts, { redisOptions }); + } else { + ioredisClient = new Redis(REDIS_URI, redisOptions); + } + + ioredisClient.on('ready', () => { + logger.info('IoRedis connection ready'); + }); + ioredisClient.on('reconnecting', () => { + logger.info('IoRedis connection reconnecting'); + }); + ioredisClient.on('end', () => { + logger.info('IoRedis connection ended'); + }); + ioredisClient.on('close', () => { + logger.info('IoRedis connection closed'); + }); + ioredisClient.on('error', (err) => logger.error('IoRedis connection error:', err)); + ioredisClient.setMaxListeners(redis_max_listeners); + logger.info( + '[Optional] IoRedis initialized for rate limiters. If you have issues, disable Redis or restart the server.', + ); +} else { + logger.info('[Optional] IoRedis not initialized for rate limiters.'); +} + +module.exports = ioredisClient; diff --git a/api/cache/keyvFiles.js b/api/cache/keyvFiles.js index f969174b7d..1476b60cb8 100644 --- a/api/cache/keyvFiles.js +++ b/api/cache/keyvFiles.js @@ -1,11 +1,9 @@ const { KeyvFile } = require('keyv-file'); -const logFile = new KeyvFile({ filename: './data/logs.json' }); -const pendingReqFile = new KeyvFile({ filename: './data/pendingReqCache.json' }); -const violationFile = new KeyvFile({ filename: './data/violations.json' }); +const logFile = new KeyvFile({ filename: './data/logs.json' }).setMaxListeners(20); +const violationFile = new KeyvFile({ filename: './data/violations.json' }).setMaxListeners(20); module.exports = { logFile, - pendingReqFile, violationFile, }; diff --git a/api/cache/keyvMongo.js b/api/cache/keyvMongo.js index 8f5b9fd8d8..1606e98eb8 100644 --- a/api/cache/keyvMongo.js +++ b/api/cache/keyvMongo.js @@ -1,9 +1,272 @@ -const KeyvMongo = require('@keyv/mongo'); +// api/cache/keyvMongo.js +const mongoose = require('mongoose'); +const EventEmitter = require('events'); +const { GridFSBucket } = require('mongodb'); const { logger } = require('~/config'); -const { MONGO_URI } = process.env ?? {}; +const storeMap = new Map(); + +class KeyvMongoCustom extends EventEmitter { + constructor(url, options = {}) { + super(); + + url = url || {}; + if (typeof url === 'string') { + url = { url }; + } + if (url.uri) { + url = { url: url.uri, ...url }; + } + + this.opts = { + url: 'mongodb://127.0.0.1:27017', + collection: 'keyv', + ...url, + ...options, + }; + + this.ttlSupport = false; + + // Filter valid options + const keyvMongoKeys = new Set([ + 'url', + 'collection', + 'namespace', + 'serialize', + 'deserialize', + 'uri', + 'useGridFS', + 'dialect', + ]); + this.opts = Object.fromEntries(Object.entries(this.opts).filter(([k]) => keyvMongoKeys.has(k))); + } + + // Helper to access the store WITHOUT storing a promise on the instance + _getClient() { + const storeKey = `${this.opts.collection}:${this.opts.useGridFS ? 'gridfs' : 'collection'}`; + + // If we already have the store initialized, return it directly + if (storeMap.has(storeKey)) { + return Promise.resolve(storeMap.get(storeKey)); + } + + // Check mongoose connection state + if (mongoose.connection.readyState !== 1) { + return Promise.reject( + new Error('Mongoose connection not ready. Ensure connectDb() is called first.'), + ); + } + + try { + const db = mongoose.connection.db; + let client; + + if (this.opts.useGridFS) { + const bucket = new GridFSBucket(db, { + readPreference: this.opts.readPreference, + bucketName: this.opts.collection, + }); + const store = db.collection(`${this.opts.collection}.files`); + client = { bucket, store, db }; + } else { + const collection = this.opts.collection || 'keyv'; + const store = db.collection(collection); + client = { store, db }; + } + + storeMap.set(storeKey, client); + return Promise.resolve(client); + } catch (error) { + this.emit('error', error); + return Promise.reject(error); + } + } + + async get(key) { + const client = await this._getClient(); + + if (this.opts.useGridFS) { + await client.store.updateOne( + { + filename: key, + }, + { + $set: { + 'metadata.lastAccessed': new Date(), + }, + }, + ); + + const stream = client.bucket.openDownloadStreamByName(key); + + return new Promise((resolve) => { + const resp = []; + stream.on('error', () => { + resolve(undefined); + }); + + stream.on('end', () => { + const data = Buffer.concat(resp).toString('utf8'); + resolve(data); + }); + + stream.on('data', (chunk) => { + resp.push(chunk); + }); + }); + } + + const document = await client.store.findOne({ key: { $eq: key } }); + + if (!document) { + return undefined; + } + + return document.value; + } + + async getMany(keys) { + const client = await this._getClient(); + + if (this.opts.useGridFS) { + const promises = []; + for (const key of keys) { + promises.push(this.get(key)); + } + + const values = await Promise.allSettled(promises); + const data = []; + for (const value of values) { + data.push(value.value); + } + + return data; + } + + const values = await client.store + .find({ key: { $in: keys } }) + .project({ _id: 0, value: 1, key: 1 }) + .toArray(); + + const results = [...keys]; + let i = 0; + for (const key of keys) { + const rowIndex = values.findIndex((row) => row.key === key); + results[i] = rowIndex > -1 ? values[rowIndex].value : undefined; + i++; + } + + return results; + } + + async set(key, value, ttl) { + const client = await this._getClient(); + const expiresAt = typeof ttl === 'number' ? new Date(Date.now() + ttl) : null; + + if (this.opts.useGridFS) { + const stream = client.bucket.openUploadStream(key, { + metadata: { + expiresAt, + lastAccessed: new Date(), + }, + }); + + return new Promise((resolve) => { + stream.on('finish', () => { + resolve(stream); + }); + stream.end(value); + }); + } + + await client.store.updateOne( + { key: { $eq: key } }, + { $set: { key, value, expiresAt } }, + { upsert: true }, + ); + } + + async delete(key) { + if (typeof key !== 'string') { + return false; + } + + const client = await this._getClient(); + + if (this.opts.useGridFS) { + try { + const bucket = new GridFSBucket(client.db, { + bucketName: this.opts.collection, + }); + const files = await bucket.find({ filename: key }).toArray(); + await client.bucket.delete(files[0]._id); + return true; + } catch { + return false; + } + } + + const object = await client.store.deleteOne({ key: { $eq: key } }); + return object.deletedCount > 0; + } + + async deleteMany(keys) { + const client = await this._getClient(); + + if (this.opts.useGridFS) { + const bucket = new GridFSBucket(client.db, { + bucketName: this.opts.collection, + }); + const files = await bucket.find({ filename: { $in: keys } }).toArray(); + if (files.length === 0) { + return false; + } + + await Promise.all(files.map(async (file) => client.bucket.delete(file._id))); + return true; + } + + const object = await client.store.deleteMany({ key: { $in: keys } }); + return object.deletedCount > 0; + } + + async clear() { + const client = await this._getClient(); + + if (this.opts.useGridFS) { + try { + await client.bucket.drop(); + } catch (error) { + // Throw error if not "namespace not found" error + if (!(error.code === 26)) { + throw error; + } + } + } + + await client.store.deleteMany({ + key: { $regex: this.namespace ? `^${this.namespace}:*` : '' }, + }); + } + + async has(key) { + const client = await this._getClient(); + const filter = { [this.opts.useGridFS ? 'filename' : 'key']: { $eq: key } }; + const document = await client.store.countDocuments(filter, { limit: 1 }); + return document !== 0; + } + + // No-op disconnect + async disconnect() { + // This is a no-op since we don't want to close the shared mongoose connection + return true; + } +} + +const keyvMongo = new KeyvMongoCustom({ + collection: 'logs', +}); -const keyvMongo = new KeyvMongo(MONGO_URI, { collection: 'logs' }); keyvMongo.on('error', (err) => logger.error('KeyvMongo connection error:', err)); module.exports = keyvMongo; diff --git a/api/cache/keyvRedis.js b/api/cache/keyvRedis.js index d544b50a11..cb9d837e21 100644 --- a/api/cache/keyvRedis.js +++ b/api/cache/keyvRedis.js @@ -1,20 +1,106 @@ -const KeyvRedis = require('@keyv/redis'); +const fs = require('fs'); +const ioredis = require('ioredis'); +const KeyvRedis = require('@keyv/redis').default; const { isEnabled } = require('~/server/utils'); const logger = require('~/config/winston'); -const { REDIS_URI, USE_REDIS } = process.env; +const { REDIS_URI, USE_REDIS, USE_REDIS_CLUSTER, REDIS_CA, REDIS_KEY_PREFIX, REDIS_MAX_LISTENERS } = + process.env; let keyvRedis; +const redis_prefix = REDIS_KEY_PREFIX || ''; +const redis_max_listeners = Number(REDIS_MAX_LISTENERS) || 40; + +function mapURI(uri) { + const regex = + /^(?:(?\w+):\/\/)?(?:(?[^:@]+)(?::(?[^@]+))?@)?(?[\w.-]+)(?::(?\d{1,5}))?$/; + const match = uri.match(regex); + + if (match) { + const { scheme, user, password, host, port } = match.groups; + + return { + scheme: scheme || 'none', + user: user || null, + password: password || null, + host: host || null, + port: port || null, + }; + } else { + const parts = uri.split(':'); + if (parts.length === 2) { + return { + scheme: 'none', + user: null, + password: null, + host: parts[0], + port: parts[1], + }; + } + + return { + scheme: 'none', + user: null, + password: null, + host: uri, + port: null, + }; + } +} if (REDIS_URI && isEnabled(USE_REDIS)) { - keyvRedis = new KeyvRedis(REDIS_URI, { useRedisSets: false }); + let redisOptions = null; + /** @type {import('@keyv/redis').KeyvRedisOptions} */ + let keyvOpts = { + useRedisSets: false, + keyPrefix: redis_prefix, + }; + + if (REDIS_CA) { + const ca = fs.readFileSync(REDIS_CA); + redisOptions = { tls: { ca } }; + } + + if (isEnabled(USE_REDIS_CLUSTER)) { + const hosts = REDIS_URI.split(',').map((item) => { + var value = mapURI(item); + + return { + host: value.host, + port: value.port, + }; + }); + const cluster = new ioredis.Cluster(hosts, { redisOptions }); + keyvRedis = new KeyvRedis(cluster, keyvOpts); + } else { + keyvRedis = new KeyvRedis(REDIS_URI, keyvOpts); + } + + const pingInterval = setInterval(() => { + logger.debug('KeyvRedis ping'); + keyvRedis.client.ping().catch(err => logger.error('Redis keep-alive ping failed:', err)); + }, 5 * 60 * 1000); + + keyvRedis.on('ready', () => { + logger.info('KeyvRedis connection ready'); + }); + keyvRedis.on('reconnecting', () => { + logger.info('KeyvRedis connection reconnecting'); + }); + keyvRedis.on('end', () => { + logger.info('KeyvRedis connection ended'); + }); + keyvRedis.on('close', () => { + clearInterval(pingInterval); + logger.info('KeyvRedis connection closed'); + }); keyvRedis.on('error', (err) => logger.error('KeyvRedis connection error:', err)); - keyvRedis.setMaxListeners(20); + keyvRedis.setMaxListeners(redis_max_listeners); logger.info( - '[Optional] Redis initialized. Note: Redis support is experimental. If you have issues, disable it. Cache needs to be flushed for values to refresh.', + '[Optional] Redis initialized. If you have issues, or seeing older values, disable it or flush cache to refresh values.', ); } else { - logger.info('[Optional] Redis not initialized. Note: Redis support is experimental.'); + logger.info('[Optional] Redis not initialized.'); } module.exports = keyvRedis; diff --git a/api/cache/redis.js b/api/cache/redis.js deleted file mode 100644 index adf291d02b..0000000000 --- a/api/cache/redis.js +++ /dev/null @@ -1,4 +0,0 @@ -const Redis = require('ioredis'); -const { REDIS_URI } = process.env ?? {}; -const redis = new Redis.Cluster(REDIS_URI); -module.exports = redis; diff --git a/api/config/index.js b/api/config/index.js index aaf8bb2764..e238f700be 100644 --- a/api/config/index.js +++ b/api/config/index.js @@ -1,31 +1,35 @@ +const axios = require('axios'); const { EventSource } = require('eventsource'); const { Time, CacheKeys } = require('librechat-data-provider'); +const { MCPManager, FlowStateManager } = require('librechat-mcp'); const logger = require('./winston'); global.EventSource = EventSource; +/** @type {MCPManager} */ let mcpManager = null; let flowManager = null; /** - * @returns {Promise} + * @param {string} [userId] - Optional user ID, to avoid disconnecting the current user. + * @returns {MCPManager} */ -async function getMCPManager() { +function getMCPManager(userId) { if (!mcpManager) { - const { MCPManager } = await import('librechat-mcp'); mcpManager = MCPManager.getInstance(logger); + } else { + mcpManager.checkIdleConnections(userId); } return mcpManager; } /** - * @param {(key: string) => Keyv} getLogStores - * @returns {Promise} + * @param {Keyv} flowsCache + * @returns {FlowStateManager} */ -async function getFlowStateManager(getLogStores) { +function getFlowStateManager(flowsCache) { if (!flowManager) { - const { FlowStateManager } = await import('librechat-mcp'); - flowManager = new FlowStateManager(getLogStores(CacheKeys.FLOWS), { + flowManager = new FlowStateManager(flowsCache, { ttl: Time.ONE_MINUTE * 3, logger, }); @@ -47,9 +51,46 @@ const sendEvent = (res, event) => { res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`); }; +/** + * Creates and configures an Axios instance with optional proxy settings. + * + * @typedef {import('axios').AxiosInstance} AxiosInstance + * @typedef {import('axios').AxiosProxyConfig} AxiosProxyConfig + * + * @returns {AxiosInstance} A configured Axios instance + * @throws {Error} If there's an issue creating the Axios instance or parsing the proxy URL + */ +function createAxiosInstance() { + const instance = axios.create(); + + if (process.env.proxy) { + try { + const url = new URL(process.env.proxy); + + /** @type {AxiosProxyConfig} */ + const proxyConfig = { + host: url.hostname.replace(/^\[|\]$/g, ''), + protocol: url.protocol.replace(':', ''), + }; + + if (url.port) { + proxyConfig.port = parseInt(url.port, 10); + } + + instance.defaults.proxy = proxyConfig; + } catch (error) { + console.error('Error parsing proxy URL:', error); + throw new Error(`Invalid proxy URL: ${process.env.proxy}`); + } + } + + return instance; +} + module.exports = { logger, sendEvent, getMCPManager, + createAxiosInstance, getFlowStateManager, }; diff --git a/api/config/index.spec.js b/api/config/index.spec.js new file mode 100644 index 0000000000..36ed8302f3 --- /dev/null +++ b/api/config/index.spec.js @@ -0,0 +1,126 @@ +const axios = require('axios'); +const { createAxiosInstance } = require('./index'); + +// Mock axios +jest.mock('axios', () => ({ + interceptors: { + request: { use: jest.fn(), eject: jest.fn() }, + response: { use: jest.fn(), eject: jest.fn() }, + }, + create: jest.fn().mockReturnValue({ + defaults: { + proxy: null, + }, + get: jest.fn().mockResolvedValue({ data: {} }), + post: jest.fn().mockResolvedValue({ data: {} }), + put: jest.fn().mockResolvedValue({ data: {} }), + delete: jest.fn().mockResolvedValue({ data: {} }), + }), + get: jest.fn().mockResolvedValue({ data: {} }), + post: jest.fn().mockResolvedValue({ data: {} }), + put: jest.fn().mockResolvedValue({ data: {} }), + delete: jest.fn().mockResolvedValue({ data: {} }), + reset: jest.fn().mockImplementation(function () { + this.get.mockClear(); + this.post.mockClear(); + this.put.mockClear(); + this.delete.mockClear(); + this.create.mockClear(); + }), +})); + +describe('createAxiosInstance', () => { + const originalEnv = process.env; + + beforeEach(() => { + // Reset mocks + jest.clearAllMocks(); + // Create a clean copy of process.env + process.env = { ...originalEnv }; + // Default: no proxy + delete process.env.proxy; + }); + + afterAll(() => { + // Restore original process.env + process.env = originalEnv; + }); + + test('creates an axios instance without proxy when no proxy env is set', () => { + const instance = createAxiosInstance(); + + expect(axios.create).toHaveBeenCalledTimes(1); + expect(instance.defaults.proxy).toBeNull(); + }); + + test('configures proxy correctly with hostname and protocol', () => { + process.env.proxy = 'http://example.com'; + + const instance = createAxiosInstance(); + + expect(axios.create).toHaveBeenCalledTimes(1); + expect(instance.defaults.proxy).toEqual({ + host: 'example.com', + protocol: 'http', + }); + }); + + test('configures proxy correctly with hostname, protocol and port', () => { + process.env.proxy = 'https://proxy.example.com:8080'; + + const instance = createAxiosInstance(); + + expect(axios.create).toHaveBeenCalledTimes(1); + expect(instance.defaults.proxy).toEqual({ + host: 'proxy.example.com', + protocol: 'https', + port: 8080, + }); + }); + + test('handles proxy URLs with authentication', () => { + process.env.proxy = 'http://user:pass@proxy.example.com:3128'; + + const instance = createAxiosInstance(); + + expect(axios.create).toHaveBeenCalledTimes(1); + expect(instance.defaults.proxy).toEqual({ + host: 'proxy.example.com', + protocol: 'http', + port: 3128, + // Note: The current implementation doesn't handle auth - if needed, add this functionality + }); + }); + + test('throws error when proxy URL is invalid', () => { + process.env.proxy = 'invalid-url'; + + expect(() => createAxiosInstance()).toThrow('Invalid proxy URL'); + expect(axios.create).toHaveBeenCalledTimes(1); + }); + + // If you want to test the actual URL parsing more thoroughly + test('handles edge case proxy URLs correctly', () => { + // IPv6 address + process.env.proxy = 'http://[::1]:8080'; + + let instance = createAxiosInstance(); + + expect(instance.defaults.proxy).toEqual({ + host: '::1', + protocol: 'http', + port: 8080, + }); + + // URL with path (which should be ignored for proxy config) + process.env.proxy = 'http://proxy.example.com:8080/some/path'; + + instance = createAxiosInstance(); + + expect(instance.defaults.proxy).toEqual({ + host: 'proxy.example.com', + protocol: 'http', + port: 8080, + }); + }); +}); diff --git a/api/config/meiliLogger.js b/api/config/meiliLogger.js index 195b387ae5..c5e60ea157 100644 --- a/api/config/meiliLogger.js +++ b/api/config/meiliLogger.js @@ -4,7 +4,11 @@ require('winston-daily-rotate-file'); const logDir = path.join(__dirname, '..', 'logs'); -const { NODE_ENV } = process.env; +const { NODE_ENV, DEBUG_LOGGING = false } = process.env; + +const useDebugLogging = + (typeof DEBUG_LOGGING === 'string' && DEBUG_LOGGING?.toLowerCase() === 'true') || + DEBUG_LOGGING === true; const levels = { error: 0, @@ -36,9 +40,10 @@ const fileFormat = winston.format.combine( winston.format.splat(), ); +const logLevel = useDebugLogging ? 'debug' : 'error'; const transports = [ new winston.transports.DailyRotateFile({ - level: 'debug', + level: logLevel, filename: `${logDir}/meiliSync-%DATE%.log`, datePattern: 'YYYY-MM-DD', zippedArchive: true, @@ -48,14 +53,6 @@ const transports = [ }), ]; -// if (NODE_ENV !== 'production') { -// transports.push( -// new winston.transports.Console({ -// format: winston.format.combine(winston.format.colorize(), winston.format.simple()), -// }), -// ); -// } - const consoleFormat = winston.format.combine( winston.format.colorize({ all: true }), winston.format.timestamp({ format: 'YYYY-MM-DD HH:mm:ss' }), diff --git a/api/config/winston.js b/api/config/winston.js index 8f51b9963c..12f6053723 100644 --- a/api/config/winston.js +++ b/api/config/winston.js @@ -5,7 +5,7 @@ const { redactFormat, redactMessage, debugTraverse, jsonTruncateFormat } = requi const logDir = path.join(__dirname, '..', 'logs'); -const { NODE_ENV, DEBUG_LOGGING = true, DEBUG_CONSOLE = false, CONSOLE_JSON = false } = process.env; +const { NODE_ENV, DEBUG_LOGGING = true, CONSOLE_JSON = false, DEBUG_CONSOLE = false } = process.env; const useConsoleJson = (typeof CONSOLE_JSON === 'string' && CONSOLE_JSON?.toLowerCase() === 'true') || @@ -15,6 +15,10 @@ const useDebugConsole = (typeof DEBUG_CONSOLE === 'string' && DEBUG_CONSOLE?.toLowerCase() === 'true') || DEBUG_CONSOLE === true; +const useDebugLogging = + (typeof DEBUG_LOGGING === 'string' && DEBUG_LOGGING?.toLowerCase() === 'true') || + DEBUG_LOGGING === true; + const levels = { error: 0, warn: 1, @@ -57,28 +61,9 @@ const transports = [ maxFiles: '14d', format: fileFormat, }), - // new winston.transports.DailyRotateFile({ - // level: 'info', - // filename: `${logDir}/info-%DATE%.log`, - // datePattern: 'YYYY-MM-DD', - // zippedArchive: true, - // maxSize: '20m', - // maxFiles: '14d', - // }), ]; -// if (NODE_ENV !== 'production') { -// transports.push( -// new winston.transports.Console({ -// format: winston.format.combine(winston.format.colorize(), winston.format.simple()), -// }), -// ); -// } - -if ( - (typeof DEBUG_LOGGING === 'string' && DEBUG_LOGGING?.toLowerCase() === 'true') || - DEBUG_LOGGING === true -) { +if (useDebugLogging) { transports.push( new winston.transports.DailyRotateFile({ level: 'debug', @@ -107,10 +92,16 @@ const consoleFormat = winston.format.combine( }), ); +// Determine console log level +let consoleLogLevel = 'info'; +if (useDebugConsole) { + consoleLogLevel = 'debug'; +} + if (useDebugConsole) { transports.push( new winston.transports.Console({ - level: 'debug', + level: consoleLogLevel, format: useConsoleJson ? winston.format.combine(fileFormat, jsonTruncateFormat(), winston.format.json()) : winston.format.combine(fileFormat, debugTraverse), @@ -119,14 +110,14 @@ if (useDebugConsole) { } else if (useConsoleJson) { transports.push( new winston.transports.Console({ - level: 'info', + level: consoleLogLevel, format: winston.format.combine(fileFormat, jsonTruncateFormat(), winston.format.json()), }), ); } else { transports.push( new winston.transports.Console({ - level: 'info', + level: consoleLogLevel, format: consoleFormat, }), ); diff --git a/api/jest.config.js b/api/jest.config.js index ec44bd7f56..2df7790b7b 100644 --- a/api/jest.config.js +++ b/api/jest.config.js @@ -5,7 +5,6 @@ module.exports = { coverageDirectory: 'coverage', setupFiles: [ './test/jestSetup.js', - './test/__mocks__/KeyvMongo.js', './test/__mocks__/logger.js', './test/__mocks__/fetchEventSource.js', ], diff --git a/api/lib/db/indexSync.js b/api/lib/db/indexSync.js index 86c909419d..75acd9d231 100644 --- a/api/lib/db/indexSync.js +++ b/api/lib/db/indexSync.js @@ -1,9 +1,11 @@ const { MeiliSearch } = require('meilisearch'); -const Conversation = require('~/models/schema/convoSchema'); -const Message = require('~/models/schema/messageSchema'); +const { Conversation } = require('~/models/Conversation'); +const { Message } = require('~/models/Message'); +const { isEnabled } = require('~/server/utils'); const { logger } = require('~/config'); -const searchEnabled = process.env?.SEARCH?.toLowerCase() === 'true'; +const searchEnabled = isEnabled(process.env.SEARCH); +const indexingDisabled = isEnabled(process.env.MEILI_NO_SYNC); let currentTimeout = null; class MeiliSearchClient { @@ -23,8 +25,7 @@ class MeiliSearchClient { } } -// eslint-disable-next-line no-unused-vars -async function indexSync(req, res, next) { +async function indexSync() { if (!searchEnabled) { return; } @@ -33,10 +34,15 @@ async function indexSync(req, res, next) { const client = MeiliSearchClient.getInstance(); const { status } = await client.health(); - if (status !== 'available' || !process.env.SEARCH) { + if (status !== 'available') { throw new Error('Meilisearch not available'); } + if (indexingDisabled === true) { + logger.info('[indexSync] Indexing is disabled, skipping...'); + return; + } + const messageCount = await Message.countDocuments(); const convoCount = await Conversation.countDocuments(); const messages = await client.index('messages').getStats(); @@ -71,7 +77,6 @@ async function indexSync(req, res, next) { logger.info('[indexSync] Meilisearch not configured, search will be disabled.'); } else { logger.error('[indexSync] error', err); - // res.status(500).json({ error: 'Server error' }); } } } diff --git a/api/lib/utils/reduceHits.js b/api/lib/utils/reduceHits.js deleted file mode 100644 index 77b2f9d57d..0000000000 --- a/api/lib/utils/reduceHits.js +++ /dev/null @@ -1,59 +0,0 @@ -const mergeSort = require('./mergeSort'); -const { cleanUpPrimaryKeyValue } = require('./misc'); - -function reduceMessages(hits) { - const counts = {}; - - for (const hit of hits) { - if (!counts[hit.conversationId]) { - counts[hit.conversationId] = 1; - } else { - counts[hit.conversationId]++; - } - } - - const result = []; - - for (const [conversationId, count] of Object.entries(counts)) { - result.push({ - conversationId, - count, - }); - } - - return mergeSort(result, (a, b) => b.count - a.count); -} - -function reduceHits(hits, titles = []) { - const counts = {}; - const titleMap = {}; - const convos = [...hits, ...titles]; - - for (const convo of convos) { - const currentId = cleanUpPrimaryKeyValue(convo.conversationId); - if (!counts[currentId]) { - counts[currentId] = 1; - } else { - counts[currentId]++; - } - - if (convo.title) { - // titleMap[currentId] = convo._formatted.title; - titleMap[currentId] = convo.title; - } - } - - const result = []; - - for (const [conversationId, count] of Object.entries(counts)) { - result.push({ - conversationId, - count, - title: titleMap[conversationId] ? titleMap[conversationId] : null, - }); - } - - return mergeSort(result, (a, b) => b.count - a.count); -} - -module.exports = { reduceMessages, reduceHits }; diff --git a/api/models/Action.js b/api/models/Action.js index 299b3bf20a..677b4d78df 100644 --- a/api/models/Action.js +++ b/api/models/Action.js @@ -1,5 +1,5 @@ const mongoose = require('mongoose'); -const actionSchema = require('./schema/action'); +const { actionSchema } = require('@librechat/data-schemas'); const Action = mongoose.model('action', actionSchema); diff --git a/api/models/Agent.js b/api/models/Agent.js index 6fa00f56bc..9b34eeae65 100644 --- a/api/models/Agent.js +++ b/api/models/Agent.js @@ -1,6 +1,8 @@ const mongoose = require('mongoose'); -const { SystemRoles } = require('librechat-data-provider'); -const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants; +const { agentSchema } = require('@librechat/data-schemas'); +const { SystemRoles, Tools } = require('librechat-data-provider'); +const { GLOBAL_PROJECT_NAME, EPHEMERAL_AGENT_ID, mcp_delimiter } = + require('librechat-data-provider').Constants; const { CONFIG_STORE, STARTUP_CONFIG } = require('librechat-data-provider').CacheKeys; const { getProjectByName, @@ -9,7 +11,6 @@ const { removeAgentFromAllProjects, } = require('./Project'); const getLogStores = require('~/cache/getLogStores'); -const agentSchema = require('./schema/agent'); const Agent = mongoose.model('agent', agentSchema); @@ -39,13 +40,69 @@ const getAgent = async (searchParameter) => await Agent.findOne(searchParameter) * @param {Object} params * @param {ServerRequest} params.req * @param {string} params.agent_id + * @param {string} params.endpoint + * @param {import('@librechat/agents').ClientOptions} [params.model_parameters] + * @returns {Agent|null} The agent document as a plain object, or null if not found. + */ +const loadEphemeralAgent = ({ req, agent_id, endpoint, model_parameters: _m }) => { + const { model, ...model_parameters } = _m; + /** @type {Record} */ + const availableTools = req.app.locals.availableTools; + const mcpServers = new Set(req.body.ephemeralAgent?.mcp); + /** @type {string[]} */ + const tools = []; + if (req.body.ephemeralAgent?.execute_code === true) { + tools.push(Tools.execute_code); + } + + if (mcpServers.size > 0) { + for (const toolName of Object.keys(availableTools)) { + if (!toolName.includes(mcp_delimiter)) { + continue; + } + const mcpServer = toolName.split(mcp_delimiter)?.[1]; + if (mcpServer && mcpServers.has(mcpServer)) { + tools.push(toolName); + } + } + } + + const instructions = req.body.promptPrefix; + return { + id: agent_id, + instructions, + provider: endpoint, + model_parameters, + model, + tools, + }; +}; + +/** + * Load an agent based on the provided ID + * + * @param {Object} params + * @param {ServerRequest} params.req + * @param {string} params.agent_id + * @param {string} params.endpoint + * @param {import('@librechat/agents').ClientOptions} [params.model_parameters] * @returns {Promise} The agent document as a plain object, or null if not found. */ -const loadAgent = async ({ req, agent_id }) => { +const loadAgent = async ({ req, agent_id, endpoint, model_parameters }) => { + if (!agent_id) { + return null; + } + if (agent_id === EPHEMERAL_AGENT_ID) { + return loadEphemeralAgent({ req, agent_id, endpoint, model_parameters }); + } const agent = await getAgent({ id: agent_id, }); + if (!agent) { + return null; + } + if (agent.author.toString() === req.user.id) { return agent; } @@ -96,12 +153,30 @@ const updateAgent = async (searchParameter, updateData) => { */ const addAgentResourceFile = async ({ agent_id, tool_resource, file_id }) => { const searchParameter = { id: agent_id }; - - // build the update to push or create the file ids set + let agent = await getAgent(searchParameter); + if (!agent) { + throw new Error('Agent not found for adding resource file'); + } const fileIdsPath = `tool_resources.${tool_resource}.file_ids`; - const updateData = { $addToSet: { [fileIdsPath]: file_id } }; + await Agent.updateOne( + { + id: agent_id, + [`${fileIdsPath}`]: { $exists: false }, + }, + { + $set: { + [`${fileIdsPath}`]: [], + }, + }, + ); + + const updateData = { + $addToSet: { + tools: tool_resource, + [fileIdsPath]: file_id, + }, + }; - // return the updated agent or throw if no agent matches const updatedAgent = await updateAgent(searchParameter, updateData); if (updatedAgent) { return updatedAgent; @@ -111,16 +186,17 @@ const addAgentResourceFile = async ({ agent_id, tool_resource, file_id }) => { }; /** - * Removes multiple resource files from an agent in a single update. + * Removes multiple resource files from an agent using atomic operations. * @param {object} params * @param {string} params.agent_id * @param {Array<{tool_resource: string, file_id: string}>} params.files * @returns {Promise} The updated agent. + * @throws {Error} If the agent is not found or update fails. */ const removeAgentResourceFiles = async ({ agent_id, files }) => { const searchParameter = { id: agent_id }; - // associate each tool resource with the respective file ids array + // Group files to remove by resource const filesByResource = files.reduce((acc, { tool_resource, file_id }) => { if (!acc[tool_resource]) { acc[tool_resource] = []; @@ -129,42 +205,35 @@ const removeAgentResourceFiles = async ({ agent_id, files }) => { return acc; }, {}); - // build the update aggregation pipeline wich removes file ids from tool resources array - // and eventually deletes empty tool resources - const updateData = []; - Object.entries(filesByResource).forEach(([resource, fileIds]) => { - const toolResourcePath = `tool_resources.${resource}`; - const fileIdsPath = `${toolResourcePath}.file_ids`; - - // file ids removal stage - updateData.push({ - $set: { - [fileIdsPath]: { - $filter: { - input: `$${fileIdsPath}`, - cond: { $not: [{ $in: ['$$this', fileIds] }] }, - }, - }, - }, - }); - - // empty tool resource deletion stage - updateData.push({ - $set: { - [toolResourcePath]: { - $cond: [{ $eq: [`$${fileIdsPath}`, []] }, '$$REMOVE', `$${toolResourcePath}`], - }, - }, - }); - }); - - // return the updated agent or throw if no agent matches - const updatedAgent = await updateAgent(searchParameter, updateData); - if (updatedAgent) { - return updatedAgent; - } else { - throw new Error('Agent not found for removing resource files'); + // Step 1: Atomically remove file IDs using $pull + const pullOps = {}; + const resourcesToCheck = new Set(); + for (const [resource, fileIds] of Object.entries(filesByResource)) { + const fileIdsPath = `tool_resources.${resource}.file_ids`; + pullOps[fileIdsPath] = { $in: fileIds }; + resourcesToCheck.add(resource); } + + const updatePullData = { $pull: pullOps }; + const agentAfterPull = await Agent.findOneAndUpdate(searchParameter, updatePullData, { + new: true, + }).lean(); + + if (!agentAfterPull) { + // Agent might have been deleted concurrently, or never existed. + // Check if it existed before trying to throw. + const agentExists = await getAgent(searchParameter); + if (!agentExists) { + throw new Error('Agent not found for removing resource files'); + } + // If it existed but findOneAndUpdate returned null, something else went wrong. + throw new Error('Failed to update agent during file removal (pull step)'); + } + + // Return the agent state directly after the $pull operation. + // Skipping the $unset step for now to simplify and test core $pull atomicity. + // Empty arrays might remain, but the removal itself should be correct. + return agentAfterPull; }; /** @@ -239,7 +308,7 @@ const getListAgents = async (searchParameter) => { * This function also updates the corresponding projects to include or exclude the agent ID. * * @param {Object} params - Parameters for updating the agent's projects. - * @param {import('librechat-data-provider').TUser} params.user - Parameters for updating the agent's projects. + * @param {MongoUser} params.user - Parameters for updating the agent's projects. * @param {string} params.agentId - The ID of the agent to update. * @param {string[]} [params.projectIds] - Array of project IDs to add to the agent. * @param {string[]} [params.removeProjectIds] - Array of project IDs to remove from the agent. @@ -290,6 +359,7 @@ const updateAgentProjects = async ({ user, agentId, projectIds, removeProjectIds }; module.exports = { + Agent, getAgent, loadAgent, createAgent, diff --git a/api/models/Agent.spec.js b/api/models/Agent.spec.js new file mode 100644 index 0000000000..051cb6800f --- /dev/null +++ b/api/models/Agent.spec.js @@ -0,0 +1,334 @@ +const mongoose = require('mongoose'); +const { v4: uuidv4 } = require('uuid'); +const { MongoMemoryServer } = require('mongodb-memory-server'); +const { Agent, addAgentResourceFile, removeAgentResourceFiles } = require('./Agent'); + +describe('Agent Resource File Operations', () => { + let mongoServer; + + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + await mongoose.connect(mongoUri); + }); + + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); + }); + + beforeEach(async () => { + await Agent.deleteMany({}); + }); + + const createBasicAgent = async () => { + const agentId = `agent_${uuidv4()}`; + const agent = await Agent.create({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: new mongoose.Types.ObjectId(), + }); + return agent; + }; + + test('should add tool_resource to tools if missing', async () => { + const agent = await createBasicAgent(); + const fileId = uuidv4(); + const toolResource = 'file_search'; + + const updatedAgent = await addAgentResourceFile({ + agent_id: agent.id, + tool_resource: toolResource, + file_id: fileId, + }); + + expect(updatedAgent.tools).toContain(toolResource); + expect(Array.isArray(updatedAgent.tools)).toBe(true); + // Should not duplicate + const count = updatedAgent.tools.filter((t) => t === toolResource).length; + expect(count).toBe(1); + }); + + test('should not duplicate tool_resource in tools if already present', async () => { + const agent = await createBasicAgent(); + const fileId1 = uuidv4(); + const fileId2 = uuidv4(); + const toolResource = 'file_search'; + + // First add + await addAgentResourceFile({ + agent_id: agent.id, + tool_resource: toolResource, + file_id: fileId1, + }); + + // Second add (should not duplicate) + const updatedAgent = await addAgentResourceFile({ + agent_id: agent.id, + tool_resource: toolResource, + file_id: fileId2, + }); + + expect(updatedAgent.tools).toContain(toolResource); + expect(Array.isArray(updatedAgent.tools)).toBe(true); + const count = updatedAgent.tools.filter((t) => t === toolResource).length; + expect(count).toBe(1); + }); + + test('should handle concurrent file additions', async () => { + const agent = await createBasicAgent(); + const fileIds = Array.from({ length: 10 }, () => uuidv4()); + + // Concurrent additions + const additionPromises = fileIds.map((fileId) => + addAgentResourceFile({ + agent_id: agent.id, + tool_resource: 'test_tool', + file_id: fileId, + }), + ); + + await Promise.all(additionPromises); + + const updatedAgent = await Agent.findOne({ id: agent.id }); + expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined(); + expect(updatedAgent.tool_resources.test_tool.file_ids).toHaveLength(10); + expect(new Set(updatedAgent.tool_resources.test_tool.file_ids).size).toBe(10); + }); + + test('should handle concurrent additions and removals', async () => { + const agent = await createBasicAgent(); + const initialFileIds = Array.from({ length: 5 }, () => uuidv4()); + + await Promise.all( + initialFileIds.map((fileId) => + addAgentResourceFile({ + agent_id: agent.id, + tool_resource: 'test_tool', + file_id: fileId, + }), + ), + ); + + const newFileIds = Array.from({ length: 5 }, () => uuidv4()); + const operations = [ + ...newFileIds.map((fileId) => + addAgentResourceFile({ + agent_id: agent.id, + tool_resource: 'test_tool', + file_id: fileId, + }), + ), + ...initialFileIds.map((fileId) => + removeAgentResourceFiles({ + agent_id: agent.id, + files: [{ tool_resource: 'test_tool', file_id: fileId }], + }), + ), + ]; + + await Promise.all(operations); + + const updatedAgent = await Agent.findOne({ id: agent.id }); + expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined(); + expect(updatedAgent.tool_resources.test_tool.file_ids).toHaveLength(5); + }); + + test('should initialize array when adding to non-existent tool resource', async () => { + const agent = await createBasicAgent(); + const fileId = uuidv4(); + + const updatedAgent = await addAgentResourceFile({ + agent_id: agent.id, + tool_resource: 'new_tool', + file_id: fileId, + }); + + expect(updatedAgent.tool_resources.new_tool.file_ids).toBeDefined(); + expect(updatedAgent.tool_resources.new_tool.file_ids).toHaveLength(1); + expect(updatedAgent.tool_resources.new_tool.file_ids[0]).toBe(fileId); + }); + + test('should handle rapid sequential modifications to same tool resource', async () => { + const agent = await createBasicAgent(); + const fileId = uuidv4(); + + for (let i = 0; i < 10; i++) { + await addAgentResourceFile({ + agent_id: agent.id, + tool_resource: 'test_tool', + file_id: `${fileId}_${i}`, + }); + + if (i % 2 === 0) { + await removeAgentResourceFiles({ + agent_id: agent.id, + files: [{ tool_resource: 'test_tool', file_id: `${fileId}_${i}` }], + }); + } + } + + const updatedAgent = await Agent.findOne({ id: agent.id }); + expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined(); + expect(Array.isArray(updatedAgent.tool_resources.test_tool.file_ids)).toBe(true); + }); + + test('should handle multiple tool resources concurrently', async () => { + const agent = await createBasicAgent(); + const toolResources = ['tool1', 'tool2', 'tool3']; + const operations = []; + + toolResources.forEach((tool) => { + const fileIds = Array.from({ length: 5 }, () => uuidv4()); + fileIds.forEach((fileId) => { + operations.push( + addAgentResourceFile({ + agent_id: agent.id, + tool_resource: tool, + file_id: fileId, + }), + ); + }); + }); + + await Promise.all(operations); + + const updatedAgent = await Agent.findOne({ id: agent.id }); + toolResources.forEach((tool) => { + expect(updatedAgent.tool_resources[tool].file_ids).toBeDefined(); + expect(updatedAgent.tool_resources[tool].file_ids).toHaveLength(5); + }); + }); + + test('should handle concurrent duplicate additions', async () => { + const agent = await createBasicAgent(); + const fileId = uuidv4(); + + // Concurrent additions of the same file + const additionPromises = Array.from({ length: 5 }).map(() => + addAgentResourceFile({ + agent_id: agent.id, + tool_resource: 'test_tool', + file_id: fileId, + }), + ); + + await Promise.all(additionPromises); + + const updatedAgent = await Agent.findOne({ id: agent.id }); + expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined(); + // Should only contain one instance of the fileId + expect(updatedAgent.tool_resources.test_tool.file_ids).toHaveLength(1); + expect(updatedAgent.tool_resources.test_tool.file_ids[0]).toBe(fileId); + }); + + test('should handle concurrent add and remove of the same file', async () => { + const agent = await createBasicAgent(); + const fileId = uuidv4(); + + // First, ensure the file exists (or test might be trivial if remove runs first) + await addAgentResourceFile({ + agent_id: agent.id, + tool_resource: 'test_tool', + file_id: fileId, + }); + + // Concurrent add (which should be ignored) and remove + const operations = [ + addAgentResourceFile({ + agent_id: agent.id, + tool_resource: 'test_tool', + file_id: fileId, + }), + removeAgentResourceFiles({ + agent_id: agent.id, + files: [{ tool_resource: 'test_tool', file_id: fileId }], + }), + ]; + + await Promise.all(operations); + + const updatedAgent = await Agent.findOne({ id: agent.id }); + // The final state should ideally be that the file is removed, + // but the key point is consistency (not duplicated or error state). + // Depending on execution order, the file might remain if the add operation's + // findOneAndUpdate runs after the remove operation completes. + // A more robust check might be that the length is <= 1. + // Given the remove uses an update pipeline, it might be more likely to win. + // The final state depends on race condition timing (add or remove might "win"). + // The critical part is that the state is consistent (no duplicates, no errors). + // Assert that the fileId is either present exactly once or not present at all. + expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined(); + const finalFileIds = updatedAgent.tool_resources.test_tool.file_ids; + const count = finalFileIds.filter((id) => id === fileId).length; + expect(count).toBeLessThanOrEqual(1); // Should be 0 or 1, never more + // Optional: Check overall length is consistent with the count + if (count === 0) { + expect(finalFileIds).toHaveLength(0); + } else { + expect(finalFileIds).toHaveLength(1); + expect(finalFileIds[0]).toBe(fileId); + } + }); + + test('should handle concurrent duplicate removals', async () => { + const agent = await createBasicAgent(); + const fileId = uuidv4(); + + // Add the file first + await addAgentResourceFile({ + agent_id: agent.id, + tool_resource: 'test_tool', + file_id: fileId, + }); + + // Concurrent removals of the same file + const removalPromises = Array.from({ length: 5 }).map(() => + removeAgentResourceFiles({ + agent_id: agent.id, + files: [{ tool_resource: 'test_tool', file_id: fileId }], + }), + ); + + await Promise.all(removalPromises); + + const updatedAgent = await Agent.findOne({ id: agent.id }); + // Check if the array is empty or the tool resource itself is removed + const fileIds = updatedAgent.tool_resources?.test_tool?.file_ids ?? []; + expect(fileIds).toHaveLength(0); + expect(fileIds).not.toContain(fileId); + }); + + test('should handle concurrent removals of different files', async () => { + const agent = await createBasicAgent(); + const fileIds = Array.from({ length: 10 }, () => uuidv4()); + + // Add all files first + await Promise.all( + fileIds.map((fileId) => + addAgentResourceFile({ + agent_id: agent.id, + tool_resource: 'test_tool', + file_id: fileId, + }), + ), + ); + + // Concurrently remove all files + const removalPromises = fileIds.map((fileId) => + removeAgentResourceFiles({ + agent_id: agent.id, + files: [{ tool_resource: 'test_tool', file_id: fileId }], + }), + ); + + await Promise.all(removalPromises); + + const updatedAgent = await Agent.findOne({ id: agent.id }); + // Check if the array is empty or the tool resource itself is removed + const finalFileIds = updatedAgent.tool_resources?.test_tool?.file_ids ?? []; + expect(finalFileIds).toHaveLength(0); + }); +}); diff --git a/api/models/Assistant.js b/api/models/Assistant.js index d0e73ad4e7..a8a5b98157 100644 --- a/api/models/Assistant.js +++ b/api/models/Assistant.js @@ -1,5 +1,5 @@ const mongoose = require('mongoose'); -const assistantSchema = require('./schema/assistant'); +const { assistantSchema } = require('@librechat/data-schemas'); const Assistant = mongoose.model('assistant', assistantSchema); diff --git a/api/models/Balance.js b/api/models/Balance.js index 24d9087b77..226f6ef508 100644 --- a/api/models/Balance.js +++ b/api/models/Balance.js @@ -1,44 +1,4 @@ const mongoose = require('mongoose'); -const balanceSchema = require('./schema/balance'); -const { getMultiplier } = require('./tx'); -const { logger } = require('~/config'); - -balanceSchema.statics.check = async function ({ - user, - model, - endpoint, - valueKey, - tokenType, - amount, - endpointTokenConfig, -}) { - const multiplier = getMultiplier({ valueKey, tokenType, model, endpoint, endpointTokenConfig }); - const tokenCost = amount * multiplier; - const { tokenCredits: balance } = (await this.findOne({ user }, 'tokenCredits').lean()) ?? {}; - - logger.debug('[Balance.check]', { - user, - model, - endpoint, - valueKey, - tokenType, - amount, - balance, - multiplier, - endpointTokenConfig: !!endpointTokenConfig, - }); - - if (!balance) { - return { - canSpend: false, - balance: 0, - tokenCost, - }; - } - - logger.debug('[Balance.check]', { tokenCost }); - - return { canSpend: balance >= tokenCost, balance, tokenCost }; -}; +const { balanceSchema } = require('@librechat/data-schemas'); module.exports = mongoose.model('Balance', balanceSchema); diff --git a/api/models/Banner.js b/api/models/Banner.js index 8d439dae28..399a8e72ee 100644 --- a/api/models/Banner.js +++ b/api/models/Banner.js @@ -1,5 +1,9 @@ -const Banner = require('./schema/banner'); +const mongoose = require('mongoose'); const logger = require('~/config/winston'); +const { bannerSchema } = require('@librechat/data-schemas'); + +const Banner = mongoose.model('Banner', bannerSchema); + /** * Retrieves the current active banner. * @returns {Promise} The active banner object or null if no active banner is found. @@ -24,4 +28,4 @@ const getBanner = async (user) => { } }; -module.exports = { getBanner }; +module.exports = { Banner, getBanner }; diff --git a/api/models/Categories.js b/api/models/Categories.js index 0f7f29703f..5da1f4b2da 100644 --- a/api/models/Categories.js +++ b/api/models/Categories.js @@ -1,40 +1,40 @@ const { logger } = require('~/config'); -// const { Categories } = require('./schema/categories'); + const options = [ { - label: 'idea', + label: 'com_ui_idea', value: 'idea', }, { - label: 'travel', + label: 'com_ui_travel', value: 'travel', }, { - label: 'teach_or_explain', + label: 'com_ui_teach_or_explain', value: 'teach_or_explain', }, { - label: 'write', + label: 'com_ui_write', value: 'write', }, { - label: 'shop', + label: 'com_ui_shop', value: 'shop', }, { - label: 'code', + label: 'com_ui_code', value: 'code', }, { - label: 'misc', + label: 'com_ui_misc', value: 'misc', }, { - label: 'roleplay', + label: 'com_ui_roleplay', value: 'roleplay', }, { - label: 'finance', + label: 'com_ui_finance', value: 'finance', }, ]; diff --git a/api/models/Conversation.js b/api/models/Conversation.js index d6365e99ce..51081a6491 100644 --- a/api/models/Conversation.js +++ b/api/models/Conversation.js @@ -15,19 +15,6 @@ const searchConversation = async (conversationId) => { throw new Error('Error searching conversation'); } }; -/** - * Searches for a conversation by conversationId and returns associated file ids. - * @param {string} conversationId - The conversation's ID. - * @returns {Promise} - */ -const getConvoFiles = async (conversationId) => { - try { - return (await Conversation.findOne({ conversationId }, 'files').lean())?.files ?? []; - } catch (error) { - logger.error('[getConvoFiles] Error getting conversation files', error); - throw new Error('Error getting conversation files'); - } -}; /** * Retrieves a single conversation for a given user and conversation ID. @@ -73,6 +60,20 @@ const deleteNullOrEmptyConversations = async () => { } }; +/** + * Searches for a conversation by conversationId and returns associated file ids. + * @param {string} conversationId - The conversation's ID. + * @returns {Promise} + */ +const getConvoFiles = async (conversationId) => { + try { + return (await Conversation.findOne({ conversationId }, 'files').lean())?.files ?? []; + } catch (error) { + logger.error('[getConvoFiles] Error getting conversation files', error); + throw new Error('Error getting conversation files'); + } +}; + module.exports = { Conversation, getConvoFiles, @@ -87,11 +88,13 @@ module.exports = { */ saveConvo: async (req, { conversationId, newConversationId, ...convo }, metadata) => { try { - if (metadata && metadata?.context) { + if (metadata?.context) { logger.debug(`[saveConvo] ${metadata.context}`); } + const messages = await getMessages({ conversationId }, '_id'); const update = { ...convo, messages, user: req.user.id }; + if (newConversationId) { update.conversationId = newConversationId; } @@ -104,10 +107,16 @@ module.exports = { update.expiredAt = null; } + /** @type {{ $set: Partial; $unset?: Record }} */ + const updateOperation = { $set: update }; + if (metadata && metadata.unsetFields && Object.keys(metadata.unsetFields).length > 0) { + updateOperation.$unset = metadata.unsetFields; + } + /** Note: the resulting Model object is necessary for Meilisearch operations */ const conversation = await Conversation.findOneAndUpdate( { conversationId, user: req.user.id }, - update, + updateOperation, { new: true, upsert: true, @@ -141,75 +150,102 @@ module.exports = { throw new Error('Failed to save conversations in bulk.'); } }, - getConvosByPage: async (user, pageNumber = 1, pageSize = 25, isArchived = false, tags) => { - const query = { user }; + getConvosByCursor: async ( + user, + { cursor, limit = 25, isArchived = false, tags, search, order = 'desc' } = {}, + ) => { + const filters = [{ user }]; + if (isArchived) { - query.isArchived = true; + filters.push({ isArchived: true }); } else { - query.$or = [{ isArchived: false }, { isArchived: { $exists: false } }]; - } - if (Array.isArray(tags) && tags.length > 0) { - query.tags = { $in: tags }; + filters.push({ $or: [{ isArchived: false }, { isArchived: { $exists: false } }] }); } - query.$and = [{ $or: [{ expiredAt: null }, { expiredAt: { $exists: false } }] }]; + if (Array.isArray(tags) && tags.length > 0) { + filters.push({ tags: { $in: tags } }); + } + + filters.push({ $or: [{ expiredAt: null }, { expiredAt: { $exists: false } }] }); + + if (search) { + try { + const meiliResults = await Conversation.meiliSearch(search); + const matchingIds = Array.isArray(meiliResults.hits) + ? meiliResults.hits.map((result) => result.conversationId) + : []; + if (!matchingIds.length) { + return { conversations: [], nextCursor: null }; + } + filters.push({ conversationId: { $in: matchingIds } }); + } catch (error) { + logger.error('[getConvosByCursor] Error during meiliSearch', error); + return { message: 'Error during meiliSearch' }; + } + } + + if (cursor) { + filters.push({ updatedAt: { $lt: new Date(cursor) } }); + } + + const query = filters.length === 1 ? filters[0] : { $and: filters }; try { - const totalConvos = (await Conversation.countDocuments(query)) || 1; - const totalPages = Math.ceil(totalConvos / pageSize); const convos = await Conversation.find(query) - .sort({ updatedAt: -1 }) - .skip((pageNumber - 1) * pageSize) - .limit(pageSize) + .select( + 'conversationId endpoint title createdAt updatedAt user model agent_id assistant_id spec iconURL', + ) + .sort({ updatedAt: order === 'asc' ? 1 : -1 }) + .limit(limit + 1) .lean(); - return { conversations: convos, pages: totalPages, pageNumber, pageSize }; + + let nextCursor = null; + if (convos.length > limit) { + const lastConvo = convos.pop(); + nextCursor = lastConvo.updatedAt.toISOString(); + } + + return { conversations: convos, nextCursor }; } catch (error) { - logger.error('[getConvosByPage] Error getting conversations', error); + logger.error('[getConvosByCursor] Error getting conversations', error); return { message: 'Error getting conversations' }; } }, - getConvosQueried: async (user, convoIds, pageNumber = 1, pageSize = 25) => { + getConvosQueried: async (user, convoIds, cursor = null, limit = 25) => { try { - if (!convoIds || convoIds.length === 0) { - return { conversations: [], pages: 1, pageNumber, pageSize }; + if (!convoIds?.length) { + return { conversations: [], nextCursor: null, convoMap: {} }; + } + + const conversationIds = convoIds.map((convo) => convo.conversationId); + + const results = await Conversation.find({ + user, + conversationId: { $in: conversationIds }, + $or: [{ expiredAt: { $exists: false } }, { expiredAt: null }], + }).lean(); + + results.sort((a, b) => new Date(b.updatedAt) - new Date(a.updatedAt)); + + let filtered = results; + if (cursor && cursor !== 'start') { + const cursorDate = new Date(cursor); + filtered = results.filter((convo) => new Date(convo.updatedAt) < cursorDate); + } + + const limited = filtered.slice(0, limit + 1); + let nextCursor = null; + if (limited.length > limit) { + const lastConvo = limited.pop(); + nextCursor = lastConvo.updatedAt.toISOString(); } - const cache = {}; const convoMap = {}; - const promises = []; - - convoIds.forEach((convo) => - promises.push( - Conversation.findOne({ - user, - conversationId: convo.conversationId, - $or: [{ expiredAt: { $exists: false } }, { expiredAt: null }], - }).lean(), - ), - ); - - const results = (await Promise.all(promises)).filter(Boolean); - - results.forEach((convo, i) => { - const page = Math.floor(i / pageSize) + 1; - if (!cache[page]) { - cache[page] = []; - } - cache[page].push(convo); + limited.forEach((convo) => { convoMap[convo.conversationId] = convo; }); - const totalPages = Math.ceil(results.length / pageSize); - cache.pages = totalPages; - cache.pageSize = pageSize; - return { - cache, - conversations: cache[pageNumber] || [], - pages: totalPages || 1, - pageNumber, - pageSize, - convoMap, - }; + return { conversations: limited, nextCursor, convoMap }; } catch (error) { logger.error('[getConvosQueried] Error getting conversations', error); return { message: 'Error fetching conversations' }; @@ -250,10 +286,26 @@ module.exports = { * logger.error(result); // { n: 5, ok: 1, deletedCount: 5, messages: { n: 10, ok: 1, deletedCount: 10 } } */ deleteConvos: async (user, filter) => { - let toRemove = await Conversation.find({ ...filter, user }).select('conversationId'); - const ids = toRemove.map((instance) => instance.conversationId); - let deleteCount = await Conversation.deleteMany({ ...filter, user }); - deleteCount.messages = await deleteMessages({ conversationId: { $in: ids } }); - return deleteCount; + try { + const userFilter = { ...filter, user }; + + const conversations = await Conversation.find(userFilter).select('conversationId'); + const conversationIds = conversations.map((c) => c.conversationId); + + if (!conversationIds.length) { + throw new Error('Conversation not found or already deleted.'); + } + + const deleteConvoResult = await Conversation.deleteMany(userFilter); + + const deleteMessagesResult = await deleteMessages({ + conversationId: { $in: conversationIds }, + }); + + return { ...deleteConvoResult, messages: deleteMessagesResult }; + } catch (error) { + logger.error('[deleteConvos] Error deleting conversations and messages', error); + throw error; + } }, }; diff --git a/api/models/ConversationTag.js b/api/models/ConversationTag.js index 53d144e1f5..f0cac8620e 100644 --- a/api/models/ConversationTag.js +++ b/api/models/ConversationTag.js @@ -1,7 +1,11 @@ -const ConversationTag = require('./schema/conversationTagSchema'); +const mongoose = require('mongoose'); const Conversation = require('./schema/convoSchema'); const logger = require('~/config/winston'); +const { conversationTagSchema } = require('@librechat/data-schemas'); + +const ConversationTag = mongoose.model('ConversationTag', conversationTagSchema); + /** * Retrieves all conversation tags for a user. * @param {string} user - The user ID. diff --git a/api/models/File.js b/api/models/File.js index 17f8506600..4d94994478 100644 --- a/api/models/File.js +++ b/api/models/File.js @@ -1,5 +1,7 @@ const mongoose = require('mongoose'); -const fileSchema = require('./schema/fileSchema'); +const { EToolResources } = require('librechat-data-provider'); +const { fileSchema } = require('@librechat/data-schemas'); +const { logger } = require('~/config'); const File = mongoose.model('File', fileSchema); @@ -17,11 +19,50 @@ const findFileById = async (file_id, options = {}) => { * Retrieves files matching a given filter, sorted by the most recently updated. * @param {Object} filter - The filter criteria to apply. * @param {Object} [_sortOptions] - Optional sort parameters. + * @param {Object|String} [selectFields={ text: 0 }] - Fields to include/exclude in the query results. + * Default excludes the 'text' field. * @returns {Promise>} A promise that resolves to an array of file documents. */ -const getFiles = async (filter, _sortOptions) => { +const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }) => { const sortOptions = { updatedAt: -1, ..._sortOptions }; - return await File.find(filter).sort(sortOptions).lean(); + return await File.find(filter).select(selectFields).sort(sortOptions).lean(); +}; + +/** + * Retrieves tool files (files that are embedded or have a fileIdentifier) from an array of file IDs + * @param {string[]} fileIds - Array of file_id strings to search for + * @param {Set} toolResourceSet - Optional filter for tool resources + * @returns {Promise>} Files that match the criteria + */ +const getToolFilesByIds = async (fileIds, toolResourceSet) => { + if (!fileIds || !fileIds.length) { + return []; + } + + try { + const filter = { + file_id: { $in: fileIds }, + }; + + if (toolResourceSet.size) { + filter.$or = []; + } + + if (toolResourceSet.has(EToolResources.file_search)) { + filter.$or.push({ embedded: true }); + } + if (toolResourceSet.has(EToolResources.execute_code)) { + filter.$or.push({ 'metadata.fileIdentifier': { $exists: true } }); + } + + const selectFields = { text: 0 }; + const sortOptions = { updatedAt: -1 }; + + return await getFiles(filter, sortOptions, selectFields); + } catch (error) { + logger.error('[getToolFilesByIds] Error retrieving tool files:', error); + throw new Error('Error retrieving tool files'); + } }; /** @@ -105,14 +146,38 @@ const deleteFiles = async (file_ids, user) => { return await File.deleteMany(deleteQuery); }; +/** + * Batch updates files with new signed URLs in MongoDB + * + * @param {MongoFile[]} updates - Array of updates in the format { file_id, filepath } + * @returns {Promise} + */ +async function batchUpdateFiles(updates) { + if (!updates || updates.length === 0) { + return; + } + + const bulkOperations = updates.map((update) => ({ + updateOne: { + filter: { file_id: update.file_id }, + update: { $set: { filepath: update.filepath } }, + }, + })); + + const result = await File.bulkWrite(bulkOperations); + logger.info(`Updated ${result.modifiedCount} files with new S3 URLs`); +} + module.exports = { File, findFileById, getFiles, + getToolFilesByIds, createFile, updateFile, updateFileUsage, deleteFile, deleteFiles, deleteFileByFilter, + batchUpdateFiles, }; diff --git a/api/models/Key.js b/api/models/Key.js index 58fb0ac3a9..c69c350a42 100644 --- a/api/models/Key.js +++ b/api/models/Key.js @@ -1,4 +1,4 @@ const mongoose = require('mongoose'); -const keySchema = require('./schema/key'); +const { keySchema } = require('@librechat/data-schemas'); module.exports = mongoose.model('Key', keySchema); diff --git a/api/models/Message.js b/api/models/Message.js index e651b20ad0..86fd2fd549 100644 --- a/api/models/Message.js +++ b/api/models/Message.js @@ -61,6 +61,14 @@ async function saveMessage(req, params, metadata) { update.expiredAt = null; } + if (update.tokenCount != null && isNaN(update.tokenCount)) { + logger.warn( + `Resetting invalid \`tokenCount\` for message \`${params.messageId}\`: ${update.tokenCount}`, + ); + logger.info(`---\`saveMessage\` context: ${metadata?.context}`); + update.tokenCount = 0; + } + const message = await Message.findOneAndUpdate( { messageId: params.messageId, user: req.user.id }, update, @@ -71,7 +79,44 @@ async function saveMessage(req, params, metadata) { } catch (err) { logger.error('Error saving message:', err); logger.info(`---\`saveMessage\` context: ${metadata?.context}`); - throw err; + + // Check if this is a duplicate key error (MongoDB error code 11000) + if (err.code === 11000 && err.message.includes('duplicate key error')) { + // Log the duplicate key error but don't crash the application + logger.warn(`Duplicate messageId detected: ${params.messageId}. Continuing execution.`); + + try { + // Try to find the existing message with this ID + const existingMessage = await Message.findOne({ + messageId: params.messageId, + user: req.user.id, + }); + + // If we found it, return it + if (existingMessage) { + return existingMessage.toObject(); + } + + // If we can't find it (unlikely but possible in race conditions) + return { + ...params, + messageId: params.messageId, + user: req.user.id, + }; + } catch (findError) { + // If the findOne also fails, log it but don't crash + logger.warn( + `Could not retrieve existing message with ID ${params.messageId}: ${findError.message}`, + ); + return { + ...params, + messageId: params.messageId, + user: req.user.id, + }; + } + } + + throw err; // Re-throw other errors } } diff --git a/api/models/Project.js b/api/models/Project.js index 17ef3093a5..43d7263723 100644 --- a/api/models/Project.js +++ b/api/models/Project.js @@ -1,6 +1,6 @@ const { model } = require('mongoose'); const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants; -const projectSchema = require('~/models/schema/projectSchema'); +const { projectSchema } = require('@librechat/data-schemas'); const Project = model('Project', projectSchema); @@ -9,7 +9,7 @@ const Project = model('Project', projectSchema); * * @param {string} projectId - The ID of the project to find and return as a plain object. * @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document. - * @returns {Promise} A plain object representing the project document, or `null` if no project is found. + * @returns {Promise} A plain object representing the project document, or `null` if no project is found. */ const getProjectById = async function (projectId, fieldsToSelect = null) { const query = Project.findById(projectId); @@ -27,7 +27,7 @@ const getProjectById = async function (projectId, fieldsToSelect = null) { * * @param {string} projectName - The name of the project to find or create. * @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document. - * @returns {Promise} A plain object representing the project document. + * @returns {Promise} A plain object representing the project document. */ const getProjectByName = async function (projectName, fieldsToSelect = null) { const query = { name: projectName }; @@ -47,7 +47,7 @@ const getProjectByName = async function (projectName, fieldsToSelect = null) { * * @param {string} projectId - The ID of the project to update. * @param {string[]} promptGroupIds - The array of prompt group IDs to add to the project. - * @returns {Promise} The updated project document. + * @returns {Promise} The updated project document. */ const addGroupIdsToProject = async function (projectId, promptGroupIds) { return await Project.findByIdAndUpdate( @@ -62,7 +62,7 @@ const addGroupIdsToProject = async function (projectId, promptGroupIds) { * * @param {string} projectId - The ID of the project to update. * @param {string[]} promptGroupIds - The array of prompt group IDs to remove from the project. - * @returns {Promise} The updated project document. + * @returns {Promise} The updated project document. */ const removeGroupIdsFromProject = async function (projectId, promptGroupIds) { return await Project.findByIdAndUpdate( @@ -87,7 +87,7 @@ const removeGroupFromAllProjects = async (promptGroupId) => { * * @param {string} projectId - The ID of the project to update. * @param {string[]} agentIds - The array of agent IDs to add to the project. - * @returns {Promise} The updated project document. + * @returns {Promise} The updated project document. */ const addAgentIdsToProject = async function (projectId, agentIds) { return await Project.findByIdAndUpdate( @@ -102,7 +102,7 @@ const addAgentIdsToProject = async function (projectId, agentIds) { * * @param {string} projectId - The ID of the project to update. * @param {string[]} agentIds - The array of agent IDs to remove from the project. - * @returns {Promise} The updated project document. + * @returns {Promise} The updated project document. */ const removeAgentIdsFromProject = async function (projectId, agentIds) { return await Project.findByIdAndUpdate( diff --git a/api/models/Prompt.js b/api/models/Prompt.js index 60456884a8..43dc3ec22b 100644 --- a/api/models/Prompt.js +++ b/api/models/Prompt.js @@ -1,3 +1,4 @@ +const mongoose = require('mongoose'); const { ObjectId } = require('mongodb'); const { SystemRoles, SystemCategories, Constants } = require('librechat-data-provider'); const { @@ -6,10 +7,13 @@ const { removeGroupIdsFromProject, removeGroupFromAllProjects, } = require('./Project'); -const { Prompt, PromptGroup } = require('./schema/promptSchema'); +const { promptGroupSchema, promptSchema } = require('@librechat/data-schemas'); const { escapeRegExp } = require('~/server/utils'); const { logger } = require('~/config'); +const PromptGroup = mongoose.model('PromptGroup', promptGroupSchema); +const Prompt = mongoose.model('Prompt', promptSchema); + /** * Create a pipeline for the aggregation to get prompt groups * @param {Object} query diff --git a/api/models/Role.js b/api/models/Role.js index 9c160512b7..07bf5a2ccb 100644 --- a/api/models/Role.js +++ b/api/models/Role.js @@ -1,29 +1,30 @@ +const mongoose = require('mongoose'); const { CacheKeys, SystemRoles, roleDefaults, PermissionTypes, + permissionsSchema, removeNullishValues, - agentPermissionsSchema, - promptPermissionsSchema, - bookmarkPermissionsSchema, - multiConvoPermissionsSchema, } = require('librechat-data-provider'); const getLogStores = require('~/cache/getLogStores'); -const Role = require('~/models/schema/roleSchema'); +const { roleSchema } = require('@librechat/data-schemas'); const { logger } = require('~/config'); +const Role = mongoose.model('Role', roleSchema); + /** * Retrieve a role by name and convert the found role document to a plain object. - * If the role with the given name doesn't exist and the name is a system defined role, create it and return the lean version. + * If the role with the given name doesn't exist and the name is a system defined role, + * create it and return the lean version. * * @param {string} roleName - The name of the role to find or create. * @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document. * @returns {Promise} A plain object representing the role document. */ const getRoleByName = async function (roleName, fieldsToSelect = null) { + const cache = getLogStores(CacheKeys.ROLES); try { - const cache = getLogStores(CacheKeys.ROLES); const cachedRole = await cache.get(roleName); if (cachedRole) { return cachedRole; @@ -35,8 +36,7 @@ const getRoleByName = async function (roleName, fieldsToSelect = null) { let role = await query.lean().exec(); if (!role && SystemRoles[roleName]) { - role = roleDefaults[roleName]; - role = await new Role(role).save(); + role = await new Role(roleDefaults[roleName]).save(); await cache.set(roleName, role); return role.toObject(); } @@ -55,8 +55,8 @@ const getRoleByName = async function (roleName, fieldsToSelect = null) { * @returns {Promise} Updated role document. */ const updateRoleByName = async function (roleName, updates) { + const cache = getLogStores(CacheKeys.ROLES); try { - const cache = getLogStores(CacheKeys.ROLES); const role = await Role.findOneAndUpdate( { name: roleName }, { $set: updates }, @@ -72,27 +72,20 @@ const updateRoleByName = async function (roleName, updates) { } }; -const permissionSchemas = { - [PermissionTypes.AGENTS]: agentPermissionsSchema, - [PermissionTypes.PROMPTS]: promptPermissionsSchema, - [PermissionTypes.BOOKMARKS]: bookmarkPermissionsSchema, - [PermissionTypes.MULTI_CONVO]: multiConvoPermissionsSchema, -}; - /** * Updates access permissions for a specific role and multiple permission types. - * @param {SystemRoles} roleName - The role to update. + * @param {string} roleName - The role to update. * @param {Object.>} permissionsUpdate - Permissions to update and their values. */ async function updateAccessPermissions(roleName, permissionsUpdate) { + // Filter and clean the permission updates based on our schema definition. const updates = {}; for (const [permissionType, permissions] of Object.entries(permissionsUpdate)) { - if (permissionSchemas[permissionType]) { + if (permissionsSchema.shape && permissionsSchema.shape[permissionType]) { updates[permissionType] = removeNullishValues(permissions); } } - - if (Object.keys(updates).length === 0) { + if (!Object.keys(updates).length) { return; } @@ -102,26 +95,75 @@ async function updateAccessPermissions(roleName, permissionsUpdate) { return; } - const updatedPermissions = {}; + const currentPermissions = role.permissions || {}; + const updatedPermissions = { ...currentPermissions }; let hasChanges = false; + const unsetFields = {}; + const permissionTypes = Object.keys(permissionsSchema.shape || {}); + for (const permType of permissionTypes) { + if (role[permType] && typeof role[permType] === 'object') { + logger.info( + `Migrating '${roleName}' role from old schema: found '${permType}' at top level`, + ); + + updatedPermissions[permType] = { + ...updatedPermissions[permType], + ...role[permType], + }; + + unsetFields[permType] = 1; + hasChanges = true; + } + } + + // Process the current updates for (const [permissionType, permissions] of Object.entries(updates)) { - const currentPermissions = role[permissionType] || {}; - updatedPermissions[permissionType] = { ...currentPermissions }; + const currentTypePermissions = currentPermissions[permissionType] || {}; + updatedPermissions[permissionType] = { ...currentTypePermissions }; for (const [permission, value] of Object.entries(permissions)) { - if (currentPermissions[permission] !== value) { + if (currentTypePermissions[permission] !== value) { updatedPermissions[permissionType][permission] = value; hasChanges = true; logger.info( - `Updating '${roleName}' role ${permissionType} '${permission}' permission from ${currentPermissions[permission]} to: ${value}`, + `Updating '${roleName}' role permission '${permissionType}' '${permission}' from ${currentTypePermissions[permission]} to: ${value}`, ); } } } if (hasChanges) { - await updateRoleByName(roleName, updatedPermissions); + const updateObj = { permissions: updatedPermissions }; + + if (Object.keys(unsetFields).length > 0) { + logger.info( + `Unsetting old schema fields for '${roleName}' role: ${Object.keys(unsetFields).join(', ')}`, + ); + + try { + await Role.updateOne( + { name: roleName }, + { + $set: updateObj, + $unset: unsetFields, + }, + ); + + const cache = getLogStores(CacheKeys.ROLES); + const updatedRole = await Role.findOne({ name: roleName }).select('-__v').lean().exec(); + await cache.set(roleName, updatedRole); + + logger.info(`Updated role '${roleName}' and removed old schema fields`); + } catch (updateError) { + logger.error(`Error during role migration update: ${updateError.message}`); + throw updateError; + } + } else { + // Standard update if no migration needed + await updateRoleByName(roleName, updateObj); + } + logger.info(`Updated '${roleName}' role permissions`); } else { logger.info(`No changes needed for '${roleName}' role permissions`); @@ -139,33 +181,111 @@ async function updateAccessPermissions(roleName, permissionsUpdate) { * @returns {Promise} */ const initializeRoles = async function () { - const defaultRoles = [SystemRoles.ADMIN, SystemRoles.USER]; - - for (const roleName of defaultRoles) { + for (const roleName of [SystemRoles.ADMIN, SystemRoles.USER]) { let role = await Role.findOne({ name: roleName }); + const defaultPerms = roleDefaults[roleName].permissions; if (!role) { - // Create new role if it doesn't exist + // Create new role if it doesn't exist. role = new Role(roleDefaults[roleName]); } else { - // Add missing permission types - let isUpdated = false; - for (const permType of Object.values(PermissionTypes)) { - if (!role[permType]) { - role[permType] = roleDefaults[roleName][permType]; - isUpdated = true; + // Ensure role.permissions is defined. + role.permissions = role.permissions || {}; + // For each permission type in defaults, add it if missing. + for (const permType of Object.keys(defaultPerms)) { + if (role.permissions[permType] == null) { + role.permissions[permType] = defaultPerms[permType]; } } - if (isUpdated) { - await role.save(); - } } await role.save(); } }; + +/** + * Migrates roles from old schema to new schema structure. + * This can be called directly to fix existing roles. + * + * @param {string} [roleName] - Optional specific role to migrate. If not provided, migrates all roles. + * @returns {Promise} Number of roles migrated. + */ +const migrateRoleSchema = async function (roleName) { + try { + // Get roles to migrate + let roles; + if (roleName) { + const role = await Role.findOne({ name: roleName }); + roles = role ? [role] : []; + } else { + roles = await Role.find({}); + } + + logger.info(`Migrating ${roles.length} roles to new schema structure`); + let migratedCount = 0; + + for (const role of roles) { + const permissionTypes = Object.keys(permissionsSchema.shape || {}); + const unsetFields = {}; + let hasOldSchema = false; + + // Check for old schema fields + for (const permType of permissionTypes) { + if (role[permType] && typeof role[permType] === 'object') { + hasOldSchema = true; + + // Ensure permissions object exists + role.permissions = role.permissions || {}; + + // Migrate permissions from old location to new + role.permissions[permType] = { + ...role.permissions[permType], + ...role[permType], + }; + + // Mark field for removal + unsetFields[permType] = 1; + } + } + + if (hasOldSchema) { + try { + logger.info(`Migrating role '${role.name}' from old schema structure`); + + // Simple update operation + await Role.updateOne( + { _id: role._id }, + { + $set: { permissions: role.permissions }, + $unset: unsetFields, + }, + ); + + // Refresh cache + const cache = getLogStores(CacheKeys.ROLES); + const updatedRole = await Role.findById(role._id).lean().exec(); + await cache.set(role.name, updatedRole); + + migratedCount++; + logger.info(`Migrated role '${role.name}'`); + } catch (error) { + logger.error(`Failed to migrate role '${role.name}': ${error.message}`); + } + } + } + + logger.info(`Migration complete: ${migratedCount} roles migrated`); + return migratedCount; + } catch (error) { + logger.error(`Role schema migration failed: ${error.message}`); + throw error; + } +}; + module.exports = { + Role, getRoleByName, initializeRoles, updateRoleByName, updateAccessPermissions, + migrateRoleSchema, }; diff --git a/api/models/Role.spec.js b/api/models/Role.spec.js index 92386f0fa9..a8b60801ca 100644 --- a/api/models/Role.spec.js +++ b/api/models/Role.spec.js @@ -2,22 +2,21 @@ const mongoose = require('mongoose'); const { MongoMemoryServer } = require('mongodb-memory-server'); const { SystemRoles, - PermissionTypes, - roleDefaults, Permissions, + roleDefaults, + PermissionTypes, } = require('librechat-data-provider'); -const { updateAccessPermissions, initializeRoles } = require('~/models/Role'); +const { Role, getRoleByName, updateAccessPermissions, initializeRoles } = require('~/models/Role'); const getLogStores = require('~/cache/getLogStores'); -const Role = require('~/models/schema/roleSchema'); // Mock the cache -jest.mock('~/cache/getLogStores', () => { - return jest.fn().mockReturnValue({ +jest.mock('~/cache/getLogStores', () => + jest.fn().mockReturnValue({ get: jest.fn(), set: jest.fn(), del: jest.fn(), - }); -}); + }), +); let mongoServer; @@ -41,10 +40,12 @@ describe('updateAccessPermissions', () => { it('should update permissions when changes are needed', async () => { await new Role({ name: SystemRoles.USER, - [PermissionTypes.PROMPTS]: { - CREATE: true, - USE: true, - SHARED_GLOBAL: false, + permissions: { + [PermissionTypes.PROMPTS]: { + CREATE: true, + USE: true, + SHARED_GLOBAL: false, + }, }, }).save(); @@ -56,8 +57,8 @@ describe('updateAccessPermissions', () => { }, }); - const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean(); - expect(updatedRole[PermissionTypes.PROMPTS]).toEqual({ + const updatedRole = await getRoleByName(SystemRoles.USER); + expect(updatedRole.permissions[PermissionTypes.PROMPTS]).toEqual({ CREATE: true, USE: true, SHARED_GLOBAL: true, @@ -67,10 +68,12 @@ describe('updateAccessPermissions', () => { it('should not update permissions when no changes are needed', async () => { await new Role({ name: SystemRoles.USER, - [PermissionTypes.PROMPTS]: { - CREATE: true, - USE: true, - SHARED_GLOBAL: false, + permissions: { + [PermissionTypes.PROMPTS]: { + CREATE: true, + USE: true, + SHARED_GLOBAL: false, + }, }, }).save(); @@ -82,8 +85,8 @@ describe('updateAccessPermissions', () => { }, }); - const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean(); - expect(updatedRole[PermissionTypes.PROMPTS]).toEqual({ + const updatedRole = await getRoleByName(SystemRoles.USER); + expect(updatedRole.permissions[PermissionTypes.PROMPTS]).toEqual({ CREATE: true, USE: true, SHARED_GLOBAL: false, @@ -92,11 +95,8 @@ describe('updateAccessPermissions', () => { it('should handle non-existent roles', async () => { await updateAccessPermissions('NON_EXISTENT_ROLE', { - [PermissionTypes.PROMPTS]: { - CREATE: true, - }, + [PermissionTypes.PROMPTS]: { CREATE: true }, }); - const role = await Role.findOne({ name: 'NON_EXISTENT_ROLE' }); expect(role).toBeNull(); }); @@ -104,21 +104,21 @@ describe('updateAccessPermissions', () => { it('should update only specified permissions', async () => { await new Role({ name: SystemRoles.USER, - [PermissionTypes.PROMPTS]: { - CREATE: true, - USE: true, - SHARED_GLOBAL: false, + permissions: { + [PermissionTypes.PROMPTS]: { + CREATE: true, + USE: true, + SHARED_GLOBAL: false, + }, }, }).save(); await updateAccessPermissions(SystemRoles.USER, { - [PermissionTypes.PROMPTS]: { - SHARED_GLOBAL: true, - }, + [PermissionTypes.PROMPTS]: { SHARED_GLOBAL: true }, }); - const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean(); - expect(updatedRole[PermissionTypes.PROMPTS]).toEqual({ + const updatedRole = await getRoleByName(SystemRoles.USER); + expect(updatedRole.permissions[PermissionTypes.PROMPTS]).toEqual({ CREATE: true, USE: true, SHARED_GLOBAL: true, @@ -128,21 +128,21 @@ describe('updateAccessPermissions', () => { it('should handle partial updates', async () => { await new Role({ name: SystemRoles.USER, - [PermissionTypes.PROMPTS]: { - CREATE: true, - USE: true, - SHARED_GLOBAL: false, + permissions: { + [PermissionTypes.PROMPTS]: { + CREATE: true, + USE: true, + SHARED_GLOBAL: false, + }, }, }).save(); await updateAccessPermissions(SystemRoles.USER, { - [PermissionTypes.PROMPTS]: { - USE: false, - }, + [PermissionTypes.PROMPTS]: { USE: false }, }); - const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean(); - expect(updatedRole[PermissionTypes.PROMPTS]).toEqual({ + const updatedRole = await getRoleByName(SystemRoles.USER); + expect(updatedRole.permissions[PermissionTypes.PROMPTS]).toEqual({ CREATE: true, USE: false, SHARED_GLOBAL: false, @@ -152,13 +152,9 @@ describe('updateAccessPermissions', () => { it('should update multiple permission types at once', async () => { await new Role({ name: SystemRoles.USER, - [PermissionTypes.PROMPTS]: { - CREATE: true, - USE: true, - SHARED_GLOBAL: false, - }, - [PermissionTypes.BOOKMARKS]: { - USE: true, + permissions: { + [PermissionTypes.PROMPTS]: { CREATE: true, USE: true, SHARED_GLOBAL: false }, + [PermissionTypes.BOOKMARKS]: { USE: true }, }, }).save(); @@ -167,24 +163,20 @@ describe('updateAccessPermissions', () => { [PermissionTypes.BOOKMARKS]: { USE: false }, }); - const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean(); - expect(updatedRole[PermissionTypes.PROMPTS]).toEqual({ + const updatedRole = await getRoleByName(SystemRoles.USER); + expect(updatedRole.permissions[PermissionTypes.PROMPTS]).toEqual({ CREATE: true, USE: false, SHARED_GLOBAL: true, }); - expect(updatedRole[PermissionTypes.BOOKMARKS]).toEqual({ - USE: false, - }); + expect(updatedRole.permissions[PermissionTypes.BOOKMARKS]).toEqual({ USE: false }); }); it('should handle updates for a single permission type', async () => { await new Role({ name: SystemRoles.USER, - [PermissionTypes.PROMPTS]: { - CREATE: true, - USE: true, - SHARED_GLOBAL: false, + permissions: { + [PermissionTypes.PROMPTS]: { CREATE: true, USE: true, SHARED_GLOBAL: false }, }, }).save(); @@ -192,8 +184,8 @@ describe('updateAccessPermissions', () => { [PermissionTypes.PROMPTS]: { USE: false, SHARED_GLOBAL: true }, }); - const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean(); - expect(updatedRole[PermissionTypes.PROMPTS]).toEqual({ + const updatedRole = await getRoleByName(SystemRoles.USER); + expect(updatedRole.permissions[PermissionTypes.PROMPTS]).toEqual({ CREATE: true, USE: false, SHARED_GLOBAL: true, @@ -203,33 +195,25 @@ describe('updateAccessPermissions', () => { it('should update MULTI_CONVO permissions', async () => { await new Role({ name: SystemRoles.USER, - [PermissionTypes.MULTI_CONVO]: { - USE: false, + permissions: { + [PermissionTypes.MULTI_CONVO]: { USE: false }, }, }).save(); await updateAccessPermissions(SystemRoles.USER, { - [PermissionTypes.MULTI_CONVO]: { - USE: true, - }, + [PermissionTypes.MULTI_CONVO]: { USE: true }, }); - const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean(); - expect(updatedRole[PermissionTypes.MULTI_CONVO]).toEqual({ - USE: true, - }); + const updatedRole = await getRoleByName(SystemRoles.USER); + expect(updatedRole.permissions[PermissionTypes.MULTI_CONVO]).toEqual({ USE: true }); }); it('should update MULTI_CONVO permissions along with other permission types', async () => { await new Role({ name: SystemRoles.USER, - [PermissionTypes.PROMPTS]: { - CREATE: true, - USE: true, - SHARED_GLOBAL: false, - }, - [PermissionTypes.MULTI_CONVO]: { - USE: false, + permissions: { + [PermissionTypes.PROMPTS]: { CREATE: true, USE: true, SHARED_GLOBAL: false }, + [PermissionTypes.MULTI_CONVO]: { USE: false }, }, }).save(); @@ -238,35 +222,29 @@ describe('updateAccessPermissions', () => { [PermissionTypes.MULTI_CONVO]: { USE: true }, }); - const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean(); - expect(updatedRole[PermissionTypes.PROMPTS]).toEqual({ + const updatedRole = await getRoleByName(SystemRoles.USER); + expect(updatedRole.permissions[PermissionTypes.PROMPTS]).toEqual({ CREATE: true, USE: true, SHARED_GLOBAL: true, }); - expect(updatedRole[PermissionTypes.MULTI_CONVO]).toEqual({ - USE: true, - }); + expect(updatedRole.permissions[PermissionTypes.MULTI_CONVO]).toEqual({ USE: true }); }); it('should not update MULTI_CONVO permissions when no changes are needed', async () => { await new Role({ name: SystemRoles.USER, - [PermissionTypes.MULTI_CONVO]: { - USE: true, + permissions: { + [PermissionTypes.MULTI_CONVO]: { USE: true }, }, }).save(); await updateAccessPermissions(SystemRoles.USER, { - [PermissionTypes.MULTI_CONVO]: { - USE: true, - }, + [PermissionTypes.MULTI_CONVO]: { USE: true }, }); - const updatedRole = await Role.findOne({ name: SystemRoles.USER }).lean(); - expect(updatedRole[PermissionTypes.MULTI_CONVO]).toEqual({ - USE: true, - }); + const updatedRole = await getRoleByName(SystemRoles.USER); + expect(updatedRole.permissions[PermissionTypes.MULTI_CONVO]).toEqual({ USE: true }); }); }); @@ -278,65 +256,69 @@ describe('initializeRoles', () => { it('should create default roles if they do not exist', async () => { await initializeRoles(); - const adminRole = await Role.findOne({ name: SystemRoles.ADMIN }).lean(); - const userRole = await Role.findOne({ name: SystemRoles.USER }).lean(); + const adminRole = await getRoleByName(SystemRoles.ADMIN); + const userRole = await getRoleByName(SystemRoles.USER); expect(adminRole).toBeTruthy(); expect(userRole).toBeTruthy(); - // Check if all permission types exist + // Check if all permission types exist in the permissions field Object.values(PermissionTypes).forEach((permType) => { - expect(adminRole[permType]).toBeDefined(); - expect(userRole[permType]).toBeDefined(); + expect(adminRole.permissions[permType]).toBeDefined(); + expect(userRole.permissions[permType]).toBeDefined(); }); - // Check if permissions match defaults (example for ADMIN role) - expect(adminRole[PermissionTypes.PROMPTS].SHARED_GLOBAL).toBe(true); - expect(adminRole[PermissionTypes.BOOKMARKS].USE).toBe(true); - expect(adminRole[PermissionTypes.AGENTS].CREATE).toBe(true); + // Example: Check default values for ADMIN role + expect(adminRole.permissions[PermissionTypes.PROMPTS].SHARED_GLOBAL).toBe(true); + expect(adminRole.permissions[PermissionTypes.BOOKMARKS].USE).toBe(true); + expect(adminRole.permissions[PermissionTypes.AGENTS].CREATE).toBe(true); }); it('should not modify existing permissions for existing roles', async () => { const customUserRole = { name: SystemRoles.USER, - [PermissionTypes.PROMPTS]: { - [Permissions.USE]: false, - [Permissions.CREATE]: true, - [Permissions.SHARED_GLOBAL]: true, - }, - [PermissionTypes.BOOKMARKS]: { - [Permissions.USE]: false, + permissions: { + [PermissionTypes.PROMPTS]: { + [Permissions.USE]: false, + [Permissions.CREATE]: true, + [Permissions.SHARED_GLOBAL]: true, + }, + [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false }, }, }; await new Role(customUserRole).save(); - await initializeRoles(); - const userRole = await Role.findOne({ name: SystemRoles.USER }).lean(); - - expect(userRole[PermissionTypes.PROMPTS]).toEqual(customUserRole[PermissionTypes.PROMPTS]); - expect(userRole[PermissionTypes.BOOKMARKS]).toEqual(customUserRole[PermissionTypes.BOOKMARKS]); - expect(userRole[PermissionTypes.AGENTS]).toBeDefined(); + const userRole = await getRoleByName(SystemRoles.USER); + expect(userRole.permissions[PermissionTypes.PROMPTS]).toEqual( + customUserRole.permissions[PermissionTypes.PROMPTS], + ); + expect(userRole.permissions[PermissionTypes.BOOKMARKS]).toEqual( + customUserRole.permissions[PermissionTypes.BOOKMARKS], + ); + expect(userRole.permissions[PermissionTypes.AGENTS]).toBeDefined(); }); it('should add new permission types to existing roles', async () => { const partialUserRole = { name: SystemRoles.USER, - [PermissionTypes.PROMPTS]: roleDefaults[SystemRoles.USER][PermissionTypes.PROMPTS], - [PermissionTypes.BOOKMARKS]: roleDefaults[SystemRoles.USER][PermissionTypes.BOOKMARKS], + permissions: { + [PermissionTypes.PROMPTS]: + roleDefaults[SystemRoles.USER].permissions[PermissionTypes.PROMPTS], + [PermissionTypes.BOOKMARKS]: + roleDefaults[SystemRoles.USER].permissions[PermissionTypes.BOOKMARKS], + }, }; await new Role(partialUserRole).save(); - await initializeRoles(); - const userRole = await Role.findOne({ name: SystemRoles.USER }).lean(); - - expect(userRole[PermissionTypes.AGENTS]).toBeDefined(); - expect(userRole[PermissionTypes.AGENTS].CREATE).toBeDefined(); - expect(userRole[PermissionTypes.AGENTS].USE).toBeDefined(); - expect(userRole[PermissionTypes.AGENTS].SHARED_GLOBAL).toBeDefined(); + const userRole = await getRoleByName(SystemRoles.USER); + expect(userRole.permissions[PermissionTypes.AGENTS]).toBeDefined(); + expect(userRole.permissions[PermissionTypes.AGENTS].CREATE).toBeDefined(); + expect(userRole.permissions[PermissionTypes.AGENTS].USE).toBeDefined(); + expect(userRole.permissions[PermissionTypes.AGENTS].SHARED_GLOBAL).toBeDefined(); }); it('should handle multiple runs without duplicating or modifying data', async () => { @@ -349,72 +331,73 @@ describe('initializeRoles', () => { expect(adminRoles).toHaveLength(1); expect(userRoles).toHaveLength(1); - const adminRole = adminRoles[0].toObject(); - const userRole = userRoles[0].toObject(); - - // Check if all permission types exist + const adminPerms = adminRoles[0].toObject().permissions; + const userPerms = userRoles[0].toObject().permissions; Object.values(PermissionTypes).forEach((permType) => { - expect(adminRole[permType]).toBeDefined(); - expect(userRole[permType]).toBeDefined(); + expect(adminPerms[permType]).toBeDefined(); + expect(userPerms[permType]).toBeDefined(); }); }); it('should update roles with missing permission types from roleDefaults', async () => { const partialAdminRole = { name: SystemRoles.ADMIN, - [PermissionTypes.PROMPTS]: { - [Permissions.USE]: false, - [Permissions.CREATE]: false, - [Permissions.SHARED_GLOBAL]: false, + permissions: { + [PermissionTypes.PROMPTS]: { + [Permissions.USE]: false, + [Permissions.CREATE]: false, + [Permissions.SHARED_GLOBAL]: false, + }, + [PermissionTypes.BOOKMARKS]: + roleDefaults[SystemRoles.ADMIN].permissions[PermissionTypes.BOOKMARKS], }, - [PermissionTypes.BOOKMARKS]: roleDefaults[SystemRoles.ADMIN][PermissionTypes.BOOKMARKS], }; await new Role(partialAdminRole).save(); - await initializeRoles(); - const adminRole = await Role.findOne({ name: SystemRoles.ADMIN }).lean(); - - expect(adminRole[PermissionTypes.PROMPTS]).toEqual(partialAdminRole[PermissionTypes.PROMPTS]); - expect(adminRole[PermissionTypes.AGENTS]).toBeDefined(); - expect(adminRole[PermissionTypes.AGENTS].CREATE).toBeDefined(); - expect(adminRole[PermissionTypes.AGENTS].USE).toBeDefined(); - expect(adminRole[PermissionTypes.AGENTS].SHARED_GLOBAL).toBeDefined(); + const adminRole = await getRoleByName(SystemRoles.ADMIN); + expect(adminRole.permissions[PermissionTypes.PROMPTS]).toEqual( + partialAdminRole.permissions[PermissionTypes.PROMPTS], + ); + expect(adminRole.permissions[PermissionTypes.AGENTS]).toBeDefined(); + expect(adminRole.permissions[PermissionTypes.AGENTS].CREATE).toBeDefined(); + expect(adminRole.permissions[PermissionTypes.AGENTS].USE).toBeDefined(); + expect(adminRole.permissions[PermissionTypes.AGENTS].SHARED_GLOBAL).toBeDefined(); }); it('should include MULTI_CONVO permissions when creating default roles', async () => { await initializeRoles(); - const adminRole = await Role.findOne({ name: SystemRoles.ADMIN }).lean(); - const userRole = await Role.findOne({ name: SystemRoles.USER }).lean(); + const adminRole = await getRoleByName(SystemRoles.ADMIN); + const userRole = await getRoleByName(SystemRoles.USER); - expect(adminRole[PermissionTypes.MULTI_CONVO]).toBeDefined(); - expect(userRole[PermissionTypes.MULTI_CONVO]).toBeDefined(); - - // Check if MULTI_CONVO permissions match defaults - expect(adminRole[PermissionTypes.MULTI_CONVO].USE).toBe( - roleDefaults[SystemRoles.ADMIN][PermissionTypes.MULTI_CONVO].USE, + expect(adminRole.permissions[PermissionTypes.MULTI_CONVO]).toBeDefined(); + expect(userRole.permissions[PermissionTypes.MULTI_CONVO]).toBeDefined(); + expect(adminRole.permissions[PermissionTypes.MULTI_CONVO].USE).toBe( + roleDefaults[SystemRoles.ADMIN].permissions[PermissionTypes.MULTI_CONVO].USE, ); - expect(userRole[PermissionTypes.MULTI_CONVO].USE).toBe( - roleDefaults[SystemRoles.USER][PermissionTypes.MULTI_CONVO].USE, + expect(userRole.permissions[PermissionTypes.MULTI_CONVO].USE).toBe( + roleDefaults[SystemRoles.USER].permissions[PermissionTypes.MULTI_CONVO].USE, ); }); it('should add MULTI_CONVO permissions to existing roles without them', async () => { const partialUserRole = { name: SystemRoles.USER, - [PermissionTypes.PROMPTS]: roleDefaults[SystemRoles.USER][PermissionTypes.PROMPTS], - [PermissionTypes.BOOKMARKS]: roleDefaults[SystemRoles.USER][PermissionTypes.BOOKMARKS], + permissions: { + [PermissionTypes.PROMPTS]: + roleDefaults[SystemRoles.USER].permissions[PermissionTypes.PROMPTS], + [PermissionTypes.BOOKMARKS]: + roleDefaults[SystemRoles.USER].permissions[PermissionTypes.BOOKMARKS], + }, }; await new Role(partialUserRole).save(); - await initializeRoles(); - const userRole = await Role.findOne({ name: SystemRoles.USER }).lean(); - - expect(userRole[PermissionTypes.MULTI_CONVO]).toBeDefined(); - expect(userRole[PermissionTypes.MULTI_CONVO].USE).toBeDefined(); + const userRole = await getRoleByName(SystemRoles.USER); + expect(userRole.permissions[PermissionTypes.MULTI_CONVO]).toBeDefined(); + expect(userRole.permissions[PermissionTypes.MULTI_CONVO].USE).toBeDefined(); }); }); diff --git a/api/models/Session.js b/api/models/Session.js index dbb66ed8ff..38821b77dd 100644 --- a/api/models/Session.js +++ b/api/models/Session.js @@ -1,7 +1,7 @@ const mongoose = require('mongoose'); const signPayload = require('~/server/services/signPayload'); const { hashToken } = require('~/server/utils/crypto'); -const sessionSchema = require('./schema/session'); +const { sessionSchema } = require('@librechat/data-schemas'); const { logger } = require('~/config'); const Session = mongoose.model('Session', sessionSchema); diff --git a/api/models/Share.js b/api/models/Share.js index 041927ec61..8611d01bc0 100644 --- a/api/models/Share.js +++ b/api/models/Share.js @@ -1,7 +1,9 @@ +const mongoose = require('mongoose'); const { nanoid } = require('nanoid'); const { Constants } = require('librechat-data-provider'); const { Conversation } = require('~/models/Conversation'); -const SharedLink = require('./schema/shareSchema'); +const { shareSchema } = require('@librechat/data-schemas'); +const SharedLink = mongoose.model('SharedLink', shareSchema); const { getMessages } = require('./Message'); const logger = require('~/config/winston'); @@ -50,6 +52,14 @@ function anonymizeMessages(messages, newConvoId) { const newMessageId = anonymizeMessageId(message.messageId); idMap.set(message.messageId, newMessageId); + const anonymizedAttachments = message.attachments?.map((attachment) => { + return { + ...attachment, + messageId: newMessageId, + conversationId: newConvoId, + }; + }); + return { ...message, messageId: newMessageId, @@ -59,6 +69,7 @@ function anonymizeMessages(messages, newConvoId) { model: message.model?.startsWith('asst_') ? anonymizeAssistantId(message.model) : message.model, + attachments: anonymizedAttachments, }; }); } diff --git a/api/models/Token.js b/api/models/Token.js index 210666ddd7..c89abb8c84 100644 --- a/api/models/Token.js +++ b/api/models/Token.js @@ -1,6 +1,6 @@ const mongoose = require('mongoose'); const { encryptV2 } = require('~/server/utils/crypto'); -const tokenSchema = require('./schema/tokenSchema'); +const { tokenSchema } = require('@librechat/data-schemas'); const { logger } = require('~/config'); /** @@ -13,6 +13,13 @@ const Token = mongoose.model('Token', tokenSchema); */ async function fixIndexes() { try { + if ( + process.env.NODE_ENV === 'CI' || + process.env.NODE_ENV === 'development' || + process.env.NODE_ENV === 'test' + ) { + return; + } const indexes = await Token.collection.indexes(); logger.debug('Existing Token Indexes:', JSON.stringify(indexes, null, 2)); const unwantedTTLIndexes = indexes.filter( diff --git a/api/models/ToolCall.js b/api/models/ToolCall.js index e1d7b0cc84..7bc0f157dc 100644 --- a/api/models/ToolCall.js +++ b/api/models/ToolCall.js @@ -1,9 +1,11 @@ -const ToolCall = require('./schema/toolCallSchema'); +const mongoose = require('mongoose'); +const { toolCallSchema } = require('@librechat/data-schemas'); +const ToolCall = mongoose.model('ToolCall', toolCallSchema); /** * Create a new tool call - * @param {ToolCallData} toolCallData - The tool call data - * @returns {Promise} The created tool call document + * @param {IToolCallData} toolCallData - The tool call data + * @returns {Promise} The created tool call document */ async function createToolCall(toolCallData) { try { @@ -16,7 +18,7 @@ async function createToolCall(toolCallData) { /** * Get a tool call by ID * @param {string} id - The tool call document ID - * @returns {Promise} The tool call document or null if not found + * @returns {Promise} The tool call document or null if not found */ async function getToolCallById(id) { try { @@ -44,7 +46,7 @@ async function getToolCallsByMessage(messageId, userId) { * Get tool calls by conversation ID and user * @param {string} conversationId - The conversation ID * @param {string} userId - The user's ObjectId - * @returns {Promise} Array of tool call documents + * @returns {Promise} Array of tool call documents */ async function getToolCallsByConvo(conversationId, userId) { try { @@ -57,8 +59,8 @@ async function getToolCallsByConvo(conversationId, userId) { /** * Update a tool call * @param {string} id - The tool call document ID - * @param {Partial} updateData - The data to update - * @returns {Promise} The updated tool call document or null if not found + * @param {Partial} updateData - The data to update + * @returns {Promise} The updated tool call document or null if not found */ async function updateToolCall(id, updateData) { try { diff --git a/api/models/Transaction.js b/api/models/Transaction.js index 8435a812c4..e171241b61 100644 --- a/api/models/Transaction.js +++ b/api/models/Transaction.js @@ -1,11 +1,144 @@ const mongoose = require('mongoose'); -const { isEnabled } = require('~/server/utils/handleText'); -const transactionSchema = require('./schema/transaction'); +const { transactionSchema } = require('@librechat/data-schemas'); +const { getBalanceConfig } = require('~/server/services/Config'); const { getMultiplier, getCacheMultiplier } = require('./tx'); const { logger } = require('~/config'); const Balance = require('./Balance'); + const cancelRate = 1.15; +/** + * Updates a user's token balance based on a transaction using optimistic concurrency control + * without schema changes. Compatible with DocumentDB. + * @async + * @function + * @param {Object} params - The function parameters. + * @param {string|mongoose.Types.ObjectId} params.user - The user ID. + * @param {number} params.incrementValue - The value to increment the balance by (can be negative). + * @param {import('mongoose').UpdateQuery['$set']} [params.setValues] - Optional additional fields to set. + * @returns {Promise} Returns the updated balance document (lean). + * @throws {Error} Throws an error if the update fails after multiple retries. + */ +const updateBalance = async ({ user, incrementValue, setValues }) => { + let maxRetries = 10; // Number of times to retry on conflict + let delay = 50; // Initial retry delay in ms + let lastError = null; + + for (let attempt = 1; attempt <= maxRetries; attempt++) { + let currentBalanceDoc; + try { + // 1. Read the current document state + currentBalanceDoc = await Balance.findOne({ user }).lean(); + const currentCredits = currentBalanceDoc ? currentBalanceDoc.tokenCredits : 0; + + // 2. Calculate the desired new state + const potentialNewCredits = currentCredits + incrementValue; + const newCredits = Math.max(0, potentialNewCredits); // Ensure balance doesn't go below zero + + // 3. Prepare the update payload + const updatePayload = { + $set: { + tokenCredits: newCredits, + ...(setValues || {}), // Merge other values to set + }, + }; + + // 4. Attempt the conditional update or upsert + let updatedBalance = null; + if (currentBalanceDoc) { + // --- Document Exists: Perform Conditional Update --- + // Try to update only if the tokenCredits match the value we read (currentCredits) + updatedBalance = await Balance.findOneAndUpdate( + { + user: user, + tokenCredits: currentCredits, // Optimistic lock: condition based on the read value + }, + updatePayload, + { + new: true, // Return the modified document + // lean: true, // .lean() is applied after query execution in Mongoose >= 6 + }, + ).lean(); // Use lean() for plain JS object + + if (updatedBalance) { + // Success! The update was applied based on the expected current state. + return updatedBalance; + } + // If updatedBalance is null, it means tokenCredits changed between read and write (conflict). + lastError = new Error(`Concurrency conflict for user ${user} on attempt ${attempt}.`); + // Proceed to retry logic below. + } else { + // --- Document Does Not Exist: Perform Conditional Upsert --- + // Try to insert the document, but only if it still doesn't exist. + // Using tokenCredits: {$exists: false} helps prevent race conditions where + // another process creates the doc between our findOne and findOneAndUpdate. + try { + updatedBalance = await Balance.findOneAndUpdate( + { + user: user, + // Attempt to match only if the document doesn't exist OR was just created + // without tokenCredits (less likely but possible). A simple { user } filter + // might also work, relying on the retry for conflicts. + // Let's use a simpler filter and rely on retry for races. + // tokenCredits: { $exists: false } // This condition might be too strict if doc exists with 0 credits + }, + updatePayload, + { + upsert: true, // Create if doesn't exist + new: true, // Return the created/updated document + // setDefaultsOnInsert: true, // Ensure schema defaults are applied on insert + // lean: true, + }, + ).lean(); + + if (updatedBalance) { + // Upsert succeeded (likely created the document) + return updatedBalance; + } + // If null, potentially a rare race condition during upsert. Retry should handle it. + lastError = new Error( + `Upsert race condition suspected for user ${user} on attempt ${attempt}.`, + ); + } catch (error) { + if (error.code === 11000) { + // E11000 duplicate key error on index + // This means another process created the document *just* before our upsert. + // It's a concurrency conflict during creation. We should retry. + lastError = error; // Store the error + // Proceed to retry logic below. + } else { + // Different error, rethrow + throw error; + } + } + } // End if/else (document exists?) + } catch (error) { + // Catch errors from findOne or unexpected findOneAndUpdate errors + logger.error(`[updateBalance] Error during attempt ${attempt} for user ${user}:`, error); + lastError = error; // Store the error + // Consider stopping retries for non-transient errors, but for now, we retry. + } + + // If we reached here, it means the update failed (conflict or error), wait and retry + if (attempt < maxRetries) { + const jitter = Math.random() * delay * 0.5; // Add jitter to delay + await new Promise((resolve) => setTimeout(resolve, delay + jitter)); + delay = Math.min(delay * 2, 2000); // Exponential backoff with cap + } + } // End for loop (retries) + + // If loop finishes without success, throw the last encountered error or a generic one + logger.error( + `[updateBalance] Failed to update balance for user ${user} after ${maxRetries} attempts.`, + ); + throw ( + lastError || + new Error( + `Failed to update balance for user ${user} after maximum retries due to persistent conflicts.`, + ) + ); +}; + /** Method to calculate and set the tokenValue for a transaction */ transactionSchema.methods.calculateTokenValue = function () { if (!this.valueKey || !this.tokenType) { @@ -21,6 +154,39 @@ transactionSchema.methods.calculateTokenValue = function () { } }; +/** + * New static method to create an auto-refill transaction that does NOT trigger a balance update. + * @param {object} txData - Transaction data. + * @param {string} txData.user - The user ID. + * @param {string} txData.tokenType - The type of token. + * @param {string} txData.context - The context of the transaction. + * @param {number} txData.rawAmount - The raw amount of tokens. + * @returns {Promise} - The created transaction. + */ +transactionSchema.statics.createAutoRefillTransaction = async function (txData) { + if (txData.rawAmount != null && isNaN(txData.rawAmount)) { + return; + } + const transaction = new this(txData); + transaction.endpointTokenConfig = txData.endpointTokenConfig; + transaction.calculateTokenValue(); + await transaction.save(); + + const balanceResponse = await updateBalance({ + user: transaction.user, + incrementValue: txData.rawAmount, + setValues: { lastRefill: new Date() }, + }); + const result = { + rate: transaction.rate, + user: transaction.user.toString(), + balance: balanceResponse.tokenCredits, + }; + logger.debug('[Balance.check] Auto-refill performed', result); + result.transaction = transaction; + return result; +}; + /** * Static method to create a transaction and update the balance * @param {txData} txData - Transaction data. @@ -37,27 +203,22 @@ transactionSchema.statics.create = async function (txData) { await transaction.save(); - if (!isEnabled(process.env.CHECK_BALANCE)) { + const balance = await getBalanceConfig(); + if (!balance?.enabled) { return; } - let balance = await Balance.findOne({ user: transaction.user }).lean(); let incrementValue = transaction.tokenValue; - if (balance && balance?.tokenCredits + incrementValue < 0) { - incrementValue = -balance.tokenCredits; - } - - balance = await Balance.findOneAndUpdate( - { user: transaction.user }, - { $inc: { tokenCredits: incrementValue } }, - { upsert: true, new: true }, - ).lean(); + const balanceResponse = await updateBalance({ + user: transaction.user, + incrementValue, + }); return { rate: transaction.rate, user: transaction.user.toString(), - balance: balance.tokenCredits, + balance: balanceResponse.tokenCredits, [transaction.tokenType]: incrementValue, }; }; @@ -78,27 +239,22 @@ transactionSchema.statics.createStructured = async function (txData) { await transaction.save(); - if (!isEnabled(process.env.CHECK_BALANCE)) { + const balance = await getBalanceConfig(); + if (!balance?.enabled) { return; } - let balance = await Balance.findOne({ user: transaction.user }).lean(); let incrementValue = transaction.tokenValue; - if (balance && balance?.tokenCredits + incrementValue < 0) { - incrementValue = -balance.tokenCredits; - } - - balance = await Balance.findOneAndUpdate( - { user: transaction.user }, - { $inc: { tokenCredits: incrementValue } }, - { upsert: true, new: true }, - ).lean(); + const balanceResponse = await updateBalance({ + user: transaction.user, + incrementValue, + }); return { rate: transaction.rate, user: transaction.user.toString(), - balance: balance.tokenCredits, + balance: balanceResponse.tokenCredits, [transaction.tokenType]: incrementValue, }; }; diff --git a/api/models/Transaction.spec.js b/api/models/Transaction.spec.js index b8c69e13f4..43f3c004b2 100644 --- a/api/models/Transaction.spec.js +++ b/api/models/Transaction.spec.js @@ -1,9 +1,13 @@ const mongoose = require('mongoose'); const { MongoMemoryServer } = require('mongodb-memory-server'); +const { spendTokens, spendStructuredTokens } = require('./spendTokens'); +const { getBalanceConfig } = require('~/server/services/Config'); +const { getMultiplier, getCacheMultiplier } = require('./tx'); const { Transaction } = require('./Transaction'); const Balance = require('./Balance'); -const { spendTokens, spendStructuredTokens } = require('./spendTokens'); -const { getMultiplier, getCacheMultiplier } = require('./tx'); + +// Mock the custom config module so we can control the balance flag. +jest.mock('~/server/services/Config'); let mongoServer; @@ -20,6 +24,8 @@ afterAll(async () => { beforeEach(async () => { await mongoose.connection.dropDatabase(); + // Default: enable balance updates in tests. + getBalanceConfig.mockResolvedValue({ enabled: true }); }); describe('Regular Token Spending Tests', () => { @@ -44,34 +50,22 @@ describe('Regular Token Spending Tests', () => { }; // Act - process.env.CHECK_BALANCE = 'true'; await spendTokens(txData, tokenUsage); // Assert - console.log('Initial Balance:', initialBalance); - const updatedBalance = await Balance.findOne({ user: userId }); - console.log('Updated Balance:', updatedBalance.tokenCredits); - const promptMultiplier = getMultiplier({ model, tokenType: 'prompt' }); const completionMultiplier = getMultiplier({ model, tokenType: 'completion' }); - - const expectedPromptCost = tokenUsage.promptTokens * promptMultiplier; - const expectedCompletionCost = tokenUsage.completionTokens * completionMultiplier; - const expectedTotalCost = expectedPromptCost + expectedCompletionCost; + const expectedTotalCost = 100 * promptMultiplier + 50 * completionMultiplier; const expectedBalance = initialBalance - expectedTotalCost; - expect(updatedBalance.tokenCredits).toBeLessThan(initialBalance); expect(updatedBalance.tokenCredits).toBeCloseTo(expectedBalance, 0); - - console.log('Expected Total Cost:', expectedTotalCost); - console.log('Actual Balance Decrease:', initialBalance - updatedBalance.tokenCredits); }); test('spendTokens should handle zero completion tokens', async () => { // Arrange const userId = new mongoose.Types.ObjectId(); - const initialBalance = 10000000; // $10.00 + const initialBalance = 10000000; await Balance.create({ user: userId, tokenCredits: initialBalance }); const model = 'gpt-3.5-turbo'; @@ -89,24 +83,19 @@ describe('Regular Token Spending Tests', () => { }; // Act - process.env.CHECK_BALANCE = 'true'; await spendTokens(txData, tokenUsage); // Assert const updatedBalance = await Balance.findOne({ user: userId }); - const promptMultiplier = getMultiplier({ model, tokenType: 'prompt' }); - const expectedCost = tokenUsage.promptTokens * promptMultiplier; + const expectedCost = 100 * promptMultiplier; expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); - - console.log('Initial Balance:', initialBalance); - console.log('Updated Balance:', updatedBalance.tokenCredits); - console.log('Expected Cost:', expectedCost); }); test('spendTokens should handle undefined token counts', async () => { + // Arrange const userId = new mongoose.Types.ObjectId(); - const initialBalance = 10000000; // $10.00 + const initialBalance = 10000000; await Balance.create({ user: userId, tokenCredits: initialBalance }); const model = 'gpt-3.5-turbo'; @@ -120,14 +109,17 @@ describe('Regular Token Spending Tests', () => { const tokenUsage = {}; + // Act const result = await spendTokens(txData, tokenUsage); + // Assert: No transaction should be created expect(result).toBeUndefined(); }); test('spendTokens should handle only prompt tokens', async () => { + // Arrange const userId = new mongoose.Types.ObjectId(); - const initialBalance = 10000000; // $10.00 + const initialBalance = 10000000; await Balance.create({ user: userId, tokenCredits: initialBalance }); const model = 'gpt-3.5-turbo'; @@ -141,14 +133,44 @@ describe('Regular Token Spending Tests', () => { const tokenUsage = { promptTokens: 100 }; + // Act await spendTokens(txData, tokenUsage); + // Assert const updatedBalance = await Balance.findOne({ user: userId }); - const promptMultiplier = getMultiplier({ model, tokenType: 'prompt' }); const expectedCost = 100 * promptMultiplier; expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); }); + + test('spendTokens should not update balance when balance feature is disabled', async () => { + // Arrange: Override the config to disable balance updates. + getBalanceConfig.mockResolvedValue({ balance: { enabled: false } }); + const userId = new mongoose.Types.ObjectId(); + const initialBalance = 10000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'gpt-3.5-turbo'; + const txData = { + user: userId, + conversationId: 'test-conversation-id', + model, + context: 'test', + endpointTokenConfig: null, + }; + + const tokenUsage = { + promptTokens: 100, + completionTokens: 50, + }; + + // Act + await spendTokens(txData, tokenUsage); + + // Assert: Balance should remain unchanged. + const updatedBalance = await Balance.findOne({ user: userId }); + expect(updatedBalance.tokenCredits).toBe(initialBalance); + }); }); describe('Structured Token Spending Tests', () => { @@ -164,7 +186,7 @@ describe('Structured Token Spending Tests', () => { conversationId: 'c23a18da-706c-470a-ac28-ec87ed065199', model, context: 'message', - endpointTokenConfig: null, // We'll use the default rates + endpointTokenConfig: null, }; const tokenUsage = { @@ -176,28 +198,15 @@ describe('Structured Token Spending Tests', () => { completionTokens: 5, }; - // Get the actual multipliers const promptMultiplier = getMultiplier({ model, tokenType: 'prompt' }); const completionMultiplier = getMultiplier({ model, tokenType: 'completion' }); const writeMultiplier = getCacheMultiplier({ model, cacheType: 'write' }); const readMultiplier = getCacheMultiplier({ model, cacheType: 'read' }); - console.log('Multipliers:', { - promptMultiplier, - completionMultiplier, - writeMultiplier, - readMultiplier, - }); - // Act - process.env.CHECK_BALANCE = 'true'; const result = await spendStructuredTokens(txData, tokenUsage); - // Assert - console.log('Initial Balance:', initialBalance); - console.log('Updated Balance:', result.completion.balance); - console.log('Transaction Result:', result); - + // Calculate expected costs. const expectedPromptCost = tokenUsage.promptTokens.input * promptMultiplier + tokenUsage.promptTokens.write * writeMultiplier + @@ -206,37 +215,21 @@ describe('Structured Token Spending Tests', () => { const expectedTotalCost = expectedPromptCost + expectedCompletionCost; const expectedBalance = initialBalance - expectedTotalCost; - console.log('Expected Cost:', expectedTotalCost); - console.log('Expected Balance:', expectedBalance); - + // Assert expect(result.completion.balance).toBeLessThan(initialBalance); - - // Allow for a small difference (e.g., 100 token credits, which is $0.0001) const allowedDifference = 100; expect(Math.abs(result.completion.balance - expectedBalance)).toBeLessThan(allowedDifference); - - // Check if the decrease is approximately as expected const balanceDecrease = initialBalance - result.completion.balance; expect(balanceDecrease).toBeCloseTo(expectedTotalCost, 0); - // Check token values - const expectedPromptTokenValue = -( - tokenUsage.promptTokens.input * promptMultiplier + - tokenUsage.promptTokens.write * writeMultiplier + - tokenUsage.promptTokens.read * readMultiplier - ); - const expectedCompletionTokenValue = -tokenUsage.completionTokens * completionMultiplier; - + const expectedPromptTokenValue = -expectedPromptCost; + const expectedCompletionTokenValue = -expectedCompletionCost; expect(result.prompt.prompt).toBeCloseTo(expectedPromptTokenValue, 1); expect(result.completion.completion).toBe(expectedCompletionTokenValue); - - console.log('Expected prompt tokenValue:', expectedPromptTokenValue); - console.log('Actual prompt tokenValue:', result.prompt.prompt); - console.log('Expected completion tokenValue:', expectedCompletionTokenValue); - console.log('Actual completion tokenValue:', result.completion.completion); }); test('should handle zero completion tokens in structured spending', async () => { + // Arrange const userId = new mongoose.Types.ObjectId(); const initialBalance = 17613154.55; await Balance.create({ user: userId, tokenCredits: initialBalance }); @@ -258,15 +251,17 @@ describe('Structured Token Spending Tests', () => { completionTokens: 0, }; - process.env.CHECK_BALANCE = 'true'; + // Act const result = await spendStructuredTokens(txData, tokenUsage); + // Assert expect(result.prompt).toBeDefined(); expect(result.completion).toBeUndefined(); expect(result.prompt.prompt).toBeLessThan(0); }); test('should handle only prompt tokens in structured spending', async () => { + // Arrange const userId = new mongoose.Types.ObjectId(); const initialBalance = 17613154.55; await Balance.create({ user: userId, tokenCredits: initialBalance }); @@ -287,15 +282,17 @@ describe('Structured Token Spending Tests', () => { }, }; - process.env.CHECK_BALANCE = 'true'; + // Act const result = await spendStructuredTokens(txData, tokenUsage); + // Assert expect(result.prompt).toBeDefined(); expect(result.completion).toBeUndefined(); expect(result.prompt.prompt).toBeLessThan(0); }); test('should handle undefined token counts in structured spending', async () => { + // Arrange const userId = new mongoose.Types.ObjectId(); const initialBalance = 17613154.55; await Balance.create({ user: userId, tokenCredits: initialBalance }); @@ -310,9 +307,10 @@ describe('Structured Token Spending Tests', () => { const tokenUsage = {}; - process.env.CHECK_BALANCE = 'true'; + // Act const result = await spendStructuredTokens(txData, tokenUsage); + // Assert expect(result).toEqual({ prompt: undefined, completion: undefined, @@ -320,6 +318,7 @@ describe('Structured Token Spending Tests', () => { }); test('should handle incomplete context for completion tokens', async () => { + // Arrange const userId = new mongoose.Types.ObjectId(); const initialBalance = 17613154.55; await Balance.create({ user: userId, tokenCredits: initialBalance }); @@ -341,15 +340,18 @@ describe('Structured Token Spending Tests', () => { completionTokens: 50, }; - process.env.CHECK_BALANCE = 'true'; + // Act const result = await spendStructuredTokens(txData, tokenUsage); - expect(result.completion.completion).toBeCloseTo(-50 * 15 * 1.15, 0); // Assuming multiplier is 15 and cancelRate is 1.15 + // Assert: + // (Assuming a multiplier for completion of 15 and a cancel rate of 1.15 as noted in the original test.) + expect(result.completion.completion).toBeCloseTo(-50 * 15 * 1.15, 0); }); }); describe('NaN Handling Tests', () => { test('should skip transaction creation when rawAmount is NaN', async () => { + // Arrange const userId = new mongoose.Types.ObjectId(); const initialBalance = 10000000; await Balance.create({ user: userId, tokenCredits: initialBalance }); @@ -365,9 +367,11 @@ describe('NaN Handling Tests', () => { tokenType: 'prompt', }; + // Act const result = await Transaction.create(txData); - expect(result).toBeUndefined(); + // Assert: No transaction should be created and balance remains unchanged. + expect(result).toBeUndefined(); const balance = await Balance.findOne({ user: userId }); expect(balance.tokenCredits).toBe(initialBalance); }); diff --git a/api/models/User.js b/api/models/User.js index 55750b4ae5..f4e8b0ec5b 100644 --- a/api/models/User.js +++ b/api/models/User.js @@ -1,5 +1,5 @@ const mongoose = require('mongoose'); -const userSchema = require('~/models/schema/userSchema'); +const { userSchema } = require('@librechat/data-schemas'); const User = mongoose.model('User', userSchema); diff --git a/api/models/balanceMethods.js b/api/models/balanceMethods.js new file mode 100644 index 0000000000..4b788160aa --- /dev/null +++ b/api/models/balanceMethods.js @@ -0,0 +1,156 @@ +const { ViolationTypes } = require('librechat-data-provider'); +const { Transaction } = require('./Transaction'); +const { logViolation } = require('~/cache'); +const { getMultiplier } = require('./tx'); +const { logger } = require('~/config'); +const Balance = require('./Balance'); + +function isInvalidDate(date) { + return isNaN(date); +} + +/** + * Simple check method that calculates token cost and returns balance info. + * The auto-refill logic has been moved to balanceMethods.js to prevent circular dependencies. + */ +const checkBalanceRecord = async function ({ + user, + model, + endpoint, + valueKey, + tokenType, + amount, + endpointTokenConfig, +}) { + const multiplier = getMultiplier({ valueKey, tokenType, model, endpoint, endpointTokenConfig }); + const tokenCost = amount * multiplier; + + // Retrieve the balance record + let record = await Balance.findOne({ user }).lean(); + if (!record) { + logger.debug('[Balance.check] No balance record found for user', { user }); + return { + canSpend: false, + balance: 0, + tokenCost, + }; + } + let balance = record.tokenCredits; + + logger.debug('[Balance.check] Initial state', { + user, + model, + endpoint, + valueKey, + tokenType, + amount, + balance, + multiplier, + endpointTokenConfig: !!endpointTokenConfig, + }); + + // Only perform auto-refill if spending would bring the balance to 0 or below + if (balance - tokenCost <= 0 && record.autoRefillEnabled && record.refillAmount > 0) { + const lastRefillDate = new Date(record.lastRefill); + const now = new Date(); + if ( + isInvalidDate(lastRefillDate) || + now >= + addIntervalToDate(lastRefillDate, record.refillIntervalValue, record.refillIntervalUnit) + ) { + try { + /** @type {{ rate: number, user: string, balance: number, transaction: import('@librechat/data-schemas').ITransaction}} */ + const result = await Transaction.createAutoRefillTransaction({ + user: user, + tokenType: 'credits', + context: 'autoRefill', + rawAmount: record.refillAmount, + }); + balance = result.balance; + } catch (error) { + logger.error('[Balance.check] Failed to record transaction for auto-refill', error); + } + } + } + + logger.debug('[Balance.check] Token cost', { tokenCost }); + return { canSpend: balance >= tokenCost, balance, tokenCost }; +}; + +/** + * Adds a time interval to a given date. + * @param {Date} date - The starting date. + * @param {number} value - The numeric value of the interval. + * @param {'seconds'|'minutes'|'hours'|'days'|'weeks'|'months'} unit - The unit of time. + * @returns {Date} A new Date representing the starting date plus the interval. + */ +const addIntervalToDate = (date, value, unit) => { + const result = new Date(date); + switch (unit) { + case 'seconds': + result.setSeconds(result.getSeconds() + value); + break; + case 'minutes': + result.setMinutes(result.getMinutes() + value); + break; + case 'hours': + result.setHours(result.getHours() + value); + break; + case 'days': + result.setDate(result.getDate() + value); + break; + case 'weeks': + result.setDate(result.getDate() + value * 7); + break; + case 'months': + result.setMonth(result.getMonth() + value); + break; + default: + break; + } + return result; +}; + +/** + * Checks the balance for a user and determines if they can spend a certain amount. + * If the user cannot spend the amount, it logs a violation and denies the request. + * + * @async + * @function + * @param {Object} params - The function parameters. + * @param {Express.Request} params.req - The Express request object. + * @param {Express.Response} params.res - The Express response object. + * @param {Object} params.txData - The transaction data. + * @param {string} params.txData.user - The user ID or identifier. + * @param {('prompt' | 'completion')} params.txData.tokenType - The type of token. + * @param {number} params.txData.amount - The amount of tokens. + * @param {string} params.txData.model - The model name or identifier. + * @param {string} [params.txData.endpointTokenConfig] - The token configuration for the endpoint. + * @returns {Promise} Throws error if the user cannot spend the amount. + * @throws {Error} Throws an error if there's an issue with the balance check. + */ +const checkBalance = async ({ req, res, txData }) => { + const { canSpend, balance, tokenCost } = await checkBalanceRecord(txData); + if (canSpend) { + return true; + } + + const type = ViolationTypes.TOKEN_BALANCE; + const errorMessage = { + type, + balance, + tokenCost, + promptTokens: txData.amount, + }; + + if (txData.generations && txData.generations.length > 0) { + errorMessage.generations = txData.generations; + } + + await logViolation(req, res, type, errorMessage, 0); + throw new Error(JSON.stringify(errorMessage)); +}; + +module.exports = { + checkBalance, +}; diff --git a/api/models/checkBalance.js b/api/models/checkBalance.js deleted file mode 100644 index 5af77bbb19..0000000000 --- a/api/models/checkBalance.js +++ /dev/null @@ -1,45 +0,0 @@ -const { ViolationTypes } = require('librechat-data-provider'); -const { logViolation } = require('~/cache'); -const Balance = require('./Balance'); -/** - * Checks the balance for a user and determines if they can spend a certain amount. - * If the user cannot spend the amount, it logs a violation and denies the request. - * - * @async - * @function - * @param {Object} params - The function parameters. - * @param {Express.Request} params.req - The Express request object. - * @param {Express.Response} params.res - The Express response object. - * @param {Object} params.txData - The transaction data. - * @param {string} params.txData.user - The user ID or identifier. - * @param {('prompt' | 'completion')} params.txData.tokenType - The type of token. - * @param {number} params.txData.amount - The amount of tokens. - * @param {string} params.txData.model - The model name or identifier. - * @param {string} [params.txData.endpointTokenConfig] - The token configuration for the endpoint. - * @returns {Promise} Returns true if the user can spend the amount, otherwise denies the request. - * @throws {Error} Throws an error if there's an issue with the balance check. - */ -const checkBalance = async ({ req, res, txData }) => { - const { canSpend, balance, tokenCost } = await Balance.check(txData); - - if (canSpend) { - return true; - } - - const type = ViolationTypes.TOKEN_BALANCE; - const errorMessage = { - type, - balance, - tokenCost, - promptTokens: txData.amount, - }; - - if (txData.generations && txData.generations.length > 0) { - errorMessage.generations = txData.generations; - } - - await logViolation(req, res, type, errorMessage, 0); - throw new Error(JSON.stringify(errorMessage)); -}; - -module.exports = checkBalance; diff --git a/api/models/plugins/mongoMeili.js b/api/models/plugins/mongoMeili.js index df96338302..75e3738e5d 100644 --- a/api/models/plugins/mongoMeili.js +++ b/api/models/plugins/mongoMeili.js @@ -1,12 +1,32 @@ const _ = require('lodash'); const mongoose = require('mongoose'); const { MeiliSearch } = require('meilisearch'); +const { parseTextParts, ContentTypes } = require('librechat-data-provider'); const { cleanUpPrimaryKeyValue } = require('~/lib/utils/misc'); const logger = require('~/config/meiliLogger'); +// Environment flags +/** + * Flag to indicate if search is enabled based on environment variables. + * @type {boolean} + */ const searchEnabled = process.env.SEARCH && process.env.SEARCH.toLowerCase() === 'true'; + +/** + * Flag to indicate if MeiliSearch is enabled based on required environment variables. + * @type {boolean} + */ const meiliEnabled = process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY && searchEnabled; +/** + * Validates the required options for configuring the mongoMeili plugin. + * + * @param {Object} options - The configuration options. + * @param {string} options.host - The MeiliSearch host. + * @param {string} options.apiKey - The MeiliSearch API key. + * @param {string} options.indexName - The name of the index. + * @throws {Error} Throws an error if any required option is missing. + */ const validateOptions = function (options) { const requiredKeys = ['host', 'apiKey', 'indexName']; requiredKeys.forEach((key) => { @@ -16,53 +36,64 @@ const validateOptions = function (options) { }); }; -// const createMeiliMongooseModel = function ({ index, indexName, client, attributesToIndex }) { +/** + * Factory function to create a MeiliMongooseModel class which extends a Mongoose model. + * This class contains static and instance methods to synchronize and manage the MeiliSearch index + * corresponding to the MongoDB collection. + * + * @param {Object} config - Configuration object. + * @param {Object} config.index - The MeiliSearch index object. + * @param {Array} config.attributesToIndex - List of attributes to index. + * @returns {Function} A class definition that will be loaded into the Mongoose schema. + */ const createMeiliMongooseModel = function ({ index, attributesToIndex }) { + // The primary key is assumed to be the first attribute in the attributesToIndex array. const primaryKey = attributesToIndex[0]; - // MeiliMongooseModel is of type Mongoose.Model + class MeiliMongooseModel { /** - * `syncWithMeili`: synchronizes the data between a MongoDB collection and a MeiliSearch index, - * only triggered if there's ever a discrepancy determined by `api\lib\db\indexSync.js`. + * Synchronizes the data between the MongoDB collection and the MeiliSearch index. * - * 1. Fetches all documents from the MongoDB collection and the MeiliSearch index. - * 2. Compares the documents from both sources. - * 3. If a document exists in MeiliSearch but not in MongoDB, it's deleted from MeiliSearch. - * 4. If a document exists in MongoDB but not in MeiliSearch, it's added to MeiliSearch. - * 5. If a document exists in both but has different `text` or `title` fields (depending on the `primaryKey`), it's updated in MeiliSearch. - * 6. After all operations, it updates the `_meiliIndex` field in MongoDB to indicate whether the document is indexed in MeiliSearch. + * The synchronization process involves: + * 1. Fetching all documents from the MongoDB collection and MeiliSearch index. + * 2. Comparing documents from both sources. + * 3. Deleting documents from MeiliSearch that no longer exist in MongoDB. + * 4. Adding documents to MeiliSearch that exist in MongoDB but not in the index. + * 5. Updating documents in MeiliSearch if key fields (such as `text` or `title`) differ. + * 6. Updating the `_meiliIndex` field in MongoDB to indicate the indexing status. * - * Note: This strategy does not use batch operations for Meilisearch as the `index.addDocuments` will discard - * the entire batch if there's an error with one document, and will not throw an error if there's an issue. - * Also, `index.getDocuments` needs an exact limit on the amount of documents to return, so we build the map in batches. + * Note: The function processes documents in batches because MeiliSearch's + * `index.getDocuments` requires an exact limit and `index.addDocuments` does not handle + * partial failures in a batch. * - * @returns {Promise} A promise that resolves when the synchronization is complete. - * - * @throws {Error} Throws an error if there's an issue with adding a document to MeiliSearch. + * @returns {Promise} Resolves when the synchronization is complete. */ static async syncWithMeili() { try { let moreDocuments = true; + // Retrieve all MongoDB documents from the collection as plain JavaScript objects. const mongoDocuments = await this.find().lean(); - const format = (doc) => _.pick(doc, attributesToIndex); - // Prepare for comparison + // Helper function to format a document by selecting only the attributes to index + // and omitting keys starting with '$'. + const format = (doc) => + _.omitBy(_.pick(doc, attributesToIndex), (v, k) => k.startsWith('$')); + + // Build a map of MongoDB documents for quick lookup based on the primary key. const mongoMap = new Map(mongoDocuments.map((doc) => [doc[primaryKey], format(doc)])); const indexMap = new Map(); let offset = 0; const batchSize = 1000; + // Fetch documents from the MeiliSearch index in batches. while (moreDocuments) { const batch = await index.getDocuments({ limit: batchSize, offset }); - if (batch.results.length === 0) { moreDocuments = false; } - for (const doc of batch.results) { indexMap.set(doc[primaryKey], format(doc)); } - offset += batchSize; } @@ -70,13 +101,12 @@ const createMeiliMongooseModel = function ({ index, attributesToIndex }) { const updateOps = []; - // Iterate over Meili index documents + // Process documents present in the MeiliSearch index. for (const [id, doc] of indexMap) { const update = {}; update[primaryKey] = id; if (mongoMap.has(id)) { - // Case: Update - // If document also exists in MongoDB, would be update case + // If document exists in MongoDB, check for discrepancies in key fields. if ( (doc.text && doc.text !== mongoMap.get(id).text) || (doc.title && doc.title !== mongoMap.get(id).title) @@ -92,8 +122,7 @@ const createMeiliMongooseModel = function ({ index, attributesToIndex }) { await index.addDocuments([doc]); } } else { - // Case: Delete - // If document does not exist in MongoDB, its a delete case from meili index + // If the document does not exist in MongoDB, delete it from MeiliSearch. await index.deleteDocument(id); updateOps.push({ updateOne: { filter: update, update: { $set: { _meiliIndex: false } } }, @@ -101,24 +130,25 @@ const createMeiliMongooseModel = function ({ index, attributesToIndex }) { } } - // Iterate over MongoDB documents + // Process documents present in MongoDB. for (const [id, doc] of mongoMap) { const update = {}; update[primaryKey] = id; - // Case: Insert - // If document does not exist in Meili Index, Its an insert case + // If the document is missing in the Meili index, add it. if (!indexMap.has(id)) { await index.addDocuments([doc]); updateOps.push({ updateOne: { filter: update, update: { $set: { _meiliIndex: true } } }, }); } else if (doc._meiliIndex === false) { + // If the document exists but is marked as not indexed, update the flag. updateOps.push({ updateOne: { filter: update, update: { $set: { _meiliIndex: true } } }, }); } } + // Execute bulk update operations in MongoDB to update the _meiliIndex flags. if (updateOps.length > 0) { await this.collection.bulkWrite(updateOps); logger.debug( @@ -132,34 +162,47 @@ const createMeiliMongooseModel = function ({ index, attributesToIndex }) { } } - // Set one or more settings of the meili index + /** + * Updates settings for the MeiliSearch index. + * + * @param {Object} settings - The settings to update on the MeiliSearch index. + * @returns {Promise} Promise resolving to the update result. + */ static async setMeiliIndexSettings(settings) { return await index.updateSettings(settings); } - // Search the index + /** + * Searches the MeiliSearch index and optionally populates the results with data from MongoDB. + * + * @param {string} q - The search query. + * @param {Object} params - Additional search parameters for MeiliSearch. + * @param {boolean} populate - Whether to populate search hits with full MongoDB documents. + * @returns {Promise} The search results with populated hits if requested. + */ static async meiliSearch(q, params, populate) { const data = await index.search(q, params); - // Populate hits with content from mongodb if (populate) { - // Find objects into mongodb matching `objectID` from Meili search + // Build a query using the primary key values from the search hits. const query = {}; - // query[primaryKey] = { $in: _.map(data.hits, primaryKey) }; query[primaryKey] = _.map(data.hits, (hit) => cleanUpPrimaryKeyValue(hit[primaryKey])); - // logger.debug('query', query); - const hitsFromMongoose = await this.find( - query, - _.reduce( - this.schema.obj, - function (results, value, key) { - return { ...results, [key]: 1 }; - }, - { _id: 1, __v: 1 }, - ), - ).lean(); - // Add additional data from mongodb into Meili search hits + // Build a projection object, including only keys that do not start with '$'. + const projection = Object.keys(this.schema.obj).reduce( + (results, key) => { + if (!key.startsWith('$')) { + results[key] = 1; + } + return results; + }, + { _id: 1, __v: 1 }, + ); + + // Retrieve the full documents from MongoDB. + const hitsFromMongoose = await this.find(query, projection).lean(); + + // Merge the MongoDB documents with the search hits. const populatedHits = data.hits.map(function (hit) { const query = {}; query[primaryKey] = hit[primaryKey]; @@ -176,51 +219,80 @@ const createMeiliMongooseModel = function ({ index, attributesToIndex }) { return data; } + /** + * Preprocesses the current document for indexing. + * + * This method: + * - Picks only the defined attributes to index. + * - Omits any keys starting with '$'. + * - Replaces pipe characters ('|') in `conversationId` with '--'. + * - Extracts and concatenates text from an array of content items. + * + * @returns {Object} The preprocessed object ready for indexing. + */ preprocessObjectForIndex() { - const object = _.pick(this.toJSON(), attributesToIndex); - // NOTE: MeiliSearch does not allow | in primary key, so we replace it with - for Bing convoIds - // object.conversationId = object.conversationId.replace(/\|/g, '-'); + const object = _.omitBy(_.pick(this.toJSON(), attributesToIndex), (v, k) => + k.startsWith('$'), + ); if (object.conversationId && object.conversationId.includes('|')) { object.conversationId = object.conversationId.replace(/\|/g, '--'); } if (object.content && Array.isArray(object.content)) { - object.text = object.content - .filter((item) => item.type === 'text' && item.text && item.text.value) - .map((item) => item.text.value) - .join(' '); + object.text = parseTextParts(object.content); delete object.content; } return object; } - // Push new document to Meili + /** + * Adds the current document to the MeiliSearch index. + * + * The method preprocesses the document, adds it to MeiliSearch, and then updates + * the MongoDB document's `_meiliIndex` flag to true. + * + * @returns {Promise} + */ async addObjectToMeili() { const object = this.preprocessObjectForIndex(); try { - // logger.debug('Adding document to Meili', object); await index.addDocuments([object]); } catch (error) { - // logger.debug('Error adding document to Meili'); - // logger.error(error); + // Error handling can be enhanced as needed. + logger.error('[addObjectToMeili] Error adding document to Meili', error); } await this.collection.updateMany({ _id: this._id }, { $set: { _meiliIndex: true } }); } - // Update an existing document in Meili + /** + * Updates the current document in the MeiliSearch index. + * + * @returns {Promise} + */ async updateObjectToMeili() { - const object = _.pick(this.toJSON(), attributesToIndex); + const object = _.omitBy(_.pick(this.toJSON(), attributesToIndex), (v, k) => + k.startsWith('$'), + ); await index.updateDocuments([object]); } - // Delete a document from Meili + /** + * Deletes the current document from the MeiliSearch index. + * + * @returns {Promise} + */ async deleteObjectFromMeili() { await index.deleteDocument(this._id); } - // * schema.post('save') + /** + * Post-save hook to synchronize the document with MeiliSearch. + * + * If the document is already indexed (i.e. `_meiliIndex` is true), it updates it; + * otherwise, it adds the document to the index. + */ postSaveHook() { if (this._meiliIndex) { this.updateObjectToMeili(); @@ -229,14 +301,24 @@ const createMeiliMongooseModel = function ({ index, attributesToIndex }) { } } - // * schema.post('update') + /** + * Post-update hook to update the document in MeiliSearch. + * + * This hook is triggered after a document update, ensuring that changes are + * propagated to the MeiliSearch index if the document is indexed. + */ postUpdateHook() { if (this._meiliIndex) { this.updateObjectToMeili(); } } - // * schema.post('remove') + /** + * Post-remove hook to delete the document from MeiliSearch. + * + * This hook is triggered after a document is removed, ensuring that the document + * is also removed from the MeiliSearch index if it was previously indexed. + */ postRemoveHook() { if (this._meiliIndex) { this.deleteObjectFromMeili(); @@ -247,11 +329,27 @@ const createMeiliMongooseModel = function ({ index, attributesToIndex }) { return MeiliMongooseModel; }; +/** + * Mongoose plugin to synchronize MongoDB collections with a MeiliSearch index. + * + * This plugin: + * - Validates the provided options. + * - Adds a `_meiliIndex` field to the schema to track indexing status. + * - Sets up a MeiliSearch client and creates an index if it doesn't already exist. + * - Loads class methods for syncing, searching, and managing documents in MeiliSearch. + * - Registers Mongoose hooks (post-save, post-update, post-remove, etc.) to maintain index consistency. + * + * @param {mongoose.Schema} schema - The Mongoose schema to which the plugin is applied. + * @param {Object} options - Configuration options. + * @param {string} options.host - The MeiliSearch host. + * @param {string} options.apiKey - The MeiliSearch API key. + * @param {string} options.indexName - The name of the MeiliSearch index. + * @param {string} options.primaryKey - The primary key field for indexing. + */ module.exports = function mongoMeili(schema, options) { - // Vaidate Options for mongoMeili validateOptions(options); - // Add meiliIndex to schema + // Add _meiliIndex field to the schema to track if a document has been indexed in MeiliSearch. schema.add({ _meiliIndex: { type: Boolean, @@ -263,69 +361,77 @@ module.exports = function mongoMeili(schema, options) { const { host, apiKey, indexName, primaryKey } = options; - // Setup MeiliSearch Client + // Setup the MeiliSearch client. const client = new MeiliSearch({ host, apiKey }); - // Asynchronously create the index + // Create the index asynchronously if it doesn't exist. client.createIndex(indexName, { primaryKey }); - // Setup the index to search for this schema + // Setup the MeiliSearch index for this schema. const index = client.index(indexName); + // Collect attributes from the schema that should be indexed. const attributesToIndex = [ ..._.reduce( schema.obj, function (results, value, key) { return value.meiliIndex ? [...results, key] : results; - // }, []), '_id']; }, [], ), ]; + // Load the class methods into the schema. schema.loadClass(createMeiliMongooseModel({ index, indexName, client, attributesToIndex })); - // Register hooks + // Register Mongoose hooks to synchronize with MeiliSearch. + + // Post-save: synchronize after a document is saved. schema.post('save', function (doc) { doc.postSaveHook(); }); + + // Post-update: synchronize after a document is updated. schema.post('update', function (doc) { doc.postUpdateHook(); }); + + // Post-remove: synchronize after a document is removed. schema.post('remove', function (doc) { doc.postRemoveHook(); }); + // Pre-deleteMany hook: remove corresponding documents from MeiliSearch when multiple documents are deleted. schema.pre('deleteMany', async function (next) { if (!meiliEnabled) { - next(); + return next(); } try { + // Check if the schema has a "messages" field to determine if it's a conversation schema. if (Object.prototype.hasOwnProperty.call(schema.obj, 'messages')) { const convoIndex = client.index('convos'); const deletedConvos = await mongoose.model('Conversation').find(this._conditions).lean(); - let promises = []; - for (const convo of deletedConvos) { - promises.push(convoIndex.deleteDocument(convo.conversationId)); - } + const promises = deletedConvos.map((convo) => + convoIndex.deleteDocument(convo.conversationId), + ); await Promise.all(promises); } + // Check if the schema has a "messageId" field to determine if it's a message schema. if (Object.prototype.hasOwnProperty.call(schema.obj, 'messageId')) { const messageIndex = client.index('messages'); const deletedMessages = await mongoose.model('Message').find(this._conditions).lean(); - let promises = []; - for (const message of deletedMessages) { - promises.push(messageIndex.deleteDocument(message.messageId)); - } + const promises = deletedMessages.map((message) => + messageIndex.deleteDocument(message.messageId), + ); await Promise.all(promises); } return next(); } catch (error) { if (meiliEnabled) { logger.error( - '[MeiliMongooseModel.deleteMany] There was an issue deleting conversation indexes upon deletion, next startup may be slow due to syncing', + '[MeiliMongooseModel.deleteMany] There was an issue deleting conversation indexes upon deletion. Next startup may be slow due to syncing.', error, ); } @@ -333,17 +439,19 @@ module.exports = function mongoMeili(schema, options) { } }); + // Post-findOneAndUpdate hook: update MeiliSearch index after a document is updated via findOneAndUpdate. schema.post('findOneAndUpdate', async function (doc) { if (!meiliEnabled) { return; } + // If the document is unfinished, do not update the index. if (doc.unfinished) { return; } let meiliDoc; - // Doc is a Conversation + // For conversation documents, try to fetch the document from the "convos" index. if (doc.messages) { try { meiliDoc = await client.index('convos').getDocument(doc.conversationId); @@ -356,10 +464,12 @@ module.exports = function mongoMeili(schema, options) { } } + // If the MeiliSearch document exists and the title is unchanged, do nothing. if (meiliDoc && meiliDoc.title === doc.title) { return; } + // Otherwise, trigger a post-save hook to synchronize the document. doc.postSaveHook(); }); }; diff --git a/api/models/schema/action.js b/api/models/schema/action.js deleted file mode 100644 index f86a9bfa2d..0000000000 --- a/api/models/schema/action.js +++ /dev/null @@ -1,60 +0,0 @@ -const mongoose = require('mongoose'); - -const { Schema } = mongoose; - -const AuthSchema = new Schema( - { - authorization_type: String, - custom_auth_header: String, - type: { - type: String, - enum: ['service_http', 'oauth', 'none'], - }, - authorization_content_type: String, - authorization_url: String, - client_url: String, - scope: String, - token_exchange_method: { - type: String, - enum: ['default_post', 'basic_auth_header', null], - }, - }, - { _id: false }, -); - -const actionSchema = new Schema({ - user: { - type: mongoose.Schema.Types.ObjectId, - ref: 'User', - index: true, - required: true, - }, - action_id: { - type: String, - index: true, - required: true, - }, - type: { - type: String, - default: 'action_prototype', - }, - settings: Schema.Types.Mixed, - agent_id: String, - assistant_id: String, - metadata: { - api_key: String, // private, encrypted - auth: AuthSchema, - domain: { - type: String, - required: true, - }, - // json_schema: Schema.Types.Mixed, - privacy_policy_url: String, - raw_spec: String, - oauth_client_id: String, // private, encrypted - oauth_client_secret: String, // private, encrypted - }, -}); -// }, { minimize: false }); // Prevent removal of empty objects - -module.exports = actionSchema; diff --git a/api/models/schema/balance.js b/api/models/schema/balance.js deleted file mode 100644 index 8ca8116e09..0000000000 --- a/api/models/schema/balance.js +++ /dev/null @@ -1,17 +0,0 @@ -const mongoose = require('mongoose'); - -const balanceSchema = mongoose.Schema({ - user: { - type: mongoose.Schema.Types.ObjectId, - ref: 'User', - index: true, - required: true, - }, - // 1000 tokenCredits = 1 mill ($0.001 USD) - tokenCredits: { - type: Number, - default: 0, - }, -}); - -module.exports = balanceSchema; diff --git a/api/models/schema/categories.js b/api/models/schema/categories.js deleted file mode 100644 index 3167685667..0000000000 --- a/api/models/schema/categories.js +++ /dev/null @@ -1,19 +0,0 @@ -const mongoose = require('mongoose'); -const Schema = mongoose.Schema; - -const categoriesSchema = new Schema({ - label: { - type: String, - required: true, - unique: true, - }, - value: { - type: String, - required: true, - unique: true, - }, -}); - -const categories = mongoose.model('categories', categoriesSchema); - -module.exports = { Categories: categories }; diff --git a/api/models/schema/conversationTagSchema.js b/api/models/schema/conversationTagSchema.js deleted file mode 100644 index 9b2a98c6d8..0000000000 --- a/api/models/schema/conversationTagSchema.js +++ /dev/null @@ -1,32 +0,0 @@ -const mongoose = require('mongoose'); - -const conversationTagSchema = mongoose.Schema( - { - tag: { - type: String, - index: true, - }, - user: { - type: String, - index: true, - }, - description: { - type: String, - index: true, - }, - count: { - type: Number, - default: 0, - }, - position: { - type: Number, - default: 0, - index: true, - }, - }, - { timestamps: true }, -); - -conversationTagSchema.index({ tag: 1, user: 1 }, { unique: true }); - -module.exports = mongoose.model('ConversationTag', conversationTagSchema); diff --git a/api/models/schema/convoSchema.js b/api/models/schema/convoSchema.js index 7d8beed6a6..89cb9c80b5 100644 --- a/api/models/schema/convoSchema.js +++ b/api/models/schema/convoSchema.js @@ -1,63 +1,18 @@ const mongoose = require('mongoose'); const mongoMeili = require('../plugins/mongoMeili'); -const { conversationPreset } = require('./defaults'); -const convoSchema = mongoose.Schema( - { - conversationId: { - type: String, - unique: true, - required: true, - index: true, - meiliIndex: true, - }, - title: { - type: String, - default: 'New Chat', - meiliIndex: true, - }, - user: { - type: String, - index: true, - }, - messages: [{ type: mongoose.Schema.Types.ObjectId, ref: 'Message' }], - // google only - examples: { type: [{ type: mongoose.Schema.Types.Mixed }], default: undefined }, - agentOptions: { - type: mongoose.Schema.Types.Mixed, - }, - ...conversationPreset, - agent_id: { - type: String, - }, - tags: { - type: [String], - default: [], - meiliIndex: true, - }, - files: { - type: [String], - }, - expiredAt: { - type: Date, - }, - }, - { timestamps: true }, -); + +const { convoSchema } = require('@librechat/data-schemas'); if (process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY) { convoSchema.plugin(mongoMeili, { host: process.env.MEILI_HOST, apiKey: process.env.MEILI_MASTER_KEY, - indexName: 'convos', // Will get created automatically if it doesn't exist already + /** Note: Will get created automatically if it doesn't exist already */ + indexName: 'convos', primaryKey: 'conversationId', }); } -// Create TTL index -convoSchema.index({ expiredAt: 1 }, { expireAfterSeconds: 0 }); -convoSchema.index({ createdAt: 1, updatedAt: 1 }); -convoSchema.index({ conversationId: 1, user: 1 }, { unique: true }); - const Conversation = mongoose.models.Conversation || mongoose.model('Conversation', convoSchema); module.exports = Conversation; diff --git a/api/models/schema/fileSchema.js b/api/models/schema/fileSchema.js deleted file mode 100644 index 77c6ff94d4..0000000000 --- a/api/models/schema/fileSchema.js +++ /dev/null @@ -1,111 +0,0 @@ -const { FileSources } = require('librechat-data-provider'); -const mongoose = require('mongoose'); - -/** - * @typedef {Object} MongoFile - * @property {ObjectId} [_id] - MongoDB Document ID - * @property {number} [__v] - MongoDB Version Key - * @property {ObjectId} user - User ID - * @property {string} [conversationId] - Optional conversation ID - * @property {string} file_id - File identifier - * @property {string} [temp_file_id] - Temporary File identifier - * @property {number} bytes - Size of the file in bytes - * @property {string} filename - Name of the file - * @property {string} filepath - Location of the file - * @property {'file'} object - Type of object, always 'file' - * @property {string} type - Type of file - * @property {number} [usage=0] - Number of uses of the file - * @property {string} [context] - Context of the file origin - * @property {boolean} [embedded=false] - Whether or not the file is embedded in vector db - * @property {string} [model] - The model to identify the group region of the file (for Azure OpenAI hosting) - * @property {string} [source] - The source of the file (e.g., from FileSources) - * @property {number} [width] - Optional width of the file - * @property {number} [height] - Optional height of the file - * @property {Object} [metadata] - Metadata related to the file - * @property {string} [metadata.fileIdentifier] - Unique identifier for the file in metadata - * @property {Date} [expiresAt] - Optional expiration date of the file - * @property {Date} [createdAt] - Date when the file was created - * @property {Date} [updatedAt] - Date when the file was updated - */ - -/** @type {MongooseSchema} */ -const fileSchema = mongoose.Schema( - { - user: { - type: mongoose.Schema.Types.ObjectId, - ref: 'User', - index: true, - required: true, - }, - conversationId: { - type: String, - ref: 'Conversation', - index: true, - }, - file_id: { - type: String, - // required: true, - index: true, - }, - temp_file_id: { - type: String, - // required: true, - }, - bytes: { - type: Number, - required: true, - }, - filename: { - type: String, - required: true, - }, - filepath: { - type: String, - required: true, - }, - object: { - type: String, - required: true, - default: 'file', - }, - embedded: { - type: Boolean, - }, - type: { - type: String, - required: true, - }, - context: { - type: String, - // required: true, - }, - usage: { - type: Number, - required: true, - default: 0, - }, - source: { - type: String, - default: FileSources.local, - }, - model: { - type: String, - }, - width: Number, - height: Number, - metadata: { - fileIdentifier: String, - }, - expiresAt: { - type: Date, - expires: 3600, // 1 hour in seconds - }, - }, - { - timestamps: true, - }, -); - -fileSchema.index({ createdAt: 1, updatedAt: 1 }); - -module.exports = fileSchema; diff --git a/api/models/schema/messageSchema.js b/api/models/schema/messageSchema.js index be71155295..cf97b84eea 100644 --- a/api/models/schema/messageSchema.js +++ b/api/models/schema/messageSchema.js @@ -1,145 +1,6 @@ const mongoose = require('mongoose'); const mongoMeili = require('~/models/plugins/mongoMeili'); -const messageSchema = mongoose.Schema( - { - messageId: { - type: String, - unique: true, - required: true, - index: true, - meiliIndex: true, - }, - conversationId: { - type: String, - index: true, - required: true, - meiliIndex: true, - }, - user: { - type: String, - index: true, - required: true, - default: null, - }, - model: { - type: String, - default: null, - }, - endpoint: { - type: String, - }, - conversationSignature: { - type: String, - }, - clientId: { - type: String, - }, - invocationId: { - type: Number, - }, - parentMessageId: { - type: String, - }, - tokenCount: { - type: Number, - }, - summaryTokenCount: { - type: Number, - }, - sender: { - type: String, - meiliIndex: true, - }, - text: { - type: String, - meiliIndex: true, - }, - summary: { - type: String, - }, - isCreatedByUser: { - type: Boolean, - required: true, - default: false, - }, - unfinished: { - type: Boolean, - default: false, - }, - error: { - type: Boolean, - default: false, - }, - finish_reason: { - type: String, - }, - _meiliIndex: { - type: Boolean, - required: false, - select: false, - default: false, - }, - files: { type: [{ type: mongoose.Schema.Types.Mixed }], default: undefined }, - plugin: { - type: { - latest: { - type: String, - required: false, - }, - inputs: { - type: [mongoose.Schema.Types.Mixed], - required: false, - default: undefined, - }, - outputs: { - type: String, - required: false, - }, - }, - default: undefined, - }, - plugins: { type: [{ type: mongoose.Schema.Types.Mixed }], default: undefined }, - content: { - type: [{ type: mongoose.Schema.Types.Mixed }], - default: undefined, - meiliIndex: true, - }, - thread_id: { - type: String, - }, - /* frontend components */ - iconURL: { - type: String, - }, - attachments: { type: [{ type: mongoose.Schema.Types.Mixed }], default: undefined }, - /* - attachments: { - type: [ - { - file_id: String, - filename: String, - filepath: String, - expiresAt: Date, - width: Number, - height: Number, - type: String, - conversationId: String, - messageId: { - type: String, - required: true, - }, - toolCallId: String, - }, - ], - default: undefined, - }, - */ - expiredAt: { - type: Date, - }, - }, - { timestamps: true }, -); +const { messageSchema } = require('@librechat/data-schemas'); if (process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY) { messageSchema.plugin(mongoMeili, { @@ -149,11 +10,7 @@ if (process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY) { primaryKey: 'messageId', }); } -messageSchema.index({ expiredAt: 1 }, { expireAfterSeconds: 0 }); -messageSchema.index({ createdAt: 1 }); -messageSchema.index({ messageId: 1, user: 1 }, { unique: true }); -/** @type {mongoose.Model} */ const Message = mongoose.models.Message || mongoose.model('Message', messageSchema); module.exports = Message; diff --git a/api/models/schema/pluginAuthSchema.js b/api/models/schema/pluginAuthSchema.js index 4b4251dda3..2066eda4c4 100644 --- a/api/models/schema/pluginAuthSchema.js +++ b/api/models/schema/pluginAuthSchema.js @@ -1,25 +1,5 @@ const mongoose = require('mongoose'); - -const pluginAuthSchema = mongoose.Schema( - { - authField: { - type: String, - required: true, - }, - value: { - type: String, - required: true, - }, - userId: { - type: String, - required: true, - }, - pluginKey: { - type: String, - }, - }, - { timestamps: true }, -); +const { pluginAuthSchema } = require('@librechat/data-schemas'); const PluginAuth = mongoose.models.Plugin || mongoose.model('PluginAuth', pluginAuthSchema); diff --git a/api/models/schema/presetSchema.js b/api/models/schema/presetSchema.js index e1c92ab9c0..6d03803ace 100644 --- a/api/models/schema/presetSchema.js +++ b/api/models/schema/presetSchema.js @@ -1,38 +1,5 @@ const mongoose = require('mongoose'); -const { conversationPreset } = require('./defaults'); -const presetSchema = mongoose.Schema( - { - presetId: { - type: String, - unique: true, - required: true, - index: true, - }, - title: { - type: String, - default: 'New Chat', - meiliIndex: true, - }, - user: { - type: String, - default: null, - }, - defaultPreset: { - type: Boolean, - }, - order: { - type: Number, - }, - // google only - examples: [{ type: mongoose.Schema.Types.Mixed }], - ...conversationPreset, - agentOptions: { - type: mongoose.Schema.Types.Mixed, - default: null, - }, - }, - { timestamps: true }, -); +const { presetSchema } = require('@librechat/data-schemas'); const Preset = mongoose.models.Preset || mongoose.model('Preset', presetSchema); diff --git a/api/models/schema/projectSchema.js b/api/models/schema/projectSchema.js deleted file mode 100644 index dfa68a06c2..0000000000 --- a/api/models/schema/projectSchema.js +++ /dev/null @@ -1,35 +0,0 @@ -const { Schema } = require('mongoose'); - -/** - * @typedef {Object} MongoProject - * @property {ObjectId} [_id] - MongoDB Document ID - * @property {string} name - The name of the project - * @property {ObjectId[]} promptGroupIds - Array of PromptGroup IDs associated with the project - * @property {Date} [createdAt] - Date when the project was created (added by timestamps) - * @property {Date} [updatedAt] - Date when the project was last updated (added by timestamps) - */ - -const projectSchema = new Schema( - { - name: { - type: String, - required: true, - index: true, - }, - promptGroupIds: { - type: [Schema.Types.ObjectId], - ref: 'PromptGroup', - default: [], - }, - agentIds: { - type: [String], - ref: 'Agent', - default: [], - }, - }, - { - timestamps: true, - }, -); - -module.exports = projectSchema; diff --git a/api/models/schema/promptSchema.js b/api/models/schema/promptSchema.js deleted file mode 100644 index 5464caf639..0000000000 --- a/api/models/schema/promptSchema.js +++ /dev/null @@ -1,118 +0,0 @@ -const mongoose = require('mongoose'); -const { Constants } = require('librechat-data-provider'); -const Schema = mongoose.Schema; - -/** - * @typedef {Object} MongoPromptGroup - * @property {ObjectId} [_id] - MongoDB Document ID - * @property {string} name - The name of the prompt group - * @property {ObjectId} author - The author of the prompt group - * @property {ObjectId} [projectId=null] - The project ID of the prompt group - * @property {ObjectId} [productionId=null] - The project ID of the prompt group - * @property {string} authorName - The name of the author of the prompt group - * @property {number} [numberOfGenerations=0] - Number of generations the prompt group has - * @property {string} [oneliner=''] - Oneliner description of the prompt group - * @property {string} [category=''] - Category of the prompt group - * @property {string} [command] - Command for the prompt group - * @property {Date} [createdAt] - Date when the prompt group was created (added by timestamps) - * @property {Date} [updatedAt] - Date when the prompt group was last updated (added by timestamps) - */ - -const promptGroupSchema = new Schema( - { - name: { - type: String, - required: true, - index: true, - }, - numberOfGenerations: { - type: Number, - default: 0, - }, - oneliner: { - type: String, - default: '', - }, - category: { - type: String, - default: '', - index: true, - }, - projectIds: { - type: [Schema.Types.ObjectId], - ref: 'Project', - index: true, - }, - productionId: { - type: Schema.Types.ObjectId, - ref: 'Prompt', - required: true, - index: true, - }, - author: { - type: Schema.Types.ObjectId, - ref: 'User', - required: true, - index: true, - }, - authorName: { - type: String, - required: true, - }, - command: { - type: String, - index: true, - validate: { - validator: function (v) { - return v === undefined || v === null || v === '' || /^[a-z0-9-]+$/.test(v); - }, - message: (props) => - `${props.value} is not a valid command. Only lowercase alphanumeric characters and highfins (') are allowed.`, - }, - maxlength: [ - Constants.COMMANDS_MAX_LENGTH, - `Command cannot be longer than ${Constants.COMMANDS_MAX_LENGTH} characters`, - ], - }, - }, - { - timestamps: true, - }, -); - -const PromptGroup = mongoose.model('PromptGroup', promptGroupSchema); - -const promptSchema = new Schema( - { - groupId: { - type: Schema.Types.ObjectId, - ref: 'PromptGroup', - required: true, - index: true, - }, - author: { - type: Schema.Types.ObjectId, - ref: 'User', - required: true, - }, - prompt: { - type: String, - required: true, - }, - type: { - type: String, - enum: ['text', 'chat'], - required: true, - }, - }, - { - timestamps: true, - }, -); - -const Prompt = mongoose.model('Prompt', promptSchema); - -promptSchema.index({ createdAt: 1, updatedAt: 1 }); -promptGroupSchema.index({ createdAt: 1, updatedAt: 1 }); - -module.exports = { Prompt, PromptGroup }; diff --git a/api/models/schema/roleSchema.js b/api/models/schema/roleSchema.js deleted file mode 100644 index 36e9d3f7b6..0000000000 --- a/api/models/schema/roleSchema.js +++ /dev/null @@ -1,55 +0,0 @@ -const { PermissionTypes, Permissions } = require('librechat-data-provider'); -const mongoose = require('mongoose'); - -const roleSchema = new mongoose.Schema({ - name: { - type: String, - required: true, - unique: true, - index: true, - }, - [PermissionTypes.BOOKMARKS]: { - [Permissions.USE]: { - type: Boolean, - default: true, - }, - }, - [PermissionTypes.PROMPTS]: { - [Permissions.SHARED_GLOBAL]: { - type: Boolean, - default: false, - }, - [Permissions.USE]: { - type: Boolean, - default: true, - }, - [Permissions.CREATE]: { - type: Boolean, - default: true, - }, - }, - [PermissionTypes.AGENTS]: { - [Permissions.SHARED_GLOBAL]: { - type: Boolean, - default: false, - }, - [Permissions.USE]: { - type: Boolean, - default: true, - }, - [Permissions.CREATE]: { - type: Boolean, - default: true, - }, - }, - [PermissionTypes.MULTI_CONVO]: { - [Permissions.USE]: { - type: Boolean, - default: true, - }, - }, -}); - -const Role = mongoose.model('Role', roleSchema); - -module.exports = Role; diff --git a/api/models/schema/session.js b/api/models/schema/session.js deleted file mode 100644 index ccda43573d..0000000000 --- a/api/models/schema/session.js +++ /dev/null @@ -1,20 +0,0 @@ -const mongoose = require('mongoose'); - -const sessionSchema = mongoose.Schema({ - refreshTokenHash: { - type: String, - required: true, - }, - expiration: { - type: Date, - required: true, - expires: 0, - }, - user: { - type: mongoose.Schema.Types.ObjectId, - ref: 'User', - required: true, - }, -}); - -module.exports = sessionSchema; diff --git a/api/models/schema/toolCallSchema.js b/api/models/schema/toolCallSchema.js deleted file mode 100644 index 2af4c67c1b..0000000000 --- a/api/models/schema/toolCallSchema.js +++ /dev/null @@ -1,54 +0,0 @@ -const mongoose = require('mongoose'); - -/** - * @typedef {Object} ToolCallData - * @property {string} conversationId - The ID of the conversation - * @property {string} messageId - The ID of the message - * @property {string} toolId - The ID of the tool - * @property {string | ObjectId} user - The user's ObjectId - * @property {unknown} [result] - Optional result data - * @property {TAttachment[]} [attachments] - Optional attachments data - * @property {number} [blockIndex] - Optional code block index - * @property {number} [partIndex] - Optional part index - */ - -/** @type {MongooseSchema} */ -const toolCallSchema = mongoose.Schema( - { - conversationId: { - type: String, - required: true, - }, - messageId: { - type: String, - required: true, - }, - toolId: { - type: String, - required: true, - }, - user: { - type: mongoose.Schema.Types.ObjectId, - ref: 'User', - required: true, - }, - result: { - type: mongoose.Schema.Types.Mixed, - }, - attachments: { - type: mongoose.Schema.Types.Mixed, - }, - blockIndex: { - type: Number, - }, - partIndex: { - type: Number, - }, - }, - { timestamps: true }, -); - -toolCallSchema.index({ messageId: 1, user: 1 }); -toolCallSchema.index({ conversationId: 1, user: 1 }); - -module.exports = mongoose.model('ToolCall', toolCallSchema); diff --git a/api/models/schema/userSchema.js b/api/models/schema/userSchema.js deleted file mode 100644 index f586553367..0000000000 --- a/api/models/schema/userSchema.js +++ /dev/null @@ -1,140 +0,0 @@ -const mongoose = require('mongoose'); -const { SystemRoles } = require('librechat-data-provider'); - -/** - * @typedef {Object} MongoSession - * @property {string} [refreshToken] - The refresh token - */ - -/** - * @typedef {Object} MongoUser - * @property {ObjectId} [_id] - MongoDB Document ID - * @property {string} [name] - The user's name - * @property {string} [username] - The user's username, in lowercase - * @property {string} email - The user's email address - * @property {boolean} emailVerified - Whether the user's email is verified - * @property {string} [password] - The user's password, trimmed with 8-128 characters - * @property {string} [avatar] - The URL of the user's avatar - * @property {string} provider - The provider of the user's account (e.g., 'local', 'google') - * @property {string} [role='USER'] - The role of the user - * @property {string} [googleId] - Optional Google ID for the user - * @property {string} [facebookId] - Optional Facebook ID for the user - * @property {string} [openidId] - Optional OpenID ID for the user - * @property {string} [ldapId] - Optional LDAP ID for the user - * @property {string} [githubId] - Optional GitHub ID for the user - * @property {string} [discordId] - Optional Discord ID for the user - * @property {string} [appleId] - Optional Apple ID for the user - * @property {Array} [plugins=[]] - List of plugins used by the user - * @property {Array.} [refreshToken] - List of sessions with refresh tokens - * @property {Date} [expiresAt] - Optional expiration date of the file - * @property {Date} [createdAt] - Date when the user was created (added by timestamps) - * @property {Date} [updatedAt] - Date when the user was last updated (added by timestamps) - */ - -/** @type {MongooseSchema} */ -const Session = mongoose.Schema({ - refreshToken: { - type: String, - default: '', - }, -}); - -/** @type {MongooseSchema} */ -const userSchema = mongoose.Schema( - { - name: { - type: String, - }, - username: { - type: String, - lowercase: true, - default: '', - }, - email: { - type: String, - required: [true, 'can\'t be blank'], - lowercase: true, - unique: true, - match: [/\S+@\S+\.\S+/, 'is invalid'], - index: true, - }, - emailVerified: { - type: Boolean, - required: true, - default: false, - }, - password: { - type: String, - trim: true, - minlength: 8, - maxlength: 128, - }, - avatar: { - type: String, - required: false, - }, - provider: { - type: String, - required: true, - default: 'local', - }, - role: { - type: String, - default: SystemRoles.USER, - }, - googleId: { - type: String, - unique: true, - sparse: true, - }, - facebookId: { - type: String, - unique: true, - sparse: true, - }, - openidId: { - type: String, - unique: true, - sparse: true, - }, - ldapId: { - type: String, - unique: true, - sparse: true, - }, - githubId: { - type: String, - unique: true, - sparse: true, - }, - discordId: { - type: String, - unique: true, - sparse: true, - }, - appleId: { - type: String, - unique: true, - sparse: true, - }, - plugins: { - type: Array, - default: [], - }, - refreshToken: { - type: [Session], - }, - expiresAt: { - type: Date, - expires: 604800, // 7 days in seconds - }, - termsAccepted: { - type: Boolean, - default: false, - }, - }, - - { timestamps: true }, -); - -module.exports = userSchema; diff --git a/api/models/spendTokens.js b/api/models/spendTokens.js index f91b2bb9cd..36b71ca9fc 100644 --- a/api/models/spendTokens.js +++ b/api/models/spendTokens.js @@ -36,7 +36,7 @@ const spendTokens = async (txData, tokenUsage) => { prompt = await Transaction.create({ ...txData, tokenType: 'prompt', - rawAmount: -Math.max(promptTokens, 0), + rawAmount: promptTokens === 0 ? 0 : -Math.max(promptTokens, 0), }); } @@ -44,7 +44,7 @@ const spendTokens = async (txData, tokenUsage) => { completion = await Transaction.create({ ...txData, tokenType: 'completion', - rawAmount: -Math.max(completionTokens, 0), + rawAmount: completionTokens === 0 ? 0 : -Math.max(completionTokens, 0), }); } diff --git a/api/models/spendTokens.spec.js b/api/models/spendTokens.spec.js index 91056bb54c..eacf420330 100644 --- a/api/models/spendTokens.spec.js +++ b/api/models/spendTokens.spec.js @@ -1,17 +1,10 @@ const mongoose = require('mongoose'); +const { MongoMemoryServer } = require('mongodb-memory-server'); +const { Transaction } = require('./Transaction'); +const Balance = require('./Balance'); +const { spendTokens, spendStructuredTokens } = require('./spendTokens'); -jest.mock('./Transaction', () => ({ - Transaction: { - create: jest.fn(), - createStructured: jest.fn(), - }, -})); - -jest.mock('./Balance', () => ({ - findOne: jest.fn(), - findOneAndUpdate: jest.fn(), -})); - +// Mock the logger to prevent console output during tests jest.mock('~/config', () => ({ logger: { debug: jest.fn(), @@ -19,19 +12,46 @@ jest.mock('~/config', () => ({ }, })); -// Import after mocking -const { spendTokens, spendStructuredTokens } = require('./spendTokens'); -const { Transaction } = require('./Transaction'); -const Balance = require('./Balance'); +// Mock the Config service +const { getBalanceConfig } = require('~/server/services/Config'); +jest.mock('~/server/services/Config'); + describe('spendTokens', () => { - beforeEach(() => { - jest.clearAllMocks(); - process.env.CHECK_BALANCE = 'true'; + let mongoServer; + let userId; + + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + await mongoose.connect(mongoUri); + }); + + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); + }); + + beforeEach(async () => { + // Clear collections before each test + await Transaction.deleteMany({}); + await Balance.deleteMany({}); + + // Create a new user ID for each test + userId = new mongoose.Types.ObjectId(); + + // Mock the balance config to be enabled by default + getBalanceConfig.mockResolvedValue({ enabled: true }); }); it('should create transactions for both prompt and completion tokens', async () => { + // Create a balance for the user + await Balance.create({ + user: userId, + tokenCredits: 10000, + }); + const txData = { - user: new mongoose.Types.ObjectId(), + user: userId, conversationId: 'test-convo', model: 'gpt-3.5-turbo', context: 'test', @@ -41,31 +61,35 @@ describe('spendTokens', () => { completionTokens: 50, }; - Transaction.create.mockResolvedValueOnce({ tokenType: 'prompt', rawAmount: -100 }); - Transaction.create.mockResolvedValueOnce({ tokenType: 'completion', rawAmount: -50 }); - Balance.findOne.mockResolvedValue({ tokenCredits: 10000 }); - Balance.findOneAndUpdate.mockResolvedValue({ tokenCredits: 9850 }); - await spendTokens(txData, tokenUsage); - expect(Transaction.create).toHaveBeenCalledTimes(2); - expect(Transaction.create).toHaveBeenCalledWith( - expect.objectContaining({ - tokenType: 'prompt', - rawAmount: -100, - }), - ); - expect(Transaction.create).toHaveBeenCalledWith( - expect.objectContaining({ - tokenType: 'completion', - rawAmount: -50, - }), - ); + // Verify transactions were created + const transactions = await Transaction.find({ user: userId }).sort({ tokenType: 1 }); + expect(transactions).toHaveLength(2); + + // Check completion transaction + expect(transactions[0].tokenType).toBe('completion'); + expect(transactions[0].rawAmount).toBe(-50); + + // Check prompt transaction + expect(transactions[1].tokenType).toBe('prompt'); + expect(transactions[1].rawAmount).toBe(-100); + + // Verify balance was updated + const balance = await Balance.findOne({ user: userId }); + expect(balance).toBeDefined(); + expect(balance.tokenCredits).toBeLessThan(10000); // Balance should be reduced }); it('should handle zero completion tokens', async () => { + // Create a balance for the user + await Balance.create({ + user: userId, + tokenCredits: 10000, + }); + const txData = { - user: new mongoose.Types.ObjectId(), + user: userId, conversationId: 'test-convo', model: 'gpt-3.5-turbo', context: 'test', @@ -75,31 +99,26 @@ describe('spendTokens', () => { completionTokens: 0, }; - Transaction.create.mockResolvedValueOnce({ tokenType: 'prompt', rawAmount: -100 }); - Transaction.create.mockResolvedValueOnce({ tokenType: 'completion', rawAmount: -0 }); - Balance.findOne.mockResolvedValue({ tokenCredits: 10000 }); - Balance.findOneAndUpdate.mockResolvedValue({ tokenCredits: 9850 }); - await spendTokens(txData, tokenUsage); - expect(Transaction.create).toHaveBeenCalledTimes(2); - expect(Transaction.create).toHaveBeenCalledWith( - expect.objectContaining({ - tokenType: 'prompt', - rawAmount: -100, - }), - ); - expect(Transaction.create).toHaveBeenCalledWith( - expect.objectContaining({ - tokenType: 'completion', - rawAmount: -0, // Changed from 0 to -0 - }), - ); + // Verify transactions were created + const transactions = await Transaction.find({ user: userId }).sort({ tokenType: 1 }); + expect(transactions).toHaveLength(2); + + // Check completion transaction + expect(transactions[0].tokenType).toBe('completion'); + // In JavaScript -0 and 0 are different but functionally equivalent + // Use Math.abs to handle both 0 and -0 + expect(Math.abs(transactions[0].rawAmount)).toBe(0); + + // Check prompt transaction + expect(transactions[1].tokenType).toBe('prompt'); + expect(transactions[1].rawAmount).toBe(-100); }); it('should handle undefined token counts', async () => { const txData = { - user: new mongoose.Types.ObjectId(), + user: userId, conversationId: 'test-convo', model: 'gpt-3.5-turbo', context: 'test', @@ -108,13 +127,22 @@ describe('spendTokens', () => { await spendTokens(txData, tokenUsage); - expect(Transaction.create).not.toHaveBeenCalled(); + // Verify no transactions were created + const transactions = await Transaction.find({ user: userId }); + expect(transactions).toHaveLength(0); }); - it('should not update balance when CHECK_BALANCE is false', async () => { - process.env.CHECK_BALANCE = 'false'; + it('should not update balance when the balance feature is disabled', async () => { + // Override configuration: disable balance updates + getBalanceConfig.mockResolvedValue({ enabled: false }); + // Create a balance for the user + await Balance.create({ + user: userId, + tokenCredits: 10000, + }); + const txData = { - user: new mongoose.Types.ObjectId(), + user: userId, conversationId: 'test-convo', model: 'gpt-3.5-turbo', context: 'test', @@ -124,19 +152,529 @@ describe('spendTokens', () => { completionTokens: 50, }; - Transaction.create.mockResolvedValueOnce({ tokenType: 'prompt', rawAmount: -100 }); - Transaction.create.mockResolvedValueOnce({ tokenType: 'completion', rawAmount: -50 }); + await spendTokens(txData, tokenUsage); + + // Verify transactions were created + const transactions = await Transaction.find({ user: userId }); + expect(transactions).toHaveLength(2); + + // Verify balance was not updated (should still be 10000) + const balance = await Balance.findOne({ user: userId }); + expect(balance.tokenCredits).toBe(10000); + }); + + it('should not allow balance to go below zero when spending tokens', async () => { + // Create a balance with a low amount + await Balance.create({ + user: userId, + tokenCredits: 5000, + }); + + const txData = { + user: userId, + conversationId: 'test-convo', + model: 'gpt-4', // Using a more expensive model + context: 'test', + }; + + // Spending more tokens than the user has balance for + const tokenUsage = { + promptTokens: 1000, + completionTokens: 500, + }; await spendTokens(txData, tokenUsage); - expect(Transaction.create).toHaveBeenCalledTimes(2); - expect(Balance.findOne).not.toHaveBeenCalled(); - expect(Balance.findOneAndUpdate).not.toHaveBeenCalled(); + // Verify transactions were created + const transactions = await Transaction.find({ user: userId }).sort({ tokenType: 1 }); + expect(transactions).toHaveLength(2); + + // Verify balance was reduced to exactly 0, not negative + const balance = await Balance.findOne({ user: userId }); + expect(balance).toBeDefined(); + expect(balance.tokenCredits).toBe(0); + + // Check that the transaction records show the adjusted values + const transactionResults = await Promise.all( + transactions.map((t) => + Transaction.create({ + ...txData, + tokenType: t.tokenType, + rawAmount: t.rawAmount, + }), + ), + ); + + // The second transaction should have an adjusted value since balance is already 0 + expect(transactionResults[1]).toEqual( + expect.objectContaining({ + balance: 0, + }), + ); + }); + + it('should handle multiple transactions in sequence with low balance and not increase balance', async () => { + // This test is specifically checking for the issue reported in production + // where the balance increases after a transaction when it should remain at 0 + // Create a balance with a very low amount + await Balance.create({ + user: userId, + tokenCredits: 100, + }); + + // First transaction - should reduce balance to 0 + const txData1 = { + user: userId, + conversationId: 'test-convo-1', + model: 'gpt-4', + context: 'test', + }; + + const tokenUsage1 = { + promptTokens: 100, + completionTokens: 50, + }; + + await spendTokens(txData1, tokenUsage1); + + // Check balance after first transaction + let balance = await Balance.findOne({ user: userId }); + expect(balance.tokenCredits).toBe(0); + + // Second transaction - should keep balance at 0, not make it negative or increase it + const txData2 = { + user: userId, + conversationId: 'test-convo-2', + model: 'gpt-4', + context: 'test', + }; + + const tokenUsage2 = { + promptTokens: 200, + completionTokens: 100, + }; + + await spendTokens(txData2, tokenUsage2); + + // Check balance after second transaction - should still be 0 + balance = await Balance.findOne({ user: userId }); + expect(balance.tokenCredits).toBe(0); + + // Verify all transactions were created + const transactions = await Transaction.find({ user: userId }); + expect(transactions).toHaveLength(4); // 2 transactions (prompt+completion) for each call + + // Let's examine the actual transaction records to see what's happening + const transactionDetails = await Transaction.find({ user: userId }).sort({ createdAt: 1 }); + + // Log the transaction details for debugging + console.log('Transaction details:'); + transactionDetails.forEach((tx, i) => { + console.log(`Transaction ${i + 1}:`, { + tokenType: tx.tokenType, + rawAmount: tx.rawAmount, + tokenValue: tx.tokenValue, + model: tx.model, + }); + }); + + // Check the return values from Transaction.create directly + // This is to verify that the incrementValue is not becoming positive + const directResult = await Transaction.create({ + user: userId, + conversationId: 'test-convo-3', + model: 'gpt-4', + tokenType: 'completion', + rawAmount: -100, + context: 'test', + }); + + console.log('Direct Transaction.create result:', directResult); + + // The completion value should never be positive + expect(directResult.completion).not.toBeGreaterThan(0); + }); + + it('should ensure tokenValue is always negative for spending tokens', async () => { + // Create a balance for the user + await Balance.create({ + user: userId, + tokenCredits: 10000, + }); + + // Test with various models to check multiplier calculations + const models = ['gpt-3.5-turbo', 'gpt-4', 'claude-3-5-sonnet']; + + for (const model of models) { + const txData = { + user: userId, + conversationId: `test-convo-${model}`, + model, + context: 'test', + }; + + const tokenUsage = { + promptTokens: 100, + completionTokens: 50, + }; + + await spendTokens(txData, tokenUsage); + + // Get the transactions for this model + const transactions = await Transaction.find({ + user: userId, + model, + }); + + // Verify tokenValue is negative for all transactions + transactions.forEach((tx) => { + console.log(`Model ${model}, Type ${tx.tokenType}: tokenValue = ${tx.tokenValue}`); + expect(tx.tokenValue).toBeLessThan(0); + }); + } + }); + + it('should handle structured transactions in sequence with low balance', async () => { + // Create a balance with a very low amount + await Balance.create({ + user: userId, + tokenCredits: 100, + }); + + // First transaction - should reduce balance to 0 + const txData1 = { + user: userId, + conversationId: 'test-convo-1', + model: 'claude-3-5-sonnet', + context: 'test', + }; + + const tokenUsage1 = { + promptTokens: { + input: 10, + write: 100, + read: 5, + }, + completionTokens: 50, + }; + + await spendStructuredTokens(txData1, tokenUsage1); + + // Check balance after first transaction + let balance = await Balance.findOne({ user: userId }); + expect(balance.tokenCredits).toBe(0); + + // Second transaction - should keep balance at 0, not make it negative or increase it + const txData2 = { + user: userId, + conversationId: 'test-convo-2', + model: 'claude-3-5-sonnet', + context: 'test', + }; + + const tokenUsage2 = { + promptTokens: { + input: 20, + write: 200, + read: 10, + }, + completionTokens: 100, + }; + + await spendStructuredTokens(txData2, tokenUsage2); + + // Check balance after second transaction - should still be 0 + balance = await Balance.findOne({ user: userId }); + expect(balance.tokenCredits).toBe(0); + + // Verify all transactions were created + const transactions = await Transaction.find({ user: userId }); + expect(transactions).toHaveLength(4); // 2 transactions (prompt+completion) for each call + + // Let's examine the actual transaction records to see what's happening + const transactionDetails = await Transaction.find({ user: userId }).sort({ createdAt: 1 }); + + // Log the transaction details for debugging + console.log('Structured transaction details:'); + transactionDetails.forEach((tx, i) => { + console.log(`Transaction ${i + 1}:`, { + tokenType: tx.tokenType, + rawAmount: tx.rawAmount, + tokenValue: tx.tokenValue, + inputTokens: tx.inputTokens, + writeTokens: tx.writeTokens, + readTokens: tx.readTokens, + model: tx.model, + }); + }); + }); + + it('should not allow balance to go below zero when spending structured tokens', async () => { + // Create a balance with a low amount + await Balance.create({ + user: userId, + tokenCredits: 5000, + }); + + const txData = { + user: userId, + conversationId: 'test-convo', + model: 'claude-3-5-sonnet', // Using a model that supports structured tokens + context: 'test', + }; + + // Spending more tokens than the user has balance for + const tokenUsage = { + promptTokens: { + input: 100, + write: 1000, + read: 50, + }, + completionTokens: 500, + }; + + const result = await spendStructuredTokens(txData, tokenUsage); + + // Verify transactions were created + const transactions = await Transaction.find({ user: userId }).sort({ tokenType: 1 }); + expect(transactions).toHaveLength(2); + + // Verify balance was reduced to exactly 0, not negative + const balance = await Balance.findOne({ user: userId }); + expect(balance).toBeDefined(); + expect(balance.tokenCredits).toBe(0); + + // The result should show the adjusted values + expect(result).toEqual({ + prompt: expect.objectContaining({ + user: userId.toString(), + balance: expect.any(Number), + }), + completion: expect.objectContaining({ + user: userId.toString(), + balance: 0, // Final balance should be 0 + }), + }); + }); + + it('should handle multiple concurrent transactions correctly with a high balance', async () => { + // Create a balance with a high amount + const initialBalance = 10000000; + await Balance.create({ + user: userId, + tokenCredits: initialBalance, + }); + + // Simulate the recordCollectedUsage function from the production code + const conversationId = 'test-concurrent-convo'; + const context = 'message'; + const model = 'gpt-4'; + + const amount = 50; + // Create `amount` of usage records to simulate multiple transactions + const collectedUsage = Array.from({ length: amount }, (_, i) => ({ + model, + input_tokens: 100 + i * 10, // Increasing input tokens + output_tokens: 50 + i * 5, // Increasing output tokens + input_token_details: { + cache_creation: i % 2 === 0 ? 20 : 0, // Some have cache creation + cache_read: i % 3 === 0 ? 10 : 0, // Some have cache read + }, + })); + + // Process all transactions concurrently to simulate race conditions + const promises = []; + let expectedTotalSpend = 0; + + for (let i = 0; i < collectedUsage.length; i++) { + const usage = collectedUsage[i]; + if (!usage) { + continue; + } + + const cache_creation = Number(usage.input_token_details?.cache_creation) || 0; + const cache_read = Number(usage.input_token_details?.cache_read) || 0; + + const txMetadata = { + context, + conversationId, + user: userId, + model: usage.model, + }; + + // Calculate expected spend for this transaction + const promptTokens = usage.input_tokens; + const completionTokens = usage.output_tokens; + + // For regular transactions + if (cache_creation === 0 && cache_read === 0) { + // Add to expected spend using the correct multipliers from tx.js + // For gpt-4, the multipliers are: prompt=30, completion=60 + expectedTotalSpend += promptTokens * 30; // gpt-4 prompt rate is 30 + expectedTotalSpend += completionTokens * 60; // gpt-4 completion rate is 60 + + promises.push( + spendTokens(txMetadata, { + promptTokens, + completionTokens, + }), + ); + } else { + // For structured transactions with cache operations + // The multipliers for claude models with cache operations are different + // But since we're using gpt-4 in the test, we need to use appropriate values + expectedTotalSpend += promptTokens * 30; // Base prompt rate for gpt-4 + // Since gpt-4 doesn't have cache multipliers defined, we'll use the prompt rate + expectedTotalSpend += cache_creation * 30; // Write rate (using prompt rate as fallback) + expectedTotalSpend += cache_read * 30; // Read rate (using prompt rate as fallback) + expectedTotalSpend += completionTokens * 60; // Completion rate for gpt-4 + + promises.push( + spendStructuredTokens(txMetadata, { + promptTokens: { + input: promptTokens, + write: cache_creation, + read: cache_read, + }, + completionTokens, + }), + ); + } + } + + // Wait for all transactions to complete + await Promise.all(promises); + + // Verify final balance + const finalBalance = await Balance.findOne({ user: userId }); + expect(finalBalance).toBeDefined(); + + // The final balance should be the initial balance minus the expected total spend + const expectedFinalBalance = initialBalance - expectedTotalSpend; + + console.log('Initial balance:', initialBalance); + console.log('Expected total spend:', expectedTotalSpend); + console.log('Expected final balance:', expectedFinalBalance); + console.log('Actual final balance:', finalBalance.tokenCredits); + + // Allow for small rounding differences + expect(finalBalance.tokenCredits).toBeCloseTo(expectedFinalBalance, 0); + + // Verify all transactions were created + const transactions = await Transaction.find({ + user: userId, + conversationId, + }); + + // We should have 2 transactions (prompt + completion) for each usage record + // Some might be structured, some regular + expect(transactions.length).toBeGreaterThanOrEqual(collectedUsage.length); + + // Log transaction details for debugging + console.log('Transaction summary:'); + let totalTokenValue = 0; + transactions.forEach((tx) => { + console.log(`${tx.tokenType}: rawAmount=${tx.rawAmount}, tokenValue=${tx.tokenValue}`); + totalTokenValue += tx.tokenValue; + }); + console.log('Total token value from transactions:', totalTokenValue); + + // The difference between expected and actual is significant + // This is likely due to the multipliers being different in the test environment + // Let's adjust our expectation based on the actual transactions + const actualSpend = initialBalance - finalBalance.tokenCredits; + console.log('Actual spend:', actualSpend); + + // Instead of checking the exact balance, let's verify that: + // 1. The balance was reduced (tokens were spent) + expect(finalBalance.tokenCredits).toBeLessThan(initialBalance); + // 2. The total token value from transactions matches the actual spend + expect(Math.abs(totalTokenValue)).toBeCloseTo(actualSpend, -3); // Allow for larger differences + }); + + // Add this new test case + it('should handle multiple concurrent balance increases correctly', async () => { + // Start with zero balance + const initialBalance = 0; + await Balance.create({ + user: userId, + tokenCredits: initialBalance, + }); + + const numberOfRefills = 25; + const refillAmount = 1000; + + const promises = []; + for (let i = 0; i < numberOfRefills; i++) { + promises.push( + Transaction.createAutoRefillTransaction({ + user: userId, + tokenType: 'credits', + context: 'concurrent-refill-test', + rawAmount: refillAmount, + }), + ); + } + + // Wait for all refill transactions to complete + const results = await Promise.all(promises); + + // Verify final balance + const finalBalance = await Balance.findOne({ user: userId }); + expect(finalBalance).toBeDefined(); + + // The final balance should be the initial balance plus the sum of all refills + const expectedFinalBalance = initialBalance + numberOfRefills * refillAmount; + + console.log('Initial balance (Increase Test):', initialBalance); + console.log(`Performed ${numberOfRefills} refills of ${refillAmount} each.`); + console.log('Expected final balance (Increase Test):', expectedFinalBalance); + console.log('Actual final balance (Increase Test):', finalBalance.tokenCredits); + + // Use toBeCloseTo for safety, though toBe should work for integer math + expect(finalBalance.tokenCredits).toBeCloseTo(expectedFinalBalance, 0); + + // Verify all transactions were created + const transactions = await Transaction.find({ + user: userId, + context: 'concurrent-refill-test', + }); + + // We should have one transaction for each refill attempt + expect(transactions.length).toBe(numberOfRefills); + + // Optional: Verify the sum of increments from the results matches the balance change + const totalIncrementReported = results.reduce((sum, result) => { + // Assuming createAutoRefillTransaction returns an object with the increment amount + // Adjust this based on the actual return structure. + // Let's assume it returns { balance: newBalance, transaction: { rawAmount: ... } } + // Or perhaps we check the transaction.rawAmount directly + return sum + (result?.transaction?.rawAmount || 0); + }, 0); + console.log('Total increment reported by results:', totalIncrementReported); + expect(totalIncrementReported).toBe(expectedFinalBalance - initialBalance); + + // Optional: Check the sum of tokenValue from saved transactions + let totalTokenValueFromDb = 0; + transactions.forEach((tx) => { + // For refills, rawAmount is positive, and tokenValue might be calculated based on it + // Let's assume tokenValue directly reflects the increment for simplicity here + // If calculation is involved, adjust accordingly + totalTokenValueFromDb += tx.rawAmount; // Or tx.tokenValue if that holds the increment + }); + console.log('Total rawAmount from DB transactions:', totalTokenValueFromDb); + expect(totalTokenValueFromDb).toBeCloseTo(expectedFinalBalance - initialBalance, 0); }); it('should create structured transactions for both prompt and completion tokens', async () => { + // Create a balance for the user + await Balance.create({ + user: userId, + tokenCredits: 10000, + }); + const txData = { - user: new mongoose.Types.ObjectId(), + user: userId, conversationId: 'test-convo', model: 'claude-3-5-sonnet', context: 'test', @@ -150,48 +688,37 @@ describe('spendTokens', () => { completionTokens: 50, }; - Transaction.createStructured.mockResolvedValueOnce({ - rate: 3.75, - user: txData.user.toString(), - balance: 9570, - prompt: -430, - }); - Transaction.create.mockResolvedValueOnce({ - rate: 15, - user: txData.user.toString(), - balance: 8820, - completion: -750, - }); - const result = await spendStructuredTokens(txData, tokenUsage); - expect(Transaction.createStructured).toHaveBeenCalledWith( - expect.objectContaining({ - tokenType: 'prompt', - inputTokens: -10, - writeTokens: -100, - readTokens: -5, - }), - ); - expect(Transaction.create).toHaveBeenCalledWith( - expect.objectContaining({ - tokenType: 'completion', - rawAmount: -50, - }), - ); + // Verify transactions were created + const transactions = await Transaction.find({ user: userId }).sort({ tokenType: 1 }); + expect(transactions).toHaveLength(2); + + // Check completion transaction + expect(transactions[0].tokenType).toBe('completion'); + expect(transactions[0].rawAmount).toBe(-50); + + // Check prompt transaction + expect(transactions[1].tokenType).toBe('prompt'); + expect(transactions[1].inputTokens).toBe(-10); + expect(transactions[1].writeTokens).toBe(-100); + expect(transactions[1].readTokens).toBe(-5); + + // Verify result contains transaction info expect(result).toEqual({ prompt: expect.objectContaining({ - rate: 3.75, - user: txData.user.toString(), - balance: 9570, - prompt: -430, + user: userId.toString(), + prompt: expect.any(Number), }), completion: expect.objectContaining({ - rate: 15, - user: txData.user.toString(), - balance: 8820, - completion: -750, + user: userId.toString(), + completion: expect.any(Number), }), }); + + // Verify balance was updated + const balance = await Balance.findOne({ user: userId }); + expect(balance).toBeDefined(); + expect(balance.tokenCredits).toBeLessThan(10000); // Balance should be reduced }); }); diff --git a/api/models/tx.js b/api/models/tx.js index 05412430c7..df88390b17 100644 --- a/api/models/tx.js +++ b/api/models/tx.js @@ -61,6 +61,7 @@ const bedrockValues = { 'amazon.nova-micro-v1:0': { prompt: 0.035, completion: 0.14 }, 'amazon.nova-lite-v1:0': { prompt: 0.06, completion: 0.24 }, 'amazon.nova-pro-v1:0': { prompt: 0.8, completion: 3.2 }, + 'deepseek.r1': { prompt: 1.35, completion: 5.4 }, }; /** @@ -75,10 +76,16 @@ const tokenValues = Object.assign( '4k': { prompt: 1.5, completion: 2 }, '16k': { prompt: 3, completion: 4 }, 'gpt-3.5-turbo-1106': { prompt: 1, completion: 2 }, + 'o4-mini': { prompt: 1.1, completion: 4.4 }, 'o3-mini': { prompt: 1.1, completion: 4.4 }, + o3: { prompt: 10, completion: 40 }, 'o1-mini': { prompt: 1.1, completion: 4.4 }, 'o1-preview': { prompt: 15, completion: 60 }, o1: { prompt: 15, completion: 60 }, + 'gpt-4.1-nano': { prompt: 0.1, completion: 0.4 }, + 'gpt-4.1-mini': { prompt: 0.4, completion: 1.6 }, + 'gpt-4.1': { prompt: 2, completion: 8 }, + 'gpt-4.5': { prompt: 75, completion: 150 }, 'gpt-4o-mini': { prompt: 0.15, completion: 0.6 }, 'gpt-4o': { prompt: 2.5, completion: 10 }, 'gpt-4o-2024-05-13': { prompt: 5, completion: 15 }, @@ -88,6 +95,8 @@ const tokenValues = Object.assign( 'claude-3-sonnet': { prompt: 3, completion: 15 }, 'claude-3-5-sonnet': { prompt: 3, completion: 15 }, 'claude-3.5-sonnet': { prompt: 3, completion: 15 }, + 'claude-3-7-sonnet': { prompt: 3, completion: 15 }, + 'claude-3.7-sonnet': { prompt: 3, completion: 15 }, 'claude-3-5-haiku': { prompt: 0.8, completion: 4 }, 'claude-3.5-haiku': { prompt: 0.8, completion: 4 }, 'claude-3-haiku': { prompt: 0.25, completion: 1.25 }, @@ -102,14 +111,39 @@ const tokenValues = Object.assign( /* cohere doesn't have rates for the older command models, so this was from https://artificialanalysis.ai/models/command-light/providers */ command: { prompt: 0.38, completion: 0.38 }, + gemma: { prompt: 0, completion: 0 }, // https://ai.google.dev/pricing + 'gemma-2': { prompt: 0, completion: 0 }, // https://ai.google.dev/pricing + 'gemma-3': { prompt: 0, completion: 0 }, // https://ai.google.dev/pricing + 'gemma-3-27b': { prompt: 0, completion: 0 }, // https://ai.google.dev/pricing 'gemini-2.0-flash-lite': { prompt: 0.075, completion: 0.3 }, - 'gemini-2.0-flash': { prompt: 0.1, completion: 0.7 }, + 'gemini-2.0-flash': { prompt: 0.1, completion: 0.4 }, 'gemini-2.0': { prompt: 0, completion: 0 }, // https://ai.google.dev/pricing + 'gemini-2.5-pro': { prompt: 1.25, completion: 10 }, + 'gemini-2.5-flash': { prompt: 0.15, completion: 3.5 }, + 'gemini-2.5': { prompt: 0, completion: 0 }, // Free for a period of time 'gemini-1.5-flash-8b': { prompt: 0.075, completion: 0.3 }, 'gemini-1.5-flash': { prompt: 0.15, completion: 0.6 }, 'gemini-1.5': { prompt: 2.5, completion: 10 }, 'gemini-pro-vision': { prompt: 0.5, completion: 1.5 }, gemini: { prompt: 0.5, completion: 1.5 }, + 'grok-2-vision-1212': { prompt: 2.0, completion: 10.0 }, + 'grok-2-vision-latest': { prompt: 2.0, completion: 10.0 }, + 'grok-2-vision': { prompt: 2.0, completion: 10.0 }, + 'grok-vision-beta': { prompt: 5.0, completion: 15.0 }, + 'grok-2-1212': { prompt: 2.0, completion: 10.0 }, + 'grok-2-latest': { prompt: 2.0, completion: 10.0 }, + 'grok-2': { prompt: 2.0, completion: 10.0 }, + 'grok-3-mini-fast': { prompt: 0.4, completion: 4 }, + 'grok-3-mini': { prompt: 0.3, completion: 0.5 }, + 'grok-3-fast': { prompt: 5.0, completion: 25.0 }, + 'grok-3': { prompt: 3.0, completion: 15.0 }, + 'grok-beta': { prompt: 5.0, completion: 15.0 }, + 'mistral-large': { prompt: 2.0, completion: 6.0 }, + 'pixtral-large': { prompt: 2.0, completion: 6.0 }, + 'mistral-saba': { prompt: 0.2, completion: 0.6 }, + codestral: { prompt: 0.3, completion: 0.9 }, + 'ministral-8b': { prompt: 0.1, completion: 0.1 }, + 'ministral-3b': { prompt: 0.04, completion: 0.04 }, }, bedrockValues, ); @@ -121,6 +155,8 @@ const tokenValues = Object.assign( * @type {Object.} */ const cacheTokenValues = { + 'claude-3.7-sonnet': { write: 3.75, read: 0.3 }, + 'claude-3-7-sonnet': { write: 3.75, read: 0.3 }, 'claude-3.5-sonnet': { write: 3.75, read: 0.3 }, 'claude-3-5-sonnet': { write: 3.75, read: 0.3 }, 'claude-3.5-haiku': { write: 1, read: 0.08 }, @@ -149,12 +185,28 @@ const getValueKey = (model, endpoint) => { return 'gpt-3.5-turbo-1106'; } else if (modelName.includes('gpt-3.5')) { return '4k'; + } else if (modelName.includes('o4-mini')) { + return 'o4-mini'; + } else if (modelName.includes('o4')) { + return 'o4'; + } else if (modelName.includes('o3-mini')) { + return 'o3-mini'; + } else if (modelName.includes('o3')) { + return 'o3'; } else if (modelName.includes('o1-preview')) { return 'o1-preview'; } else if (modelName.includes('o1-mini')) { return 'o1-mini'; } else if (modelName.includes('o1')) { return 'o1'; + } else if (modelName.includes('gpt-4.5')) { + return 'gpt-4.5'; + } else if (modelName.includes('gpt-4.1-nano')) { + return 'gpt-4.1-nano'; + } else if (modelName.includes('gpt-4.1-mini')) { + return 'gpt-4.1-mini'; + } else if (modelName.includes('gpt-4.1')) { + return 'gpt-4.1'; } else if (modelName.includes('gpt-4o-2024-05-13')) { return 'gpt-4o-2024-05-13'; } else if (modelName.includes('gpt-4o-mini')) { diff --git a/api/models/tx.spec.js b/api/models/tx.spec.js index d77973a7f5..97a730232d 100644 --- a/api/models/tx.spec.js +++ b/api/models/tx.spec.js @@ -50,6 +50,40 @@ describe('getValueKey', () => { expect(getValueKey('gpt-4-0125')).toBe('gpt-4-1106'); }); + it('should return "gpt-4.5" for model type of "gpt-4.5"', () => { + expect(getValueKey('gpt-4.5-preview')).toBe('gpt-4.5'); + expect(getValueKey('gpt-4.5-2024-08-06')).toBe('gpt-4.5'); + expect(getValueKey('gpt-4.5-2024-08-06-0718')).toBe('gpt-4.5'); + expect(getValueKey('openai/gpt-4.5')).toBe('gpt-4.5'); + expect(getValueKey('openai/gpt-4.5-2024-08-06')).toBe('gpt-4.5'); + expect(getValueKey('gpt-4.5-turbo')).toBe('gpt-4.5'); + expect(getValueKey('gpt-4.5-0125')).toBe('gpt-4.5'); + }); + + it('should return "gpt-4.1" for model type of "gpt-4.1"', () => { + expect(getValueKey('gpt-4.1-preview')).toBe('gpt-4.1'); + expect(getValueKey('gpt-4.1-2024-08-06')).toBe('gpt-4.1'); + expect(getValueKey('gpt-4.1-2024-08-06-0718')).toBe('gpt-4.1'); + expect(getValueKey('openai/gpt-4.1')).toBe('gpt-4.1'); + expect(getValueKey('openai/gpt-4.1-2024-08-06')).toBe('gpt-4.1'); + expect(getValueKey('gpt-4.1-turbo')).toBe('gpt-4.1'); + expect(getValueKey('gpt-4.1-0125')).toBe('gpt-4.1'); + }); + + it('should return "gpt-4.1-mini" for model type of "gpt-4.1-mini"', () => { + expect(getValueKey('gpt-4.1-mini-preview')).toBe('gpt-4.1-mini'); + expect(getValueKey('gpt-4.1-mini-2024-08-06')).toBe('gpt-4.1-mini'); + expect(getValueKey('openai/gpt-4.1-mini')).toBe('gpt-4.1-mini'); + expect(getValueKey('gpt-4.1-mini-0125')).toBe('gpt-4.1-mini'); + }); + + it('should return "gpt-4.1-nano" for model type of "gpt-4.1-nano"', () => { + expect(getValueKey('gpt-4.1-nano-preview')).toBe('gpt-4.1-nano'); + expect(getValueKey('gpt-4.1-nano-2024-08-06')).toBe('gpt-4.1-nano'); + expect(getValueKey('openai/gpt-4.1-nano')).toBe('gpt-4.1-nano'); + expect(getValueKey('gpt-4.1-nano-0125')).toBe('gpt-4.1-nano'); + }); + it('should return "gpt-4o" for model type of "gpt-4o"', () => { expect(getValueKey('gpt-4o-2024-08-06')).toBe('gpt-4o'); expect(getValueKey('gpt-4o-2024-08-06-0718')).toBe('gpt-4o'); @@ -80,6 +114,20 @@ describe('getValueKey', () => { expect(getValueKey('chatgpt-4o-latest-0718')).toBe('gpt-4o'); }); + it('should return "claude-3-7-sonnet" for model type of "claude-3-7-sonnet-"', () => { + expect(getValueKey('claude-3-7-sonnet-20240620')).toBe('claude-3-7-sonnet'); + expect(getValueKey('anthropic/claude-3-7-sonnet')).toBe('claude-3-7-sonnet'); + expect(getValueKey('claude-3-7-sonnet-turbo')).toBe('claude-3-7-sonnet'); + expect(getValueKey('claude-3-7-sonnet-0125')).toBe('claude-3-7-sonnet'); + }); + + it('should return "claude-3.7-sonnet" for model type of "claude-3.7-sonnet-"', () => { + expect(getValueKey('claude-3.7-sonnet-20240620')).toBe('claude-3.7-sonnet'); + expect(getValueKey('anthropic/claude-3.7-sonnet')).toBe('claude-3.7-sonnet'); + expect(getValueKey('claude-3.7-sonnet-turbo')).toBe('claude-3.7-sonnet'); + expect(getValueKey('claude-3.7-sonnet-0125')).toBe('claude-3.7-sonnet'); + }); + it('should return "claude-3-5-sonnet" for model type of "claude-3-5-sonnet-"', () => { expect(getValueKey('claude-3-5-sonnet-20240620')).toBe('claude-3-5-sonnet'); expect(getValueKey('anthropic/claude-3-5-sonnet')).toBe('claude-3-5-sonnet'); @@ -117,6 +165,15 @@ describe('getMultiplier', () => { ); }); + it('should return correct multipliers for o4-mini and o3', () => { + ['o4-mini', 'o3'].forEach((model) => { + const prompt = getMultiplier({ model, tokenType: 'prompt' }); + const completion = getMultiplier({ model, tokenType: 'completion' }); + expect(prompt).toBe(tokenValues[model].prompt); + expect(completion).toBe(tokenValues[model].completion); + }); + }); + it('should return defaultRate if tokenType is provided but not found in tokenValues', () => { expect(getMultiplier({ valueKey: '8k', tokenType: 'unknownType' })).toBe(defaultRate); }); @@ -161,6 +218,52 @@ describe('getMultiplier', () => { ); }); + it('should return the correct multiplier for gpt-4.1', () => { + const valueKey = getValueKey('gpt-4.1-2024-08-06'); + expect(getMultiplier({ valueKey, tokenType: 'prompt' })).toBe(tokenValues['gpt-4.1'].prompt); + expect(getMultiplier({ valueKey, tokenType: 'completion' })).toBe( + tokenValues['gpt-4.1'].completion, + ); + expect(getMultiplier({ model: 'gpt-4.1-preview', tokenType: 'prompt' })).toBe( + tokenValues['gpt-4.1'].prompt, + ); + expect(getMultiplier({ model: 'openai/gpt-4.1', tokenType: 'completion' })).toBe( + tokenValues['gpt-4.1'].completion, + ); + }); + + it('should return the correct multiplier for gpt-4.1-mini', () => { + const valueKey = getValueKey('gpt-4.1-mini-2024-08-06'); + expect(getMultiplier({ valueKey, tokenType: 'prompt' })).toBe( + tokenValues['gpt-4.1-mini'].prompt, + ); + expect(getMultiplier({ valueKey, tokenType: 'completion' })).toBe( + tokenValues['gpt-4.1-mini'].completion, + ); + expect(getMultiplier({ model: 'gpt-4.1-mini-preview', tokenType: 'prompt' })).toBe( + tokenValues['gpt-4.1-mini'].prompt, + ); + expect(getMultiplier({ model: 'openai/gpt-4.1-mini', tokenType: 'completion' })).toBe( + tokenValues['gpt-4.1-mini'].completion, + ); + }); + + it('should return the correct multiplier for gpt-4.1-nano', () => { + const valueKey = getValueKey('gpt-4.1-nano-2024-08-06'); + expect(getMultiplier({ valueKey, tokenType: 'prompt' })).toBe( + tokenValues['gpt-4.1-nano'].prompt, + ); + expect(getMultiplier({ valueKey, tokenType: 'completion' })).toBe( + tokenValues['gpt-4.1-nano'].completion, + ); + expect(getMultiplier({ model: 'gpt-4.1-nano-preview', tokenType: 'prompt' })).toBe( + tokenValues['gpt-4.1-nano'].prompt, + ); + expect(getMultiplier({ model: 'openai/gpt-4.1-nano', tokenType: 'completion' })).toBe( + tokenValues['gpt-4.1-nano'].completion, + ); + }); + it('should return the correct multiplier for gpt-4o-mini', () => { const valueKey = getValueKey('gpt-4o-mini-2024-07-18'); expect(getMultiplier({ valueKey, tokenType: 'prompt' })).toBe( @@ -264,7 +367,7 @@ describe('AWS Bedrock Model Tests', () => { }); describe('Deepseek Model Tests', () => { - const deepseekModels = ['deepseek-chat', 'deepseek-coder', 'deepseek-reasoner']; + const deepseekModels = ['deepseek-chat', 'deepseek-coder', 'deepseek-reasoner', 'deepseek.r1']; it('should return the correct prompt multipliers for all models', () => { const results = deepseekModels.map((model) => { @@ -324,9 +427,11 @@ describe('getCacheMultiplier', () => { it('should derive the valueKey from the model if not provided', () => { expect(getCacheMultiplier({ cacheType: 'write', model: 'claude-3-5-sonnet-20240620' })).toBe( - 3.75, + cacheTokenValues['claude-3-5-sonnet'].write, + ); + expect(getCacheMultiplier({ cacheType: 'read', model: 'claude-3-haiku-20240307' })).toBe( + cacheTokenValues['claude-3-haiku'].read, ); - expect(getCacheMultiplier({ cacheType: 'read', model: 'claude-3-haiku-20240307' })).toBe(0.03); }); it('should return null if only model or cacheType is missing', () => { @@ -347,10 +452,10 @@ describe('getCacheMultiplier', () => { }; expect( getCacheMultiplier({ model: 'custom-model', cacheType: 'write', endpointTokenConfig }), - ).toBe(5); + ).toBe(endpointTokenConfig['custom-model'].write); expect( getCacheMultiplier({ model: 'custom-model', cacheType: 'read', endpointTokenConfig }), - ).toBe(1); + ).toBe(endpointTokenConfig['custom-model'].read); }); it('should return null if model is not found in endpointTokenConfig', () => { @@ -371,18 +476,21 @@ describe('getCacheMultiplier', () => { model: 'bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0', cacheType: 'write', }), - ).toBe(3.75); + ).toBe(cacheTokenValues['claude-3-5-sonnet'].write); expect( getCacheMultiplier({ model: 'bedrock/anthropic.claude-3-haiku-20240307-v1:0', cacheType: 'read', }), - ).toBe(0.03); + ).toBe(cacheTokenValues['claude-3-haiku'].read); }); }); describe('Google Model Tests', () => { const googleModels = [ + 'gemini-2.5-pro-preview-05-06', + 'gemini-2.5-flash-preview-04-17', + 'gemini-2.5-exp', 'gemini-2.0-flash-lite-preview-02-05', 'gemini-2.0-flash-001', 'gemini-2.0-flash-exp', @@ -420,6 +528,9 @@ describe('Google Model Tests', () => { it('should map to the correct model keys', () => { const expected = { + 'gemini-2.5-pro-preview-05-06': 'gemini-2.5-pro', + 'gemini-2.5-flash-preview-04-17': 'gemini-2.5-flash', + 'gemini-2.5-exp': 'gemini-2.5', 'gemini-2.0-flash-lite-preview-02-05': 'gemini-2.0-flash-lite', 'gemini-2.0-flash-001': 'gemini-2.0-flash', 'gemini-2.0-flash-exp': 'gemini-2.0-flash', @@ -458,3 +569,98 @@ describe('Google Model Tests', () => { }); }); }); + +describe('Grok Model Tests - Pricing', () => { + describe('getMultiplier', () => { + test('should return correct prompt and completion rates for Grok vision models', () => { + const models = ['grok-2-vision-1212', 'grok-2-vision', 'grok-2-vision-latest']; + models.forEach((model) => { + expect(getMultiplier({ model, tokenType: 'prompt' })).toBe( + tokenValues['grok-2-vision'].prompt, + ); + expect(getMultiplier({ model, tokenType: 'completion' })).toBe( + tokenValues['grok-2-vision'].completion, + ); + }); + }); + + test('should return correct prompt and completion rates for Grok text models', () => { + const models = ['grok-2-1212', 'grok-2', 'grok-2-latest']; + models.forEach((model) => { + expect(getMultiplier({ model, tokenType: 'prompt' })).toBe(tokenValues['grok-2'].prompt); + expect(getMultiplier({ model, tokenType: 'completion' })).toBe( + tokenValues['grok-2'].completion, + ); + }); + }); + + test('should return correct prompt and completion rates for Grok beta models', () => { + expect(getMultiplier({ model: 'grok-vision-beta', tokenType: 'prompt' })).toBe( + tokenValues['grok-vision-beta'].prompt, + ); + expect(getMultiplier({ model: 'grok-vision-beta', tokenType: 'completion' })).toBe( + tokenValues['grok-vision-beta'].completion, + ); + expect(getMultiplier({ model: 'grok-beta', tokenType: 'prompt' })).toBe( + tokenValues['grok-beta'].prompt, + ); + expect(getMultiplier({ model: 'grok-beta', tokenType: 'completion' })).toBe( + tokenValues['grok-beta'].completion, + ); + }); + + test('should return correct prompt and completion rates for Grok 3 models', () => { + expect(getMultiplier({ model: 'grok-3', tokenType: 'prompt' })).toBe( + tokenValues['grok-3'].prompt, + ); + expect(getMultiplier({ model: 'grok-3', tokenType: 'completion' })).toBe( + tokenValues['grok-3'].completion, + ); + expect(getMultiplier({ model: 'grok-3-fast', tokenType: 'prompt' })).toBe( + tokenValues['grok-3-fast'].prompt, + ); + expect(getMultiplier({ model: 'grok-3-fast', tokenType: 'completion' })).toBe( + tokenValues['grok-3-fast'].completion, + ); + expect(getMultiplier({ model: 'grok-3-mini', tokenType: 'prompt' })).toBe( + tokenValues['grok-3-mini'].prompt, + ); + expect(getMultiplier({ model: 'grok-3-mini', tokenType: 'completion' })).toBe( + tokenValues['grok-3-mini'].completion, + ); + expect(getMultiplier({ model: 'grok-3-mini-fast', tokenType: 'prompt' })).toBe( + tokenValues['grok-3-mini-fast'].prompt, + ); + expect(getMultiplier({ model: 'grok-3-mini-fast', tokenType: 'completion' })).toBe( + tokenValues['grok-3-mini-fast'].completion, + ); + }); + + test('should return correct prompt and completion rates for Grok 3 models with prefixes', () => { + expect(getMultiplier({ model: 'xai/grok-3', tokenType: 'prompt' })).toBe( + tokenValues['grok-3'].prompt, + ); + expect(getMultiplier({ model: 'xai/grok-3', tokenType: 'completion' })).toBe( + tokenValues['grok-3'].completion, + ); + expect(getMultiplier({ model: 'xai/grok-3-fast', tokenType: 'prompt' })).toBe( + tokenValues['grok-3-fast'].prompt, + ); + expect(getMultiplier({ model: 'xai/grok-3-fast', tokenType: 'completion' })).toBe( + tokenValues['grok-3-fast'].completion, + ); + expect(getMultiplier({ model: 'xai/grok-3-mini', tokenType: 'prompt' })).toBe( + tokenValues['grok-3-mini'].prompt, + ); + expect(getMultiplier({ model: 'xai/grok-3-mini', tokenType: 'completion' })).toBe( + tokenValues['grok-3-mini'].completion, + ); + expect(getMultiplier({ model: 'xai/grok-3-mini-fast', tokenType: 'prompt' })).toBe( + tokenValues['grok-3-mini-fast'].prompt, + ); + expect(getMultiplier({ model: 'xai/grok-3-mini-fast', tokenType: 'completion' })).toBe( + tokenValues['grok-3-mini-fast'].completion, + ); + }); + }); +}); diff --git a/api/models/userMethods.js b/api/models/userMethods.js index 63b25edd3a..fbcd33aba8 100644 --- a/api/models/userMethods.js +++ b/api/models/userMethods.js @@ -1,6 +1,6 @@ const bcrypt = require('bcryptjs'); +const { getBalanceConfig } = require('~/server/services/Config'); const signPayload = require('~/server/services/signPayload'); -const { isEnabled } = require('~/server/utils/handleText'); const Balance = require('./Balance'); const User = require('./User'); @@ -13,11 +13,9 @@ const User = require('./User'); */ const getUserById = async function (userId, fieldsToSelect = null) { const query = User.findById(userId); - if (fieldsToSelect) { query.select(fieldsToSelect); } - return await query.lean(); }; @@ -32,7 +30,6 @@ const findUser = async function (searchCriteria, fieldsToSelect = null) { if (fieldsToSelect) { query.select(fieldsToSelect); } - return await query.lean(); }; @@ -58,11 +55,12 @@ const updateUser = async function (userId, updateData) { * Creates a new user, optionally with a TTL of 1 week. * @param {MongoUser} data - The user data to be created, must contain user_id. * @param {boolean} [disableTTL=true] - Whether to disable the TTL. Defaults to `true`. - * @param {boolean} [returnUser=false] - Whether to disable the TTL. Defaults to `true`. - * @returns {Promise} A promise that resolves to the created user document ID. + * @param {boolean} [returnUser=false] - Whether to return the created user object. + * @returns {Promise} A promise that resolves to the created user document ID or user object. * @throws {Error} If a user with the same user_id already exists. */ const createUser = async (data, disableTTL = true, returnUser = false) => { + const balance = await getBalanceConfig(); const userData = { ...data, expiresAt: disableTTL ? null : new Date(Date.now() + 604800 * 1000), // 1 week in milliseconds @@ -74,13 +72,27 @@ const createUser = async (data, disableTTL = true, returnUser = false) => { const user = await User.create(userData); - if (isEnabled(process.env.CHECK_BALANCE) && process.env.START_BALANCE) { - let incrementValue = parseInt(process.env.START_BALANCE); - await Balance.findOneAndUpdate( - { user: user._id }, - { $inc: { tokenCredits: incrementValue } }, - { upsert: true, new: true }, - ).lean(); + // If balance is enabled, create or update a balance record for the user using global.interfaceConfig.balance + if (balance?.enabled && balance?.startBalance) { + const update = { + $inc: { tokenCredits: balance.startBalance }, + }; + + if ( + balance.autoRefillEnabled && + balance.refillIntervalValue != null && + balance.refillIntervalUnit != null && + balance.refillAmount != null + ) { + update.$set = { + autoRefillEnabled: true, + refillIntervalValue: balance.refillIntervalValue, + refillIntervalUnit: balance.refillIntervalUnit, + refillAmount: balance.refillAmount, + }; + } + + await Balance.findOneAndUpdate({ user: user._id }, update, { upsert: true, new: true }).lean(); } if (returnUser) { @@ -123,7 +135,7 @@ const expires = eval(SESSION_EXPIRY) ?? 1000 * 60 * 15; /** * Generates a JWT token for a given user. * - * @param {MongoUser} user - ID of the user for whom the token is being generated. + * @param {MongoUser} user - The user for whom the token is being generated. * @returns {Promise} A promise that resolves to a JWT token. */ const generateToken = async (user) => { @@ -146,7 +158,7 @@ const generateToken = async (user) => { /** * Compares the provided password with the user's password. * - * @param {MongoUser} user - the user to compare password for. + * @param {MongoUser} user - The user to compare the password for. * @param {string} candidatePassword - The password to test against the user's password. * @returns {Promise} A promise that resolves to a boolean indicating if the password matches. */ diff --git a/api/package.json b/api/package.json index 8d5a997e6e..bcf94a6cad 100644 --- a/api/package.json +++ b/api/package.json @@ -1,6 +1,6 @@ { "name": "@librechat/backend", - "version": "v0.7.7-rc1", + "version": "v0.7.8", "description": "", "scripts": { "start": "echo 'please run this from the root directory'", @@ -34,20 +34,24 @@ }, "homepage": "https://librechat.ai", "dependencies": { - "@anthropic-ai/sdk": "^0.32.1", + "@anthropic-ai/sdk": "^0.37.0", + "@aws-sdk/client-s3": "^3.758.0", + "@aws-sdk/s3-request-presigner": "^3.758.0", + "@azure/identity": "^4.7.0", "@azure/search-documents": "^12.0.0", - "@google/generative-ai": "^0.21.0", + "@azure/storage-blob": "^12.27.0", + "@google/generative-ai": "^0.23.0", "@googleapis/youtube": "^20.0.0", - "@keyv/mongo": "^2.1.8", - "@keyv/redis": "^2.8.1", - "@langchain/community": "^0.3.14", - "@langchain/core": "^0.3.37", - "@langchain/google-genai": "^0.1.7", - "@langchain/google-vertexai": "^0.1.8", + "@keyv/redis": "^4.3.3", + "@langchain/community": "^0.3.42", + "@langchain/core": "^0.3.55", + "@langchain/google-genai": "^0.2.8", + "@langchain/google-vertexai": "^0.2.8", "@langchain/textsplitters": "^0.1.0", - "@librechat/agents": "^2.0.4", + "@librechat/agents": "^2.4.317", + "@librechat/data-schemas": "*", "@waylaidwanderer/fetch-event-source": "^3.0.1", - "axios": "1.7.8", + "axios": "^1.8.2", "bcryptjs": "^2.4.3", "cohere-ai": "^7.9.1", "compression": "^1.7.4", @@ -57,21 +61,23 @@ "cors": "^2.8.5", "dedent": "^1.5.3", "dotenv": "^16.0.3", + "eventsource": "^3.0.2", "express": "^4.21.2", "express-mongo-sanitize": "^2.2.0", "express-rate-limit": "^7.4.1", "express-session": "^1.18.1", + "express-static-gzip": "^2.2.0", "file-type": "^18.7.0", "firebase": "^11.0.2", "googleapis": "^126.0.1", "handlebars": "^4.7.7", + "https-proxy-agent": "^7.0.6", "ioredis": "^5.3.2", "js-yaml": "^4.1.0", "jsonwebtoken": "^9.0.0", - "keyv": "^4.5.4", - "keyv-file": "^0.2.0", + "keyv": "^5.3.2", + "keyv-file": "^5.1.2", "klona": "^2.0.6", - "langchain": "^0.2.19", "librechat-data-provider": "*", "librechat-mcp": "*", "lodash": "^4.17.21", @@ -79,12 +85,12 @@ "memorystore": "^1.6.7", "mime": "^3.0.0", "module-alias": "^2.2.3", - "mongoose": "^8.9.5", + "mongoose": "^8.12.1", "multer": "^1.4.5-lts.1", "nanoid": "^3.3.7", "nodemailer": "^6.9.15", "ollama": "^0.5.0", - "openai": "^4.47.1", + "openai": "^4.96.2", "openai-chat-tokens": "^0.2.8", "openid-client": "^5.4.2", "passport": "^0.6.0", @@ -96,7 +102,8 @@ "passport-jwt": "^4.0.1", "passport-ldapauth": "^3.0.1", "passport-local": "^1.0.0", - "sharp": "^0.32.6", + "rate-limit-redis": "^4.2.0", + "sharp": "^0.33.5", "tiktoken": "^1.0.15", "traverse": "^0.6.7", "ua-parser-js": "^1.0.36", @@ -109,6 +116,6 @@ "jest": "^29.7.0", "mongodb-memory-server": "^10.1.3", "nodemon": "^3.0.3", - "supertest": "^7.0.0" + "supertest": "^7.1.0" } } diff --git a/api/server/cleanup.js b/api/server/cleanup.js new file mode 100644 index 0000000000..6d5b77196a --- /dev/null +++ b/api/server/cleanup.js @@ -0,0 +1,387 @@ +const { logger } = require('~/config'); + +// WeakMap to hold temporary data associated with requests +const requestDataMap = new WeakMap(); + +const FinalizationRegistry = global.FinalizationRegistry || null; + +/** + * FinalizationRegistry to clean up client objects when they are garbage collected. + * This is used to prevent memory leaks and ensure that client objects are + * properly disposed of when they are no longer needed. + * The registry holds a weak reference to the client object and a cleanup + * callback that is called when the client object is garbage collected. + * The callback can be used to perform any necessary cleanup operations, + * such as removing event listeners or freeing up resources. + */ +const clientRegistry = FinalizationRegistry + ? new FinalizationRegistry((heldValue) => { + try { + // This will run when the client is garbage collected + if (heldValue && heldValue.userId) { + logger.debug(`[FinalizationRegistry] Cleaning up client for user ${heldValue.userId}`); + } else { + logger.debug('[FinalizationRegistry] Cleaning up client'); + } + } catch (e) { + // Ignore errors + } + }) + : null; + +/** + * Cleans up the client object by removing references to its properties. + * This is useful for preventing memory leaks and ensuring that the client + * and its properties can be garbage collected when it is no longer needed. + */ +function disposeClient(client) { + if (!client) { + return; + } + + try { + if (client.user) { + client.user = null; + } + if (client.apiKey) { + client.apiKey = null; + } + if (client.azure) { + client.azure = null; + } + if (client.conversationId) { + client.conversationId = null; + } + if (client.responseMessageId) { + client.responseMessageId = null; + } + if (client.message_file_map) { + client.message_file_map = null; + } + if (client.clientName) { + client.clientName = null; + } + if (client.sender) { + client.sender = null; + } + if (client.model) { + client.model = null; + } + if (client.maxContextTokens) { + client.maxContextTokens = null; + } + if (client.contextStrategy) { + client.contextStrategy = null; + } + if (client.currentDateString) { + client.currentDateString = null; + } + if (client.inputTokensKey) { + client.inputTokensKey = null; + } + if (client.outputTokensKey) { + client.outputTokensKey = null; + } + if (client.skipSaveUserMessage !== undefined) { + client.skipSaveUserMessage = null; + } + if (client.visionMode) { + client.visionMode = null; + } + if (client.continued !== undefined) { + client.continued = null; + } + if (client.fetchedConvo !== undefined) { + client.fetchedConvo = null; + } + if (client.previous_summary) { + client.previous_summary = null; + } + if (client.metadata) { + client.metadata = null; + } + if (client.isVisionModel) { + client.isVisionModel = null; + } + if (client.isChatCompletion !== undefined) { + client.isChatCompletion = null; + } + if (client.contextHandlers) { + client.contextHandlers = null; + } + if (client.augmentedPrompt) { + client.augmentedPrompt = null; + } + if (client.systemMessage) { + client.systemMessage = null; + } + if (client.azureEndpoint) { + client.azureEndpoint = null; + } + if (client.langchainProxy) { + client.langchainProxy = null; + } + if (client.isOmni !== undefined) { + client.isOmni = null; + } + if (client.runManager) { + client.runManager = null; + } + // Properties specific to AnthropicClient + if (client.message_start) { + client.message_start = null; + } + if (client.message_delta) { + client.message_delta = null; + } + if (client.isClaude3 !== undefined) { + client.isClaude3 = null; + } + if (client.useMessages !== undefined) { + client.useMessages = null; + } + if (client.isLegacyOutput !== undefined) { + client.isLegacyOutput = null; + } + if (client.supportsCacheControl !== undefined) { + client.supportsCacheControl = null; + } + // Properties specific to GoogleClient + if (client.serviceKey) { + client.serviceKey = null; + } + if (client.project_id) { + client.project_id = null; + } + if (client.client_email) { + client.client_email = null; + } + if (client.private_key) { + client.private_key = null; + } + if (client.access_token) { + client.access_token = null; + } + if (client.reverseProxyUrl) { + client.reverseProxyUrl = null; + } + if (client.authHeader) { + client.authHeader = null; + } + if (client.isGenerativeModel !== undefined) { + client.isGenerativeModel = null; + } + // Properties specific to OpenAIClient + if (client.ChatGPTClient) { + client.ChatGPTClient = null; + } + if (client.completionsUrl) { + client.completionsUrl = null; + } + if (client.shouldSummarize !== undefined) { + client.shouldSummarize = null; + } + if (client.isOllama !== undefined) { + client.isOllama = null; + } + if (client.FORCE_PROMPT !== undefined) { + client.FORCE_PROMPT = null; + } + if (client.isChatGptModel !== undefined) { + client.isChatGptModel = null; + } + if (client.isUnofficialChatGptModel !== undefined) { + client.isUnofficialChatGptModel = null; + } + if (client.useOpenRouter !== undefined) { + client.useOpenRouter = null; + } + if (client.startToken) { + client.startToken = null; + } + if (client.endToken) { + client.endToken = null; + } + if (client.userLabel) { + client.userLabel = null; + } + if (client.chatGptLabel) { + client.chatGptLabel = null; + } + if (client.modelLabel) { + client.modelLabel = null; + } + if (client.modelOptions) { + client.modelOptions = null; + } + if (client.defaultVisionModel) { + client.defaultVisionModel = null; + } + if (client.maxPromptTokens) { + client.maxPromptTokens = null; + } + if (client.maxResponseTokens) { + client.maxResponseTokens = null; + } + if (client.run) { + // Break circular references in run + if (client.run.Graph) { + client.run.Graph.resetValues(); + client.run.Graph.handlerRegistry = null; + client.run.Graph.runId = null; + client.run.Graph.tools = null; + client.run.Graph.signal = null; + client.run.Graph.config = null; + client.run.Graph.toolEnd = null; + client.run.Graph.toolMap = null; + client.run.Graph.provider = null; + client.run.Graph.streamBuffer = null; + client.run.Graph.clientOptions = null; + client.run.Graph.graphState = null; + if (client.run.Graph.boundModel?.client) { + client.run.Graph.boundModel.client = null; + } + client.run.Graph.boundModel = null; + client.run.Graph.systemMessage = null; + client.run.Graph.reasoningKey = null; + client.run.Graph.messages = null; + client.run.Graph.contentData = null; + client.run.Graph.stepKeyIds = null; + client.run.Graph.contentIndexMap = null; + client.run.Graph.toolCallStepIds = null; + client.run.Graph.messageIdsByStepKey = null; + client.run.Graph.messageStepHasToolCalls = null; + client.run.Graph.prelimMessageIdsByStepKey = null; + client.run.Graph.currentTokenType = null; + client.run.Graph.lastToken = null; + client.run.Graph.tokenTypeSwitch = null; + client.run.Graph.indexTokenCountMap = null; + client.run.Graph.currentUsage = null; + client.run.Graph.tokenCounter = null; + client.run.Graph.maxContextTokens = null; + client.run.Graph.pruneMessages = null; + client.run.Graph.lastStreamCall = null; + client.run.Graph.startIndex = null; + client.run.Graph = null; + } + if (client.run.handlerRegistry) { + client.run.handlerRegistry = null; + } + if (client.run.graphRunnable) { + if (client.run.graphRunnable.channels) { + client.run.graphRunnable.channels = null; + } + if (client.run.graphRunnable.nodes) { + client.run.graphRunnable.nodes = null; + } + if (client.run.graphRunnable.lc_kwargs) { + client.run.graphRunnable.lc_kwargs = null; + } + if (client.run.graphRunnable.builder?.nodes) { + client.run.graphRunnable.builder.nodes = null; + client.run.graphRunnable.builder = null; + } + client.run.graphRunnable = null; + } + client.run = null; + } + if (client.sendMessage) { + client.sendMessage = null; + } + if (client.savedMessageIds) { + client.savedMessageIds.clear(); + client.savedMessageIds = null; + } + if (client.currentMessages) { + client.currentMessages = null; + } + if (client.streamHandler) { + client.streamHandler = null; + } + if (client.contentParts) { + client.contentParts = null; + } + if (client.abortController) { + client.abortController = null; + } + if (client.collectedUsage) { + client.collectedUsage = null; + } + if (client.indexTokenCountMap) { + client.indexTokenCountMap = null; + } + if (client.agentConfigs) { + client.agentConfigs = null; + } + if (client.artifactPromises) { + client.artifactPromises = null; + } + if (client.usage) { + client.usage = null; + } + if (typeof client.dispose === 'function') { + client.dispose(); + } + if (client.options) { + if (client.options.req) { + client.options.req = null; + } + if (client.options.res) { + client.options.res = null; + } + if (client.options.attachments) { + client.options.attachments = null; + } + if (client.options.agent) { + client.options.agent = null; + } + } + client.options = null; + } catch (e) { + // Ignore errors during disposal + } +} + +function processReqData(data = {}, context) { + let { + abortKey, + userMessage, + userMessagePromise, + responseMessageId, + promptTokens, + conversationId, + userMessageId, + } = context; + for (const key in data) { + if (key === 'userMessage') { + userMessage = data[key]; + userMessageId = data[key].messageId; + } else if (key === 'userMessagePromise') { + userMessagePromise = data[key]; + } else if (key === 'responseMessageId') { + responseMessageId = data[key]; + } else if (key === 'promptTokens') { + promptTokens = data[key]; + } else if (key === 'abortKey') { + abortKey = data[key]; + } else if (!conversationId && key === 'conversationId') { + conversationId = data[key]; + } + } + return { + abortKey, + userMessage, + userMessagePromise, + responseMessageId, + promptTokens, + conversationId, + userMessageId, + }; +} + +module.exports = { + disposeClient, + requestDataMap, + clientRegistry, + processReqData, +}; diff --git a/api/server/controllers/AskController.js b/api/server/controllers/AskController.js index 55fe2fa717..40b209ef35 100644 --- a/api/server/controllers/AskController.js +++ b/api/server/controllers/AskController.js @@ -1,5 +1,15 @@ const { getResponseSender, Constants } = require('librechat-data-provider'); -const { createAbortController, handleAbortError } = require('~/server/middleware'); +const { + handleAbortError, + createAbortController, + cleanupAbortController, +} = require('~/server/middleware'); +const { + disposeClient, + processReqData, + clientRegistry, + requestDataMap, +} = require('~/server/cleanup'); const { sendMessage, createOnProgress } = require('~/server/utils'); const { saveMessage } = require('~/models'); const { logger } = require('~/config'); @@ -14,90 +24,162 @@ const AskController = async (req, res, next, initializeClient, addTitle) => { overrideParentMessageId = null, } = req.body; + let client = null; + let abortKey = null; + let cleanupHandlers = []; + let clientRef = null; + logger.debug('[AskController]', { text, conversationId, ...endpointOption, - modelsConfig: endpointOption.modelsConfig ? 'exists' : '', + modelsConfig: endpointOption?.modelsConfig ? 'exists' : '', }); - let userMessage; - let userMessagePromise; - let promptTokens; - let userMessageId; - let responseMessageId; + let userMessage = null; + let userMessagePromise = null; + let promptTokens = null; + let userMessageId = null; + let responseMessageId = null; + let getAbortData = null; + const sender = getResponseSender({ ...endpointOption, model: endpointOption.modelOptions.model, modelDisplayLabel, }); - const newConvo = !conversationId; - const user = req.user.id; + const initialConversationId = conversationId; + const newConvo = !initialConversationId; + const userId = req.user.id; - const getReqData = (data = {}) => { - for (let key in data) { - if (key === 'userMessage') { - userMessage = data[key]; - userMessageId = data[key].messageId; - } else if (key === 'userMessagePromise') { - userMessagePromise = data[key]; - } else if (key === 'responseMessageId') { - responseMessageId = data[key]; - } else if (key === 'promptTokens') { - promptTokens = data[key]; - } else if (!conversationId && key === 'conversationId') { - conversationId = data[key]; - } - } + let reqDataContext = { + userMessage, + userMessagePromise, + responseMessageId, + promptTokens, + conversationId, + userMessageId, }; - let getText; + const updateReqData = (data = {}) => { + reqDataContext = processReqData(data, reqDataContext); + abortKey = reqDataContext.abortKey; + userMessage = reqDataContext.userMessage; + userMessagePromise = reqDataContext.userMessagePromise; + responseMessageId = reqDataContext.responseMessageId; + promptTokens = reqDataContext.promptTokens; + conversationId = reqDataContext.conversationId; + userMessageId = reqDataContext.userMessageId; + }; + + let { onProgress: progressCallback, getPartialText } = createOnProgress(); + + const performCleanup = () => { + logger.debug('[AskController] Performing cleanup'); + if (Array.isArray(cleanupHandlers)) { + for (const handler of cleanupHandlers) { + try { + if (typeof handler === 'function') { + handler(); + } + } catch (e) { + // Ignore + } + } + } + + if (abortKey) { + logger.debug('[AskController] Cleaning up abort controller'); + cleanupAbortController(abortKey); + abortKey = null; + } + + if (client) { + disposeClient(client); + client = null; + } + + reqDataContext = null; + userMessage = null; + userMessagePromise = null; + promptTokens = null; + getAbortData = null; + progressCallback = null; + endpointOption = null; + cleanupHandlers = null; + addTitle = null; + + if (requestDataMap.has(req)) { + requestDataMap.delete(req); + } + logger.debug('[AskController] Cleanup completed'); + }; try { - const { client } = await initializeClient({ req, res, endpointOption }); - const { onProgress: progressCallback, getPartialText } = createOnProgress(); + ({ client } = await initializeClient({ req, res, endpointOption })); + if (clientRegistry && client) { + clientRegistry.register(client, { userId }, client); + } - getText = client.getStreamText != null ? client.getStreamText.bind(client) : getPartialText; + if (client) { + requestDataMap.set(req, { client }); + } - const getAbortData = () => ({ - sender, - conversationId, - userMessagePromise, - messageId: responseMessageId, - parentMessageId: overrideParentMessageId ?? userMessageId, - text: getText(), - userMessage, - promptTokens, - }); + clientRef = new WeakRef(client); - const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData); + getAbortData = () => { + const currentClient = clientRef?.deref(); + const currentText = + currentClient?.getStreamText != null ? currentClient.getStreamText() : getPartialText(); - res.on('close', () => { + return { + sender, + conversationId, + messageId: reqDataContext.responseMessageId, + parentMessageId: overrideParentMessageId ?? userMessageId, + text: currentText, + userMessage: userMessage, + userMessagePromise: userMessagePromise, + promptTokens: reqDataContext.promptTokens, + }; + }; + + const { onStart, abortController } = createAbortController( + req, + res, + getAbortData, + updateReqData, + ); + + const closeHandler = () => { logger.debug('[AskController] Request closed'); - if (!abortController) { - return; - } else if (abortController.signal.aborted) { - return; - } else if (abortController.requestCompleted) { + if (!abortController || abortController.signal.aborted || abortController.requestCompleted) { return; } - abortController.abort(); logger.debug('[AskController] Request aborted on close'); + }; + + res.on('close', closeHandler); + cleanupHandlers.push(() => { + try { + res.removeListener('close', closeHandler); + } catch (e) { + // Ignore + } }); const messageOptions = { - user, + user: userId, parentMessageId, - conversationId, + conversationId: reqDataContext.conversationId, overrideParentMessageId, - getReqData, + getReqData: updateReqData, onStart, abortController, progressCallback, progressOptions: { res, - // parentMessageId: overrideParentMessageId || userMessageId, }, }; @@ -105,59 +187,95 @@ const AskController = async (req, res, next, initializeClient, addTitle) => { let response = await client.sendMessage(text, messageOptions); response.endpoint = endpointOption.endpoint; - const { conversation = {} } = await client.responsePromise; + const databasePromise = response.databasePromise; + delete response.databasePromise; + + const { conversation: convoData = {} } = await databasePromise; + const conversation = { ...convoData }; conversation.title = conversation && !conversation.title ? null : conversation?.title || 'New Chat'; - if (client.options.attachments) { - userMessage.files = client.options.attachments; - conversation.model = endpointOption.modelOptions.model; - delete userMessage.image_urls; + const latestUserMessage = reqDataContext.userMessage; + + if (client?.options?.attachments && latestUserMessage) { + latestUserMessage.files = client.options.attachments; + if (endpointOption?.modelOptions?.model) { + conversation.model = endpointOption.modelOptions.model; + } + delete latestUserMessage.image_urls; } if (!abortController.signal.aborted) { + const finalResponseMessage = { ...response }; + sendMessage(res, { final: true, conversation, title: conversation.title, - requestMessage: userMessage, - responseMessage: response, + requestMessage: latestUserMessage, + responseMessage: finalResponseMessage, }); res.end(); - if (!client.savedMessageIds.has(response.messageId)) { + if (client?.savedMessageIds && !client.savedMessageIds.has(response.messageId)) { await saveMessage( req, - { ...response, user }, + { ...finalResponseMessage, user: userId }, { context: 'api/server/controllers/AskController.js - response end' }, ); } } - if (!client.skipSaveUserMessage) { - await saveMessage(req, userMessage, { - context: 'api/server/controllers/AskController.js - don\'t skip saving user message', + if (!client?.skipSaveUserMessage && latestUserMessage) { + await saveMessage(req, latestUserMessage, { + context: "api/server/controllers/AskController.js - don't skip saving user message", }); } - if (addTitle && parentMessageId === Constants.NO_PARENT && newConvo) { + if (typeof addTitle === 'function' && parentMessageId === Constants.NO_PARENT && newConvo) { addTitle(req, { text, - response, + response: { ...response }, client, - }); + }) + .then(() => { + logger.debug('[AskController] Title generation started'); + }) + .catch((err) => { + logger.error('[AskController] Error in title generation', err); + }) + .finally(() => { + logger.debug('[AskController] Title generation completed'); + performCleanup(); + }); + } else { + performCleanup(); } } catch (error) { - const partialText = getText && getText(); + logger.error('[AskController] Error handling request', error); + let partialText = ''; + try { + const currentClient = clientRef?.deref(); + partialText = + currentClient?.getStreamText != null ? currentClient.getStreamText() : getPartialText(); + } catch (getTextError) { + logger.error('[AskController] Error calling getText() during error handling', getTextError); + } + handleAbortError(res, req, error, { - partialText, - conversationId, sender, - messageId: responseMessageId, - parentMessageId: userMessageId ?? parentMessageId, - }).catch((err) => { - logger.error('[AskController] Error in `handleAbortError`', err); - }); + partialText, + conversationId: reqDataContext.conversationId, + messageId: reqDataContext.responseMessageId, + parentMessageId: overrideParentMessageId ?? reqDataContext.userMessageId ?? parentMessageId, + userMessageId: reqDataContext.userMessageId, + }) + .catch((err) => { + logger.error('[AskController] Error in `handleAbortError` during catch block', err); + }) + .finally(() => { + performCleanup(); + }); } }; diff --git a/api/server/controllers/AuthController.js b/api/server/controllers/AuthController.js index 71551ea867..7cdfaa9aaf 100644 --- a/api/server/controllers/AuthController.js +++ b/api/server/controllers/AuthController.js @@ -61,7 +61,7 @@ const refreshController = async (req, res) => { try { const payload = jwt.verify(refreshToken, process.env.JWT_REFRESH_SECRET); - const user = await getUserById(payload.id, '-password -__v'); + const user = await getUserById(payload.id, '-password -__v -totpSecret'); if (!user) { return res.status(401).redirect('/login'); } diff --git a/api/server/controllers/EditController.js b/api/server/controllers/EditController.js index 2a2f8c28de..d142d474df 100644 --- a/api/server/controllers/EditController.js +++ b/api/server/controllers/EditController.js @@ -1,5 +1,15 @@ const { getResponseSender } = require('librechat-data-provider'); -const { createAbortController, handleAbortError } = require('~/server/middleware'); +const { + handleAbortError, + createAbortController, + cleanupAbortController, +} = require('~/server/middleware'); +const { + disposeClient, + processReqData, + clientRegistry, + requestDataMap, +} = require('~/server/cleanup'); const { sendMessage, createOnProgress } = require('~/server/utils'); const { saveMessage } = require('~/models'); const { logger } = require('~/config'); @@ -17,6 +27,11 @@ const EditController = async (req, res, next, initializeClient) => { overrideParentMessageId = null, } = req.body; + let client = null; + let abortKey = null; + let cleanupHandlers = []; + let clientRef = null; // Declare clientRef here + logger.debug('[EditController]', { text, generation, @@ -26,123 +41,205 @@ const EditController = async (req, res, next, initializeClient) => { modelsConfig: endpointOption.modelsConfig ? 'exists' : '', }); - let userMessage; - let userMessagePromise; - let promptTokens; + let userMessage = null; + let userMessagePromise = null; + let promptTokens = null; + let getAbortData = null; + const sender = getResponseSender({ ...endpointOption, model: endpointOption.modelOptions.model, modelDisplayLabel, }); const userMessageId = parentMessageId; - const user = req.user.id; + const userId = req.user.id; - const getReqData = (data = {}) => { - for (let key in data) { - if (key === 'userMessage') { - userMessage = data[key]; - } else if (key === 'userMessagePromise') { - userMessagePromise = data[key]; - } else if (key === 'responseMessageId') { - responseMessageId = data[key]; - } else if (key === 'promptTokens') { - promptTokens = data[key]; - } - } + let reqDataContext = { userMessage, userMessagePromise, responseMessageId, promptTokens }; + + const updateReqData = (data = {}) => { + reqDataContext = processReqData(data, reqDataContext); + abortKey = reqDataContext.abortKey; + userMessage = reqDataContext.userMessage; + userMessagePromise = reqDataContext.userMessagePromise; + responseMessageId = reqDataContext.responseMessageId; + promptTokens = reqDataContext.promptTokens; }; - const { onProgress: progressCallback, getPartialText } = createOnProgress({ + let { onProgress: progressCallback, getPartialText } = createOnProgress({ generation, }); - let getText; + const performCleanup = () => { + logger.debug('[EditController] Performing cleanup'); + if (Array.isArray(cleanupHandlers)) { + for (const handler of cleanupHandlers) { + try { + if (typeof handler === 'function') { + handler(); + } + } catch (e) { + // Ignore + } + } + } + + if (abortKey) { + logger.debug('[AskController] Cleaning up abort controller'); + cleanupAbortController(abortKey); + abortKey = null; + } + + if (client) { + disposeClient(client); + client = null; + } + + reqDataContext = null; + userMessage = null; + userMessagePromise = null; + promptTokens = null; + getAbortData = null; + progressCallback = null; + endpointOption = null; + cleanupHandlers = null; + + if (requestDataMap.has(req)) { + requestDataMap.delete(req); + } + logger.debug('[EditController] Cleanup completed'); + }; try { - const { client } = await initializeClient({ req, res, endpointOption }); + ({ client } = await initializeClient({ req, res, endpointOption })); - getText = client.getStreamText != null ? client.getStreamText.bind(client) : getPartialText; + if (clientRegistry && client) { + clientRegistry.register(client, { userId }, client); + } - const getAbortData = () => ({ - conversationId, - userMessagePromise, - messageId: responseMessageId, - sender, - parentMessageId: overrideParentMessageId ?? userMessageId, - text: getText(), - userMessage, - promptTokens, - }); + if (client) { + requestDataMap.set(req, { client }); + } - const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData); + clientRef = new WeakRef(client); - res.on('close', () => { + getAbortData = () => { + const currentClient = clientRef?.deref(); + const currentText = + currentClient?.getStreamText != null ? currentClient.getStreamText() : getPartialText(); + + return { + sender, + conversationId, + messageId: reqDataContext.responseMessageId, + parentMessageId: overrideParentMessageId ?? userMessageId, + text: currentText, + userMessage: userMessage, + userMessagePromise: userMessagePromise, + promptTokens: reqDataContext.promptTokens, + }; + }; + + const { onStart, abortController } = createAbortController( + req, + res, + getAbortData, + updateReqData, + ); + + const closeHandler = () => { logger.debug('[EditController] Request closed'); - if (!abortController) { - return; - } else if (abortController.signal.aborted) { - return; - } else if (abortController.requestCompleted) { + if (!abortController || abortController.signal.aborted || abortController.requestCompleted) { return; } - abortController.abort(); logger.debug('[EditController] Request aborted on close'); + }; + + res.on('close', closeHandler); + cleanupHandlers.push(() => { + try { + res.removeListener('close', closeHandler); + } catch (e) { + // Ignore + } }); let response = await client.sendMessage(text, { - user, + user: userId, generation, isContinued, isEdited: true, conversationId, parentMessageId, - responseMessageId, + responseMessageId: reqDataContext.responseMessageId, overrideParentMessageId, - getReqData, + getReqData: updateReqData, onStart, abortController, progressCallback, progressOptions: { res, - // parentMessageId: overrideParentMessageId || userMessageId, }, }); - const { conversation = {} } = await client.responsePromise; + const databasePromise = response.databasePromise; + delete response.databasePromise; + + const { conversation: convoData = {} } = await databasePromise; + const conversation = { ...convoData }; conversation.title = conversation && !conversation.title ? null : conversation?.title || 'New Chat'; - if (client.options.attachments) { + if (client?.options?.attachments && endpointOption?.modelOptions?.model) { conversation.model = endpointOption.modelOptions.model; } if (!abortController.signal.aborted) { + const finalUserMessage = reqDataContext.userMessage; + const finalResponseMessage = { ...response }; + sendMessage(res, { final: true, conversation, title: conversation.title, - requestMessage: userMessage, - responseMessage: response, + requestMessage: finalUserMessage, + responseMessage: finalResponseMessage, }); res.end(); await saveMessage( req, - { ...response, user }, + { ...finalResponseMessage, user: userId }, { context: 'api/server/controllers/EditController.js - response end' }, ); } + + performCleanup(); } catch (error) { - const partialText = getText(); + logger.error('[EditController] Error handling request', error); + let partialText = ''; + try { + const currentClient = clientRef?.deref(); + partialText = + currentClient?.getStreamText != null ? currentClient.getStreamText() : getPartialText(); + } catch (getTextError) { + logger.error('[EditController] Error calling getText() during error handling', getTextError); + } + handleAbortError(res, req, error, { + sender, partialText, conversationId, - sender, - messageId: responseMessageId, - parentMessageId: userMessageId ?? parentMessageId, - }).catch((err) => { - logger.error('[EditController] Error in `handleAbortError`', err); - }); + messageId: reqDataContext.responseMessageId, + parentMessageId: overrideParentMessageId ?? userMessageId ?? parentMessageId, + userMessageId, + }) + .catch((err) => { + logger.error('[EditController] Error in `handleAbortError` during catch block', err); + }) + .finally(() => { + performCleanup(); + }); } }; diff --git a/api/server/controllers/ModelController.js b/api/server/controllers/ModelController.js index 79dc81d6b0..ad120c2c83 100644 --- a/api/server/controllers/ModelController.js +++ b/api/server/controllers/ModelController.js @@ -1,6 +1,7 @@ const { CacheKeys } = require('librechat-data-provider'); const { loadDefaultModels, loadConfigModels } = require('~/server/services/Config'); const { getLogStores } = require('~/cache'); +const { logger } = require('~/config'); /** * @param {ServerRequest} req @@ -36,8 +37,13 @@ async function loadModels(req) { } async function modelController(req, res) { - const modelConfig = await loadModels(req); - res.send(modelConfig); + try { + const modelConfig = await loadModels(req); + res.send(modelConfig); + } catch (error) { + logger.error('Error fetching models:', error); + res.status(500).send({ error: error.message }); + } } module.exports = { modelController, loadModels, getModelsConfig }; diff --git a/api/server/controllers/PluginController.js b/api/server/controllers/PluginController.js index 9e87b46289..674e36002a 100644 --- a/api/server/controllers/PluginController.js +++ b/api/server/controllers/PluginController.js @@ -1,5 +1,5 @@ const { CacheKeys, AuthType } = require('librechat-data-provider'); -const { addOpenAPISpecs } = require('~/app/clients/tools/util/addOpenAPISpecs'); +const { getToolkitKey } = require('~/server/services/ToolService'); const { getCustomConfig } = require('~/server/services/Config'); const { availableTools } = require('~/app/clients/tools'); const { getMCPManager } = require('~/config'); @@ -69,7 +69,7 @@ const getAvailablePluginsController = async (req, res) => { ); } - let plugins = await addOpenAPISpecs(authenticatedPlugins); + let plugins = authenticatedPlugins; if (includedTools.length > 0) { plugins = plugins.filter((plugin) => includedTools.includes(plugin.pluginKey)); @@ -105,11 +105,11 @@ const getAvailableTools = async (req, res) => { return; } - const pluginManifest = availableTools; + let pluginManifest = availableTools; const customConfig = await getCustomConfig(); if (customConfig?.mcpServers != null) { - const mcpManager = await getMCPManager(); - await mcpManager.loadManifestTools(pluginManifest); + const mcpManager = getMCPManager(); + pluginManifest = await mcpManager.loadManifestTools(pluginManifest); } /** @type {TPlugin[]} */ @@ -128,7 +128,7 @@ const getAvailableTools = async (req, res) => { (plugin) => toolDefinitions[plugin.pluginKey] !== undefined || (plugin.toolkit === true && - Object.keys(toolDefinitions).some((key) => key.startsWith(`${plugin.pluginKey}_`))), + Object.keys(toolDefinitions).some((key) => getToolkitKey(key) === plugin.pluginKey)), ); await cache.set(CacheKeys.TOOLS, tools); diff --git a/api/server/controllers/TwoFactorController.js b/api/server/controllers/TwoFactorController.js new file mode 100644 index 0000000000..f5783f45ad --- /dev/null +++ b/api/server/controllers/TwoFactorController.js @@ -0,0 +1,138 @@ +const { + generateTOTPSecret, + generateBackupCodes, + verifyTOTP, + verifyBackupCode, + getTOTPSecret, +} = require('~/server/services/twoFactorService'); +const { updateUser, getUserById } = require('~/models'); +const { logger } = require('~/config'); +const { encryptV3 } = require('~/server/utils/crypto'); + +const safeAppTitle = (process.env.APP_TITLE || 'LibreChat').replace(/\s+/g, ''); + +/** + * Enable 2FA for the user by generating a new TOTP secret and backup codes. + * The secret is encrypted and stored, and 2FA is marked as disabled until confirmed. + */ +const enable2FA = async (req, res) => { + try { + const userId = req.user.id; + const secret = generateTOTPSecret(); + const { plainCodes, codeObjects } = await generateBackupCodes(); + + // Encrypt the secret with v3 encryption before saving. + const encryptedSecret = encryptV3(secret); + + // Update the user record: store the secret & backup codes and set twoFactorEnabled to false. + const user = await updateUser(userId, { + totpSecret: encryptedSecret, + backupCodes: codeObjects, + twoFactorEnabled: false, + }); + + const otpauthUrl = `otpauth://totp/${safeAppTitle}:${user.email}?secret=${secret}&issuer=${safeAppTitle}`; + + return res.status(200).json({ otpauthUrl, backupCodes: plainCodes }); + } catch (err) { + logger.error('[enable2FA]', err); + return res.status(500).json({ message: err.message }); + } +}; + +/** + * Verify a 2FA code (either TOTP or backup code) during setup. + */ +const verify2FA = async (req, res) => { + try { + const userId = req.user.id; + const { token, backupCode } = req.body; + const user = await getUserById(userId); + + if (!user || !user.totpSecret) { + return res.status(400).json({ message: '2FA not initiated' }); + } + + const secret = await getTOTPSecret(user.totpSecret); + let isVerified = false; + + if (token) { + isVerified = await verifyTOTP(secret, token); + } else if (backupCode) { + isVerified = await verifyBackupCode({ user, backupCode }); + } + + if (isVerified) { + return res.status(200).json(); + } + return res.status(400).json({ message: 'Invalid token or backup code.' }); + } catch (err) { + logger.error('[verify2FA]', err); + return res.status(500).json({ message: err.message }); + } +}; + +/** + * Confirm and enable 2FA after a successful verification. + */ +const confirm2FA = async (req, res) => { + try { + const userId = req.user.id; + const { token } = req.body; + const user = await getUserById(userId); + + if (!user || !user.totpSecret) { + return res.status(400).json({ message: '2FA not initiated' }); + } + + const secret = await getTOTPSecret(user.totpSecret); + if (await verifyTOTP(secret, token)) { + await updateUser(userId, { twoFactorEnabled: true }); + return res.status(200).json(); + } + return res.status(400).json({ message: 'Invalid token.' }); + } catch (err) { + logger.error('[confirm2FA]', err); + return res.status(500).json({ message: err.message }); + } +}; + +/** + * Disable 2FA by clearing the stored secret and backup codes. + */ +const disable2FA = async (req, res) => { + try { + const userId = req.user.id; + await updateUser(userId, { totpSecret: null, backupCodes: [], twoFactorEnabled: false }); + return res.status(200).json(); + } catch (err) { + logger.error('[disable2FA]', err); + return res.status(500).json({ message: err.message }); + } +}; + +/** + * Regenerate backup codes for the user. + */ +const regenerateBackupCodes = async (req, res) => { + try { + const userId = req.user.id; + const { plainCodes, codeObjects } = await generateBackupCodes(); + await updateUser(userId, { backupCodes: codeObjects }); + return res.status(200).json({ + backupCodes: plainCodes, + backupCodesHash: codeObjects, + }); + } catch (err) { + logger.error('[regenerateBackupCodes]', err); + return res.status(500).json({ message: err.message }); + } +}; + +module.exports = { + enable2FA, + verify2FA, + confirm2FA, + disable2FA, + regenerateBackupCodes, +}; diff --git a/api/server/controllers/UserController.js b/api/server/controllers/UserController.js index 17089e8fdc..1ed2c4741d 100644 --- a/api/server/controllers/UserController.js +++ b/api/server/controllers/UserController.js @@ -1,6 +1,8 @@ +const { FileSources } = require('librechat-data-provider'); const { Balance, getFiles, + updateUser, deleteFiles, deleteConvos, deletePresets, @@ -12,6 +14,7 @@ const User = require('~/models/User'); const { updateUserPluginAuth, deleteUserPluginAuth } = require('~/server/services/PluginService'); const { updateUserPluginsService, deleteUserKey } = require('~/server/services/UserService'); const { verifyEmail, resendVerificationEmail } = require('~/server/services/AuthService'); +const { needsRefresh, getNewS3URL } = require('~/server/services/Files/S3/crud'); const { processDeleteRequest } = require('~/server/services/Files/process'); const { deleteAllSharedLinks } = require('~/models/Share'); const { deleteToolCalls } = require('~/models/ToolCall'); @@ -19,7 +22,24 @@ const { Transaction } = require('~/models/Transaction'); const { logger } = require('~/config'); const getUserController = async (req, res) => { - res.status(200).send(req.user); + /** @type {MongoUser} */ + const userData = req.user.toObject != null ? req.user.toObject() : { ...req.user }; + delete userData.totpSecret; + if (req.app.locals.fileStrategy === FileSources.s3 && userData.avatar) { + const avatarNeedsRefresh = needsRefresh(userData.avatar, 3600); + if (!avatarNeedsRefresh) { + return res.status(200).send(userData); + } + const originalAvatar = userData.avatar; + try { + userData.avatar = await getNewS3URL(userData.avatar); + await updateUser(userData.id, { avatar: userData.avatar }); + } catch (error) { + userData.avatar = originalAvatar; + logger.error('Error getting new S3 URL for avatar:', error); + } + } + res.status(200).send(userData); }; const getTermsStatusController = async (req, res) => { diff --git a/api/server/controllers/agents/callbacks.js b/api/server/controllers/agents/callbacks.js index 33fe585f42..3f507f7d0b 100644 --- a/api/server/controllers/agents/callbacks.js +++ b/api/server/controllers/agents/callbacks.js @@ -1,4 +1,5 @@ -const { Tools, StepTypes, imageGenTools, FileContext } = require('librechat-data-provider'); +const { nanoid } = require('nanoid'); +const { Tools, StepTypes, FileContext } = require('librechat-data-provider'); const { EnvVar, Providers, @@ -9,19 +10,10 @@ const { ChatModelStreamHandler, } = require('@librechat/agents'); const { processCodeOutput } = require('~/server/services/Files/Code/process'); +const { loadAuthValues } = require('~/server/services/Tools/credentials'); const { saveBase64Image } = require('~/server/services/Files/process'); -const { loadAuthValues } = require('~/app/clients/tools/util'); const { logger, sendEvent } = require('~/config'); -/** @typedef {import('@librechat/agents').Graph} Graph */ -/** @typedef {import('@librechat/agents').EventHandler} EventHandler */ -/** @typedef {import('@librechat/agents').ModelEndData} ModelEndData */ -/** @typedef {import('@librechat/agents').ToolEndData} ToolEndData */ -/** @typedef {import('@librechat/agents').ToolEndCallback} ToolEndCallback */ -/** @typedef {import('@librechat/agents').ChatModelStreamHandler} ChatModelStreamHandler */ -/** @typedef {import('@librechat/agents').ContentAggregatorResult['aggregateContent']} ContentAggregator */ -/** @typedef {import('@librechat/agents').GraphEvents} GraphEvents */ - class ModelEndHandler { /** * @param {Array} collectedUsage @@ -37,7 +29,7 @@ class ModelEndHandler { * @param {string} event * @param {ModelEndData | undefined} data * @param {Record | undefined} metadata - * @param {Graph} graph + * @param {StandardGraph} graph * @returns */ handle(event, data, metadata, graph) { @@ -60,7 +52,10 @@ class ModelEndHandler { } this.collectedUsage.push(usage); - if (!graph.clientOptions?.disableStreaming) { + const streamingDisabled = !!( + graph.clientOptions?.disableStreaming || graph?.boundModel?.disableStreaming + ); + if (!streamingDisabled) { return; } if (!data.output.content) { @@ -199,6 +194,22 @@ function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedU aggregateContent({ event, data }); }, }, + [GraphEvents.ON_REASONING_DELTA]: { + /** + * Handle ON_REASONING_DELTA event. + * @param {string} event - The event name. + * @param {StreamEventData} data - The event data. + * @param {GraphRunnableConfig['configurable']} [metadata] The runnable metadata. + */ + handle: (event, data, metadata) => { + if (metadata?.last_agent_index === metadata?.agent_index) { + sendEvent(res, { event, data }); + } else if (!metadata?.hide_sequential_outputs) { + sendEvent(res, { event, data }); + } + aggregateContent({ event, data }); + }, + }, }; return handlers; @@ -226,45 +237,25 @@ function createToolEndCallback({ req, res, artifactPromises }) { return; } - if (imageGenTools.has(output.name)) { - artifactPromises.push( - (async () => { - const fileMetadata = Object.assign(output.artifact, { - messageId: metadata.run_id, - toolCallId: output.tool_call_id, - conversationId: metadata.thread_id, - }); - if (!res.headersSent) { - return fileMetadata; - } - - if (!fileMetadata) { - return null; - } - - res.write(`event: attachment\ndata: ${JSON.stringify(fileMetadata)}\n\n`); - return fileMetadata; - })().catch((error) => { - logger.error('Error processing code output:', error); - return null; - }), - ); - return; - } - if (output.artifact.content) { /** @type {FormattedContent[]} */ const content = output.artifact.content; - for (const part of content) { + for (let i = 0; i < content.length; i++) { + const part = content[i]; + if (!part) { + continue; + } if (part.type !== 'image_url') { continue; } const { url } = part.image_url; artifactPromises.push( (async () => { - const filename = `${output.tool_call_id}-image-${new Date().getTime()}`; + const filename = `${output.name}_${output.tool_call_id}_img_${nanoid()}`; + const file_id = output.artifact.file_ids?.[i]; const file = await saveBase64Image(url, { req, + file_id, filename, endpoint: metadata.provider, context: FileContext.image_generation, diff --git a/api/server/controllers/agents/client.js b/api/server/controllers/agents/client.js index a8e9ad82f7..a3484f6505 100644 --- a/api/server/controllers/agents/client.js +++ b/api/server/controllers/agents/client.js @@ -7,56 +7,78 @@ // validateVisionModel, // mapModelToAzureConfig, // } = require('librechat-data-provider'); -const { Callback, createMetadataAggregator } = require('@librechat/agents'); -const { - Constants, - VisionModes, - openAISchema, - ContentTypes, - EModelEndpoint, - KnownEndpoints, - anthropicSchema, - isAgentsEndpoint, - bedrockOutputParser, - removeNullishValues, -} = require('librechat-data-provider'); -const { - extractBaseURL, - // constructAzureURL, - // genAzureChatCompletion, -} = require('~/utils'); +require('events').EventEmitter.defaultMaxListeners = 100; const { + Callback, + GraphEvents, formatMessage, formatAgentMessages, formatContentStrings, - createContextHandlers, -} = require('~/app/clients/prompts'); -const { encodeAndFormat } = require('~/server/services/Files/images/encode'); + getTokenCountForMessage, + createMetadataAggregator, +} = require('@librechat/agents'); +const { + Constants, + VisionModes, + ContentTypes, + EModelEndpoint, + KnownEndpoints, + isAgentsEndpoint, + AgentCapabilities, + bedrockInputSchema, + removeNullishValues, +} = require('librechat-data-provider'); +const { getCustomEndpointConfig, checkCapability } = require('~/server/services/Config'); +const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts'); +const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); const { getBufferString, HumanMessage } = require('@langchain/core/messages'); +const { encodeAndFormat } = require('~/server/services/Files/images/encode'); +const initOpenAI = require('~/server/services/Endpoints/openAI/initialize'); const Tokenizer = require('~/server/services/Tokenizer'); -const { spendTokens } = require('~/models/spendTokens'); const BaseClient = require('~/app/clients/BaseClient'); +const { logger, sendEvent } = require('~/config'); const { createRun } = require('./run'); -const { logger } = require('~/config'); /** @typedef {import('@librechat/agents').MessageContentComplex} MessageContentComplex */ /** @typedef {import('@langchain/core/runnables').RunnableConfig} RunnableConfig */ -const providerParsers = { - [EModelEndpoint.openAI]: openAISchema, - [EModelEndpoint.azureOpenAI]: openAISchema, - [EModelEndpoint.anthropic]: anthropicSchema, - [EModelEndpoint.bedrock]: bedrockOutputParser, +/** + * @param {ServerRequest} req + * @param {Agent} agent + * @param {string} endpoint + */ +const payloadParser = ({ req, agent, endpoint }) => { + if (isAgentsEndpoint(endpoint)) { + return { model: undefined }; + } else if (endpoint === EModelEndpoint.bedrock) { + return bedrockInputSchema.parse(agent.model_parameters); + } + return req.body.endpointOption.model_parameters; }; const legacyContentEndpoints = new Set([KnownEndpoints.groq, KnownEndpoints.deepseek]); -const noSystemModelRegex = [/\bo1\b/gi]; +const noSystemModelRegex = [/\b(o1-preview|o1-mini|amazon\.titan-text)\b/gi]; // const { processMemory, memoryInstructions } = require('~/server/services/Endpoints/agents/memory'); // const { getFormattedMemories } = require('~/models/Memory'); // const { getCurrentDateTime } = require('~/utils'); +function createTokenCounter(encoding) { + return (message) => { + const countTokens = (text) => Tokenizer.getTokenCount(text, encoding); + return getTokenCountForMessage(message, countTokens); + }; +} + +function logToolError(graph, error, toolId) { + logger.error( + '[api/server/controllers/agents/client.js #chatCompletion] Tool Error', + error, + toolId, + ); +} + class AgentClient extends BaseClient { constructor(options = {}) { super(null, options); @@ -102,6 +124,8 @@ class AgentClient extends BaseClient { this.outputTokensKey = 'output_tokens'; /** @type {UsageMetadata} */ this.usage; + /** @type {Record} */ + this.indexTokenCountMap = {}; } /** @@ -124,19 +148,13 @@ class AgentClient extends BaseClient { * @param {MongoFile[]} attachments */ checkVisionRequest(attachments) { - logger.info( - '[api/server/controllers/agents/client.js #checkVisionRequest] not implemented', - attachments, - ); // if (!attachments) { // return; // } - // const availableModels = this.options.modelsConfig?.[this.options.endpoint]; // if (!availableModels) { // return; // } - // let visionRequestDetected = false; // for (const file of attachments) { // if (file?.type?.includes('image')) { @@ -147,13 +165,11 @@ class AgentClient extends BaseClient { // if (!visionRequestDetected) { // return; // } - // this.isVisionModel = validateVisionModel({ model: this.modelOptions.model, availableModels }); // if (this.isVisionModel) { // delete this.modelOptions.stop; // return; // } - // for (const model of availableModels) { // if (!validateVisionModel({ model, availableModels })) { // continue; @@ -163,35 +179,31 @@ class AgentClient extends BaseClient { // delete this.modelOptions.stop; // return; // } - // if (!availableModels.includes(this.defaultVisionModel)) { // return; // } // if (!validateVisionModel({ model: this.defaultVisionModel, availableModels })) { // return; // } - // this.modelOptions.model = this.defaultVisionModel; // this.isVisionModel = true; // delete this.modelOptions.stop; } getSaveOptions() { - const parseOptions = providerParsers[this.options.endpoint]; - let runOptions = - this.options.endpoint === EModelEndpoint.agents - ? { - model: undefined, - // TODO: - // would need to be override settings; otherwise, model needs to be undefined - // model: this.override.model, - // instructions: this.override.instructions, - // additional_instructions: this.override.additional_instructions, - } - : {}; - - if (parseOptions) { - runOptions = parseOptions(this.options.agent.model_parameters); + // TODO: + // would need to be override settings; otherwise, model needs to be undefined + // model: this.override.model, + // instructions: this.override.instructions, + // additional_instructions: this.override.additional_instructions, + let runOptions = {}; + try { + runOptions = payloadParser(this.options); + } catch (error) { + logger.error( + '[api/server/controllers/agents/client.js #getSaveOptions] Error parsing options', + error, + ); } return removeNullishValues( @@ -219,14 +231,23 @@ class AgentClient extends BaseClient { }; } + /** + * + * @param {TMessage} message + * @param {Array} attachments + * @returns {Promise>>} + */ async addImageURLs(message, attachments) { - const { files, image_urls } = await encodeAndFormat( + const { files, text, image_urls } = await encodeAndFormat( this.options.req, attachments, this.options.agent.provider, VisionModes.agents, ); message.image_urls = image_urls.length ? image_urls : undefined; + if (text && text.length) { + message.ocr = text; + } return files; } @@ -304,7 +325,21 @@ class AgentClient extends BaseClient { assistantName: this.options?.modelLabel, }); - const needsTokenCount = this.contextStrategy && !orderedMessages[i].tokenCount; + if (message.ocr && i !== orderedMessages.length - 1) { + if (typeof formattedMessage.content === 'string') { + formattedMessage.content = message.ocr + '\n' + formattedMessage.content; + } else { + const textPart = formattedMessage.content.find((part) => part.type === 'text'); + textPart + ? (textPart.text = message.ocr + '\n' + textPart.text) + : formattedMessage.content.unshift({ type: 'text', text: message.ocr }); + } + } else if (message.ocr && i === orderedMessages.length - 1) { + systemContent = [systemContent, message.ocr].join('\n'); + } + + const needsTokenCount = + (this.contextStrategy && !orderedMessages[i].tokenCount) || message.ocr; /* If tokens were never counted, or, is a Vision request and the message has files, count again */ if (needsTokenCount || (this.isVisionModel && (message.image_urls || message.files))) { @@ -319,7 +354,9 @@ class AgentClient extends BaseClient { this.contextHandlers?.processFile(file); continue; } - + if (file.metadata?.fileIdentifier) { + continue; + } // orderedMessages[i].tokenCount += this.calculateImageTokenCost({ // width: file.width, // height: file.height, @@ -350,6 +387,10 @@ class AgentClient extends BaseClient { })); } + for (let i = 0; i < messages.length; i++) { + this.indexTokenCountMap[i] = messages[i].tokenCount; + } + const result = { tokenCountMap, prompt: payload, @@ -384,15 +425,34 @@ class AgentClient extends BaseClient { if (!collectedUsage || !collectedUsage.length) { return; } - const input_tokens = collectedUsage[0]?.input_tokens || 0; + const input_tokens = + (collectedUsage[0]?.input_tokens || 0) + + (Number(collectedUsage[0]?.input_token_details?.cache_creation) || 0) + + (Number(collectedUsage[0]?.input_token_details?.cache_read) || 0); let output_tokens = 0; let previousTokens = input_tokens; // Start with original input for (let i = 0; i < collectedUsage.length; i++) { const usage = collectedUsage[i]; + if (!usage) { + continue; + } + + const cache_creation = Number(usage.input_token_details?.cache_creation) || 0; + const cache_read = Number(usage.input_token_details?.cache_read) || 0; + + const txMetadata = { + context, + conversationId: this.conversationId, + user: this.user ?? this.options.req.user?.id, + endpointTokenConfig: this.options.endpointTokenConfig, + model: usage.model ?? model ?? this.model ?? this.options.agent.model_parameters.model, + }; + if (i > 0) { // Count new tokens generated (input_tokens minus previous accumulated tokens) - output_tokens += (Number(usage.input_tokens) || 0) - previousTokens; + output_tokens += + (Number(usage.input_tokens) || 0) + cache_creation + cache_read - previousTokens; } // Add this message's output tokens @@ -400,16 +460,27 @@ class AgentClient extends BaseClient { // Update previousTokens to include this message's output previousTokens += Number(usage.output_tokens) || 0; - spendTokens( - { - context, - conversationId: this.conversationId, - user: this.user ?? this.options.req.user?.id, - endpointTokenConfig: this.options.endpointTokenConfig, - model: usage.model ?? model ?? this.model ?? this.options.agent.model_parameters.model, - }, - { promptTokens: usage.input_tokens, completionTokens: usage.output_tokens }, - ).catch((err) => { + + if (cache_creation > 0 || cache_read > 0) { + spendStructuredTokens(txMetadata, { + promptTokens: { + input: usage.input_tokens, + write: cache_creation, + read: cache_read, + }, + completionTokens: usage.output_tokens, + }).catch((err) => { + logger.error( + '[api/server/controllers/agents/client.js #recordCollectedUsage] Error spending structured tokens', + err, + ); + }); + continue; + } + spendTokens(txMetadata, { + promptTokens: usage.input_tokens, + completionTokens: usage.output_tokens, + }).catch((err) => { logger.error( '[api/server/controllers/agents/client.js #recordCollectedUsage] Error spending tokens', err, @@ -472,24 +543,15 @@ class AgentClient extends BaseClient { } async chatCompletion({ payload, abortController = null }) { + /** @type {Partial & { version: 'v1' | 'v2'; run_id?: string; streamMode: string }} */ + let config; + /** @type {ReturnType} */ + let run; try { if (!abortController) { abortController = new AbortController(); } - const baseURL = extractBaseURL(this.completionsUrl); - logger.debug('[api/server/controllers/agents/client.js] chatCompletion', { - baseURL, - payload, - }); - - // if (this.useOpenRouter) { - // opts.defaultHeaders = { - // 'HTTP-Referer': 'https://librechat.ai', - // 'X-Title': 'LibreChat', - // }; - // } - // if (this.options.headers) { // opts.defaultHeaders = { ...opts.defaultHeaders, ...this.options.headers }; // } @@ -579,39 +641,55 @@ class AgentClient extends BaseClient { // }); // } - /** @type {Partial & { version: 'v1' | 'v2'; run_id?: string; streamMode: string }} */ - const config = { + /** @type {TCustomConfig['endpoints']['agents']} */ + const agentsEConfig = this.options.req.app.locals[EModelEndpoint.agents]; + + config = { configurable: { thread_id: this.conversationId, last_agent_index: this.agentConfigs?.size ?? 0, + user_id: this.user ?? this.options.req.user?.id, hide_sequential_outputs: this.options.agent.hide_sequential_outputs, }, - recursionLimit: this.options.req.app.locals[EModelEndpoint.agents]?.recursionLimit, + recursionLimit: agentsEConfig?.recursionLimit, signal: abortController.signal, streamMode: 'values', version: 'v2', }; - const initialMessages = formatAgentMessages(payload); - if (legacyContentEndpoints.has(this.options.agent.endpoint)) { - formatContentStrings(initialMessages); + const toolSet = new Set((this.options.agent.tools ?? []).map((tool) => tool && tool.name)); + let { messages: initialMessages, indexTokenCountMap } = formatAgentMessages( + payload, + this.indexTokenCountMap, + toolSet, + ); + if (legacyContentEndpoints.has(this.options.agent.endpoint?.toLowerCase())) { + initialMessages = formatContentStrings(initialMessages); } - /** @type {ReturnType} */ - let run; - /** * * @param {Agent} agent * @param {BaseMessage[]} messages * @param {number} [i] * @param {TMessageContentParts[]} [contentData] + * @param {Record} [currentIndexCountMap] */ - const runAgent = async (agent, messages, i = 0, contentData = []) => { + const runAgent = async (agent, _messages, i = 0, contentData = [], _currentIndexCountMap) => { config.configurable.model = agent.model_parameters.model; + const currentIndexCountMap = _currentIndexCountMap ?? indexTokenCountMap; if (i > 0) { this.model = agent.model_parameters.model; } + if (agent.recursion_limit && typeof agent.recursion_limit === 'number') { + config.recursionLimit = agent.recursion_limit; + } + if ( + agentsEConfig?.maxRecursionLimit && + config.recursionLimit > agentsEConfig?.maxRecursionLimit + ) { + config.recursionLimit = agentsEConfig?.maxRecursionLimit; + } config.configurable.agent_id = agent.id; config.configurable.name = agent.name; config.configurable.agent_index = i; @@ -626,7 +704,7 @@ class AgentClient extends BaseClient { let systemContent = [ systemMessage, agent.instructions ?? '', - i !== 0 ? agent.additional_instructions ?? '' : '', + i !== 0 ? (agent.additional_instructions ?? '') : '', ] .join('\n') .trim(); @@ -640,12 +718,23 @@ class AgentClient extends BaseClient { } if (noSystemMessages === true && systemContent?.length) { - let latestMessage = messages.pop().content; + const latestMessageContent = _messages.pop().content; if (typeof latestMessage !== 'string') { - latestMessage = latestMessage[0].text; + latestMessageContent[0].text = [systemContent, latestMessageContent[0].text].join('\n'); + _messages.push(new HumanMessage({ content: latestMessageContent })); + } else { + const text = [systemContent, latestMessageContent].join('\n'); + _messages.push(new HumanMessage(text)); } - latestMessage = [systemContent, latestMessage].join('\n'); - messages.push(new HumanMessage(latestMessage)); + } + + let messages = _messages; + if ( + agent.model_parameters?.clientOptions?.defaultHeaders?.['anthropic-beta']?.includes( + 'prompt-caching', + ) + ) { + messages = addCacheControl(messages); } run = await createRun({ @@ -665,27 +754,46 @@ class AgentClient extends BaseClient { } if (contentData.length) { + const agentUpdate = { + type: ContentTypes.AGENT_UPDATE, + [ContentTypes.AGENT_UPDATE]: { + index: contentData.length, + runId: this.responseMessageId, + agentId: agent.id, + }, + }; + const streamData = { + event: GraphEvents.ON_AGENT_UPDATE, + data: agentUpdate, + }; + this.options.aggregateContent(streamData); + sendEvent(this.options.res, streamData); + contentData.push(agentUpdate); run.Graph.contentData = contentData; } + const encoding = this.getEncoding(); await run.processStream({ messages }, config, { keepContent: i !== 0, + tokenCounter: createTokenCounter(encoding), + indexTokenCountMap: currentIndexCountMap, + maxContextTokens: agent.maxContextTokens, callbacks: { - [Callback.TOOL_ERROR]: (graph, error, toolId) => { - logger.error( - '[api/server/controllers/agents/client.js #chatCompletion] Tool Error', - error, - toolId, - ); - }, + [Callback.TOOL_ERROR]: logToolError, }, }); + + config.signal = null; }; await runAgent(this.options.agent, initialMessages); - let finalContentStart = 0; - if (this.agentConfigs && this.agentConfigs.size > 0) { + if ( + this.agentConfigs && + this.agentConfigs.size > 0 && + (await checkCapability(this.options.req, AgentCapabilities.chain)) + ) { + const windowSize = 5; let latestMessage = initialMessages.pop().content; if (typeof latestMessage !== 'string') { latestMessage = latestMessage[0].text; @@ -693,7 +801,18 @@ class AgentClient extends BaseClient { let i = 1; let runMessages = []; - const lastFiveMessages = initialMessages.slice(-5); + const windowIndexCountMap = {}; + const windowMessages = initialMessages.slice(-windowSize); + let currentIndex = 4; + for (let i = initialMessages.length - 1; i >= 0; i--) { + windowIndexCountMap[currentIndex] = indexTokenCountMap[i]; + currentIndex--; + if (currentIndex < 0) { + break; + } + } + const encoding = this.getEncoding(); + const tokenCounter = createTokenCounter(encoding); for (const [agentId, agent] of this.agentConfigs) { if (abortController.signal.aborted === true) { break; @@ -728,7 +847,9 @@ class AgentClient extends BaseClient { } try { const contextMessages = []; - for (const message of lastFiveMessages) { + const runIndexCountMap = {}; + for (let i = 0; i < windowMessages.length; i++) { + const message = windowMessages[i]; const messageType = message._getType(); if ( (!agent.tools || agent.tools.length === 0) && @@ -736,11 +857,13 @@ class AgentClient extends BaseClient { ) { continue; } - + runIndexCountMap[contextMessages.length] = windowIndexCountMap[i]; contextMessages.push(message); } - const currentMessages = [...contextMessages, new HumanMessage(bufferString)]; - await runAgent(agent, currentMessages, i, contentData); + const bufferMessage = new HumanMessage(bufferString); + runIndexCountMap[contextMessages.length] = tokenCounter(bufferMessage); + const currentMessages = [...contextMessages, bufferMessage]; + await runAgent(agent, currentMessages, i, contentData, runIndexCountMap); } catch (err) { logger.error( `[api/server/controllers/agents/client.js #chatCompletion] Error running agent ${agentId} (${i})`, @@ -751,6 +874,7 @@ class AgentClient extends BaseClient { } } + /** Note: not implemented */ if (config.configurable.hide_sequential_outputs !== true) { finalContentStart = 0; } @@ -774,18 +898,20 @@ class AgentClient extends BaseClient { ); } } catch (err) { + logger.error( + '[api/server/controllers/agents/client.js #sendCompletion] Operation aborted', + err, + ); if (!abortController.signal.aborted) { logger.error( '[api/server/controllers/agents/client.js #sendCompletion] Unhandled error type', err, ); - throw err; + this.contentParts.push({ + type: ContentTypes.ERROR, + [ContentTypes.ERROR]: `An error occurred while processing the request${err?.message ? `: ${err.message}` : ''}`, + }); } - - logger.warn( - '[api/server/controllers/agents/client.js #sendCompletion] Operation aborted', - err, - ); } } @@ -795,19 +921,56 @@ class AgentClient extends BaseClient { * @param {string} params.text * @param {string} params.conversationId */ - async titleConvo({ text }) { + async titleConvo({ text, abortController }) { if (!this.run) { throw new Error('Run not initialized'); } const { handleLLMEnd, collected: collectedMetadata } = createMetadataAggregator(); - const clientOptions = {}; - const providerConfig = this.options.req.app.locals[this.options.agent.provider]; + const endpoint = this.options.agent.endpoint; + const { req, res } = this.options; + /** @type {import('@librechat/agents').ClientOptions} */ + let clientOptions = { + maxTokens: 75, + }; + let endpointConfig = req.app.locals[endpoint]; + if (!endpointConfig) { + try { + endpointConfig = await getCustomEndpointConfig(endpoint); + } catch (err) { + logger.error( + '[api/server/controllers/agents/client.js #titleConvo] Error getting custom endpoint config', + err, + ); + } + } if ( - providerConfig && - providerConfig.titleModel && - providerConfig.titleModel !== Constants.CURRENT_MODEL + endpointConfig && + endpointConfig.titleModel && + endpointConfig.titleModel !== Constants.CURRENT_MODEL ) { - clientOptions.model = providerConfig.titleModel; + clientOptions.model = endpointConfig.titleModel; + } + if ( + endpoint === EModelEndpoint.azureOpenAI && + clientOptions.model && + this.options.agent.model_parameters.model !== clientOptions.model + ) { + clientOptions = + ( + await initOpenAI({ + req, + res, + optionsOnly: true, + overrideModel: clientOptions.model, + overrideEndpoint: endpoint, + endpointOption: { + model_parameters: clientOptions, + }, + }) + )?.llmConfig ?? clientOptions; + } + if (/\b(o\d)\b/i.test(clientOptions.model) && clientOptions.maxTokens != null) { + delete clientOptions.maxTokens; } try { const titleResult = await this.run.generateTitle({ @@ -815,6 +978,7 @@ class AgentClient extends BaseClient { contentParts: this.contentParts, clientOptions, chainOptions: { + signal: abortController.signal, callbacks: [ { handleLLMEnd, @@ -840,7 +1004,7 @@ class AgentClient extends BaseClient { }; }); - this.recordCollectedUsage({ + await this.recordCollectedUsage({ model: clientOptions.model, context: 'title', collectedUsage, diff --git a/api/server/controllers/agents/request.js b/api/server/controllers/agents/request.js index 288ae8f37f..fcee62edc7 100644 --- a/api/server/controllers/agents/request.js +++ b/api/server/controllers/agents/request.js @@ -1,5 +1,10 @@ const { Constants } = require('librechat-data-provider'); -const { createAbortController, handleAbortError } = require('~/server/middleware'); +const { + handleAbortError, + createAbortController, + cleanupAbortController, +} = require('~/server/middleware'); +const { disposeClient, clientRegistry, requestDataMap } = require('~/server/cleanup'); const { sendMessage } = require('~/server/utils'); const { saveMessage } = require('~/models'); const { logger } = require('~/config'); @@ -14,16 +19,22 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { } = req.body; let sender; + let abortKey; let userMessage; let promptTokens; let userMessageId; let responseMessageId; let userMessagePromise; + let getAbortData; + let client = null; + // Initialize as an array + let cleanupHandlers = []; const newConvo = !conversationId; - const user = req.user.id; + const userId = req.user.id; - const getReqData = (data = {}) => { + // Create handler to avoid capturing the entire parent scope + let getReqData = (data = {}) => { for (let key in data) { if (key === 'userMessage') { userMessage = data[key]; @@ -36,30 +47,96 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { promptTokens = data[key]; } else if (key === 'sender') { sender = data[key]; + } else if (key === 'abortKey') { + abortKey = data[key]; } else if (!conversationId && key === 'conversationId') { conversationId = data[key]; } } }; + // Create a function to handle final cleanup + const performCleanup = () => { + logger.debug('[AgentController] Performing cleanup'); + // Make sure cleanupHandlers is an array before iterating + if (Array.isArray(cleanupHandlers)) { + // Execute all cleanup handlers + for (const handler of cleanupHandlers) { + try { + if (typeof handler === 'function') { + handler(); + } + } catch (e) { + // Ignore cleanup errors + } + } + } + + // Clean up abort controller + if (abortKey) { + logger.debug('[AgentController] Cleaning up abort controller'); + cleanupAbortController(abortKey); + } + + // Dispose client properly + if (client) { + disposeClient(client); + } + + // Clear all references + client = null; + getReqData = null; + userMessage = null; + getAbortData = null; + endpointOption.agent = null; + endpointOption = null; + cleanupHandlers = null; + userMessagePromise = null; + + // Clear request data map + if (requestDataMap.has(req)) { + requestDataMap.delete(req); + } + logger.debug('[AgentController] Cleanup completed'); + }; + try { /** @type {{ client: TAgentClient }} */ - const { client } = await initializeClient({ req, res, endpointOption }); + const result = await initializeClient({ req, res, endpointOption }); + client = result.client; - const getAbortData = () => ({ - sender, - userMessage, - promptTokens, - conversationId, - userMessagePromise, - messageId: responseMessageId, - content: client.getContentParts(), - parentMessageId: overrideParentMessageId ?? userMessageId, - }); + // Register client with finalization registry if available + if (clientRegistry) { + clientRegistry.register(client, { userId }, client); + } + + // Store request data in WeakMap keyed by req object + requestDataMap.set(req, { client }); + + // Use WeakRef to allow GC but still access content if it exists + const contentRef = new WeakRef(client.contentParts || []); + + // Minimize closure scope - only capture small primitives and WeakRef + getAbortData = () => { + // Dereference WeakRef each time + const content = contentRef.deref(); + + return { + sender, + content: content || [], + userMessage, + promptTokens, + conversationId, + userMessagePromise, + messageId: responseMessageId, + parentMessageId: overrideParentMessageId ?? userMessageId, + }; + }; const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData); - res.on('close', () => { + // Simple handler to avoid capturing scope + const closeHandler = () => { logger.debug('[AgentController] Request closed'); if (!abortController) { return; @@ -71,10 +148,19 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { abortController.abort(); logger.debug('[AgentController] Request aborted on close'); + }; + + res.on('close', closeHandler); + cleanupHandlers.push(() => { + try { + res.removeListener('close', closeHandler); + } catch (e) { + // Ignore + } }); const messageOptions = { - user, + user: userId, onStart, getReqData, conversationId, @@ -83,69 +169,104 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { overrideParentMessageId, progressOptions: { res, - // parentMessageId: overrideParentMessageId || userMessageId, }, }; let response = await client.sendMessage(text, messageOptions); - response.endpoint = endpointOption.endpoint; - const { conversation = {} } = await client.responsePromise; + // Extract what we need and immediately break reference + const messageId = response.messageId; + const endpoint = endpointOption.endpoint; + response.endpoint = endpoint; + + // Store database promise locally + const databasePromise = response.databasePromise; + delete response.databasePromise; + + // Resolve database-related data + const { conversation: convoData = {} } = await databasePromise; + const conversation = { ...convoData }; conversation.title = conversation && !conversation.title ? null : conversation?.title || 'New Chat'; - if (req.body.files && client.options.attachments) { + // Process files if needed + if (req.body.files && client.options?.attachments) { userMessage.files = []; const messageFiles = new Set(req.body.files.map((file) => file.file_id)); for (let attachment of client.options.attachments) { if (messageFiles.has(attachment.file_id)) { - userMessage.files.push(attachment); + userMessage.files.push({ ...attachment }); } } delete userMessage.image_urls; } + // Only send if not aborted if (!abortController.signal.aborted) { + // Create a new response object with minimal copies + const finalResponse = { ...response }; + sendMessage(res, { final: true, conversation, title: conversation.title, requestMessage: userMessage, - responseMessage: response, + responseMessage: finalResponse, }); res.end(); - if (!client.savedMessageIds.has(response.messageId)) { + // Save the message if needed + if (client.savedMessageIds && !client.savedMessageIds.has(messageId)) { await saveMessage( req, - { ...response, user }, + { ...finalResponse, user: userId }, { context: 'api/server/controllers/agents/request.js - response end' }, ); } } + // Save user message if needed if (!client.skipSaveUserMessage) { await saveMessage(req, userMessage, { context: 'api/server/controllers/agents/request.js - don\'t skip saving user message', }); } + // Add title if needed - extract minimal data if (addTitle && parentMessageId === Constants.NO_PARENT && newConvo) { addTitle(req, { text, - response, + response: { ...response }, client, - }); + }) + .then(() => { + logger.debug('[AgentController] Title generation started'); + }) + .catch((err) => { + logger.error('[AgentController] Error in title generation', err); + }) + .finally(() => { + logger.debug('[AgentController] Title generation completed'); + performCleanup(); + }); + } else { + performCleanup(); } } catch (error) { + // Handle error without capturing much scope handleAbortError(res, req, error, { conversationId, sender, messageId: responseMessageId, - parentMessageId: userMessageId ?? parentMessageId, - }).catch((err) => { - logger.error('[api/server/controllers/agents/request] Error in `handleAbortError`', err); - }); + parentMessageId: overrideParentMessageId ?? userMessageId ?? parentMessageId, + userMessageId, + }) + .catch((err) => { + logger.error('[api/server/controllers/agents/request] Error in `handleAbortError`', err); + }) + .finally(() => { + performCleanup(); + }); } }; diff --git a/api/server/controllers/agents/run.js b/api/server/controllers/agents/run.js index 0fcc58a379..2452e66233 100644 --- a/api/server/controllers/agents/run.js +++ b/api/server/controllers/agents/run.js @@ -1,5 +1,5 @@ const { Run, Providers } = require('@librechat/agents'); -const { providerEndpointMap } = require('librechat-data-provider'); +const { providerEndpointMap, KnownEndpoints } = require('librechat-data-provider'); /** * @typedef {import('@librechat/agents').t} t @@ -7,9 +7,17 @@ const { providerEndpointMap } = require('librechat-data-provider'); * @typedef {import('@librechat/agents').StreamEventData} StreamEventData * @typedef {import('@librechat/agents').EventHandler} EventHandler * @typedef {import('@librechat/agents').GraphEvents} GraphEvents + * @typedef {import('@librechat/agents').LLMConfig} LLMConfig * @typedef {import('@librechat/agents').IState} IState */ +const customProviders = new Set([ + Providers.XAI, + Providers.OLLAMA, + Providers.DEEPSEEK, + Providers.OPENROUTER, +]); + /** * Creates a new Run instance with custom handlers and configuration. * @@ -32,6 +40,7 @@ async function createRun({ streamUsage = true, }) { const provider = providerEndpointMap[agent.provider] ?? agent.provider; + /** @type {LLMConfig} */ const llmConfig = Object.assign( { provider, @@ -41,15 +50,29 @@ async function createRun({ agent.model_parameters, ); - if (/o1(?!-(?:mini|preview)).*$/.test(llmConfig.model)) { - llmConfig.streaming = false; - llmConfig.disableStreaming = true; + /** Resolves issues with new OpenAI usage field */ + if ( + customProviders.has(agent.provider) || + (agent.provider === Providers.OPENAI && agent.endpoint !== agent.provider) + ) { + llmConfig.streamUsage = false; + llmConfig.usage = true; + } + + /** @type {'reasoning_content' | 'reasoning'} */ + let reasoningKey; + if ( + llmConfig.configuration?.baseURL?.includes(KnownEndpoints.openrouter) || + (agent.endpoint && agent.endpoint.toLowerCase().includes(KnownEndpoints.openrouter)) + ) { + reasoningKey = 'reasoning'; } /** @type {StandardGraphConfig} */ const graphConfig = { signal, llmConfig, + reasoningKey, tools: agent.tools, instructions: agent.instructions, additional_instructions: agent.additional_instructions, @@ -57,7 +80,7 @@ async function createRun({ }; // TEMPORARY FOR TESTING - if (agent.provider === Providers.ANTHROPIC) { + if (agent.provider === Providers.ANTHROPIC || agent.provider === Providers.BEDROCK) { graphConfig.streamBuffer = 2000; } diff --git a/api/server/controllers/agents/v1.js b/api/server/controllers/agents/v1.js index 08327ec61c..e0f27a13fc 100644 --- a/api/server/controllers/agents/v1.js +++ b/api/server/controllers/agents/v1.js @@ -1,10 +1,12 @@ const fs = require('fs').promises; const { nanoid } = require('nanoid'); const { - FileContext, - Constants, Tools, + Constants, + FileContext, + FileSources, SystemRoles, + EToolResources, actionDelimiter, } = require('librechat-data-provider'); const { @@ -16,9 +18,10 @@ const { } = require('~/models/Agent'); const { uploadImageBuffer, filterFile } = require('~/server/services/Files/process'); const { getStrategyFunctions } = require('~/server/services/Files/strategies'); +const { refreshS3Url } = require('~/server/services/Files/S3/crud'); const { updateAction, getActions } = require('~/models/Action'); -const { getProjectByName } = require('~/models/Project'); const { updateAgentProjects } = require('~/models/Agent'); +const { getProjectByName } = require('~/models/Project'); const { deleteFileByFilter } = require('~/models/File'); const { logger } = require('~/config'); @@ -101,6 +104,14 @@ const getAgentHandler = async (req, res) => { return res.status(404).json({ error: 'Agent not found' }); } + if (agent.avatar && agent.avatar?.source === FileSources.s3) { + const originalUrl = agent.avatar.filepath; + agent.avatar.filepath = await refreshS3Url(agent.avatar); + if (originalUrl !== agent.avatar.filepath) { + await updateAgent({ id }, { avatar: agent.avatar }); + } + } + agent.author = agent.author.toString(); agent.isCollaborative = !!agent.isCollaborative; @@ -203,13 +214,25 @@ const duplicateAgentHandler = async (req, res) => { } const { - _id: __id, id: _id, + _id: __id, author: _author, createdAt: _createdAt, updatedAt: _updatedAt, + tool_resources: _tool_resources = {}, ...cloneData } = agent; + cloneData.name = `${agent.name} (${new Date().toLocaleString('en-US', { + dateStyle: 'short', + timeStyle: 'short', + hour12: false, + })})`; + + if (_tool_resources?.[EToolResources.ocr]) { + cloneData.tool_resources = { + [EToolResources.ocr]: _tool_resources[EToolResources.ocr], + }; + } const newAgentId = `agent_${nanoid()}`; const newAgentData = Object.assign(cloneData, { diff --git a/api/server/controllers/assistants/chatV1.js b/api/server/controllers/assistants/chatV1.js index 8461941e05..5fa10e9e37 100644 --- a/api/server/controllers/assistants/chatV1.js +++ b/api/server/controllers/assistants/chatV1.js @@ -19,7 +19,7 @@ const { addThreadMetadata, saveAssistantMessage, } = require('~/server/services/Threads'); -const { sendResponse, sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils'); +const { sendResponse, sendMessage, sleep, countTokens } = require('~/server/utils'); const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService'); const validateAuthor = require('~/server/middleware/assistants/validateAuthor'); const { formatMessage, createVisionPrompt } = require('~/app/clients/prompts'); @@ -27,7 +27,7 @@ const { createRun, StreamRunManager } = require('~/server/services/Runs'); const { addTitle } = require('~/server/services/Endpoints/assistants'); const { createRunBody } = require('~/server/services/createRunBody'); const { getTransactions } = require('~/models/Transaction'); -const checkBalance = require('~/models/checkBalance'); +const { checkBalance } = require('~/models/balanceMethods'); const { getConvo } = require('~/models/Conversation'); const getLogStores = require('~/cache/getLogStores'); const { getModelMaxTokens } = require('~/utils'); @@ -119,7 +119,7 @@ const chatV1 = async (req, res) => { } else if (/Files.*are invalid/.test(error.message)) { const errorMessage = `Files are invalid, or may not have uploaded yet.${ endpoint === EModelEndpoint.azureAssistants - ? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.' + ? " If using Azure OpenAI, files are only available in the region of the assistant's model at the time of upload." : '' }`; return sendResponse(req, res, messageData, errorMessage); @@ -248,7 +248,8 @@ const chatV1 = async (req, res) => { } const checkBalanceBeforeRun = async () => { - if (!isEnabled(process.env.CHECK_BALANCE)) { + const balance = req.app?.locals?.balance; + if (!balance?.enabled) { return; } const transactions = @@ -378,8 +379,8 @@ const chatV1 = async (req, res) => { body.additional_instructions ? `${body.additional_instructions}\n` : '' }The user has uploaded ${imageCount} image${pluralized}. Use the \`${ImageVisionTool.function.name}\` tool to retrieve ${ - plural ? '' : 'a ' -}detailed text description${pluralized} for ${plural ? 'each' : 'the'} image${pluralized}.`; + plural ? '' : 'a ' + }detailed text description${pluralized} for ${plural ? 'each' : 'the'} image${pluralized}.`; return files; }; @@ -575,6 +576,8 @@ const chatV1 = async (req, res) => { thread_id, model: assistant_id, endpoint, + spec: endpointOption.spec, + iconURL: endpointOption.iconURL, }; sendMessage(res, { diff --git a/api/server/controllers/assistants/chatV2.js b/api/server/controllers/assistants/chatV2.js index 24a8e38fa4..309e5a86c4 100644 --- a/api/server/controllers/assistants/chatV2.js +++ b/api/server/controllers/assistants/chatV2.js @@ -18,14 +18,14 @@ const { saveAssistantMessage, } = require('~/server/services/Threads'); const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService'); -const { sendMessage, sleep, isEnabled, countTokens } = require('~/server/utils'); const { createErrorHandler } = require('~/server/controllers/assistants/errors'); const validateAuthor = require('~/server/middleware/assistants/validateAuthor'); const { createRun, StreamRunManager } = require('~/server/services/Runs'); const { addTitle } = require('~/server/services/Endpoints/assistants'); +const { sendMessage, sleep, countTokens } = require('~/server/utils'); const { createRunBody } = require('~/server/services/createRunBody'); const { getTransactions } = require('~/models/Transaction'); -const checkBalance = require('~/models/checkBalance'); +const { checkBalance } = require('~/models/balanceMethods'); const { getConvo } = require('~/models/Conversation'); const getLogStores = require('~/cache/getLogStores'); const { getModelMaxTokens } = require('~/utils'); @@ -124,7 +124,8 @@ const chatV2 = async (req, res) => { } const checkBalanceBeforeRun = async () => { - if (!isEnabled(process.env.CHECK_BALANCE)) { + const balance = req.app?.locals?.balance; + if (!balance?.enabled) { return; } const transactions = @@ -427,6 +428,8 @@ const chatV2 = async (req, res) => { thread_id, model: assistant_id, endpoint, + spec: endpointOption.spec, + iconURL: endpointOption.iconURL, }; sendMessage(res, { diff --git a/api/server/controllers/auth/LoginController.js b/api/server/controllers/auth/LoginController.js index 1b543e9baf..226b5605cc 100644 --- a/api/server/controllers/auth/LoginController.js +++ b/api/server/controllers/auth/LoginController.js @@ -1,3 +1,4 @@ +const { generate2FATempToken } = require('~/server/services/twoFactorService'); const { setAuthTokens } = require('~/server/services/AuthService'); const { logger } = require('~/config'); @@ -7,7 +8,12 @@ const loginController = async (req, res) => { return res.status(400).json({ message: 'Invalid credentials' }); } - const { password: _, __v, ...user } = req.user; + if (req.user.twoFactorEnabled) { + const tempToken = generate2FATempToken(req.user._id); + return res.status(200).json({ twoFAPending: true, tempToken }); + } + + const { password: _p, totpSecret: _t, __v, ...user } = req.user; user.id = user._id.toString(); const token = await setAuthTokens(req.user._id, res); diff --git a/api/server/controllers/auth/TwoFactorAuthController.js b/api/server/controllers/auth/TwoFactorAuthController.js new file mode 100644 index 0000000000..15cde8122a --- /dev/null +++ b/api/server/controllers/auth/TwoFactorAuthController.js @@ -0,0 +1,60 @@ +const jwt = require('jsonwebtoken'); +const { + verifyTOTP, + verifyBackupCode, + getTOTPSecret, +} = require('~/server/services/twoFactorService'); +const { setAuthTokens } = require('~/server/services/AuthService'); +const { getUserById } = require('~/models/userMethods'); +const { logger } = require('~/config'); + +/** + * Verifies the 2FA code during login using a temporary token. + */ +const verify2FAWithTempToken = async (req, res) => { + try { + const { tempToken, token, backupCode } = req.body; + if (!tempToken) { + return res.status(400).json({ message: 'Missing temporary token' }); + } + + let payload; + try { + payload = jwt.verify(tempToken, process.env.JWT_SECRET); + } catch (err) { + return res.status(401).json({ message: 'Invalid or expired temporary token' }); + } + + const user = await getUserById(payload.userId); + if (!user || !user.twoFactorEnabled) { + return res.status(400).json({ message: '2FA is not enabled for this user' }); + } + + const secret = await getTOTPSecret(user.totpSecret); + let isVerified = false; + if (token) { + isVerified = await verifyTOTP(secret, token); + } else if (backupCode) { + isVerified = await verifyBackupCode({ user, backupCode }); + } + + if (!isVerified) { + return res.status(401).json({ message: 'Invalid 2FA code or backup code' }); + } + + // Prepare user data to return (omit sensitive fields). + const userData = user.toObject ? user.toObject() : { ...user }; + delete userData.password; + delete userData.__v; + delete userData.totpSecret; + userData.id = user._id.toString(); + + const authToken = await setAuthTokens(user._id, res); + return res.status(200).json({ token: authToken, user: userData }); + } catch (err) { + logger.error('[verify2FAWithTempToken]', err); + return res.status(500).json({ message: 'Something went wrong' }); + } +}; + +module.exports = { verify2FAWithTempToken }; diff --git a/api/server/controllers/tools.js b/api/server/controllers/tools.js index 9460e66136..b37b6fcb8c 100644 --- a/api/server/controllers/tools.js +++ b/api/server/controllers/tools.js @@ -1,10 +1,18 @@ const { nanoid } = require('nanoid'); const { EnvVar } = require('@librechat/agents'); -const { Tools, AuthType, ToolCallTypes } = require('librechat-data-provider'); +const { + Tools, + AuthType, + Permissions, + ToolCallTypes, + PermissionTypes, +} = require('librechat-data-provider'); const { processFileURL, uploadImageBuffer } = require('~/server/services/Files/process'); const { processCodeOutput } = require('~/server/services/Files/Code/process'); -const { loadAuthValues, loadTools } = require('~/app/clients/tools/util'); const { createToolCall, getToolCallsByConvo } = require('~/models/ToolCall'); +const { loadAuthValues } = require('~/server/services/Tools/credentials'); +const { loadTools } = require('~/app/clients/tools/util'); +const { checkAccess } = require('~/server/middleware'); const { getMessage } = require('~/models/Message'); const { logger } = require('~/config'); @@ -12,6 +20,10 @@ const fieldsMap = { [Tools.execute_code]: [EnvVar.CODE_API_KEY], }; +const toolAccessPermType = { + [Tools.execute_code]: PermissionTypes.RUN_CODE, +}; + /** * @param {ServerRequest} req - The request object, containing information about the HTTP request. * @param {ServerResponse} res - The response object, used to send back the desired HTTP response. @@ -58,6 +70,7 @@ const verifyToolAuth = async (req, res) => { /** * @param {ServerRequest} req - The request object, containing information about the HTTP request. * @param {ServerResponse} res - The response object, used to send back the desired HTTP response. + * @param {NextFunction} next - The next middleware function to call. * @returns {Promise} A promise that resolves when the function has completed. */ const callTool = async (req, res) => { @@ -83,6 +96,16 @@ const callTool = async (req, res) => { return; } logger.debug(`[${toolId}/call] User: ${req.user.id}`); + let hasAccess = true; + if (toolAccessPermType[toolId]) { + hasAccess = await checkAccess(req.user, toolAccessPermType[toolId], [Permissions.USE]); + } + if (!hasAccess) { + logger.warn( + `[${toolAccessPermType[toolId]}] Forbidden: Insufficient permissions for User ${req.user.id}: ${Permissions.USE}`, + ); + return res.status(403).json({ message: 'Forbidden: Insufficient permissions' }); + } const { loadedTools } = await loadTools({ user: req.user.id, tools: [toolId], diff --git a/api/server/index.js b/api/server/index.js index 30d36d9a9f..cd0bdd3f88 100644 --- a/api/server/index.js +++ b/api/server/index.js @@ -22,10 +22,11 @@ const staticCache = require('./utils/staticCache'); const noIndex = require('./middleware/noIndex'); const routes = require('./routes'); -const { PORT, HOST, ALLOW_SOCIAL_LOGIN, DISABLE_COMPRESSION } = process.env ?? {}; +const { PORT, HOST, ALLOW_SOCIAL_LOGIN, DISABLE_COMPRESSION, TRUST_PROXY } = process.env ?? {}; const port = Number(PORT) || 3080; const host = HOST || 'localhost'; +const trusted_proxy = Number(TRUST_PROXY) || 1; /* trust first proxy by default */ const startServer = async () => { if (typeof Bun !== 'undefined') { @@ -53,7 +54,7 @@ const startServer = async () => { app.use(staticCache(app.locals.paths.dist)); app.use(staticCache(app.locals.paths.fonts)); app.use(staticCache(app.locals.paths.assets)); - app.set('trust proxy', 1); /* trust first proxy */ + app.set('trust proxy', trusted_proxy); app.use(cors()); app.use(cookieParser()); @@ -87,8 +88,8 @@ const startServer = async () => { app.use('/api/actions', routes.actions); app.use('/api/keys', routes.keys); app.use('/api/user', routes.user); - app.use('/api/search', routes.search); app.use('/api/ask', routes.ask); + app.use('/api/search', routes.search); app.use('/api/edit', routes.edit); app.use('/api/messages', routes.messages); app.use('/api/convos', routes.convos); @@ -145,6 +146,18 @@ process.on('uncaughtException', (err) => { logger.error('There was an uncaught error:', err); } + if (err.message.includes('abort')) { + logger.warn('There was an uncatchable AbortController error.'); + return; + } + + if (err.message.includes('GoogleGenerativeAI')) { + logger.warn( + '\n\n`GoogleGenerativeAI` errors cannot be caught due to an upstream issue, see: https://github.com/google-gemini/generative-ai-js/issues/303', + ); + return; + } + if (err.message.includes('fetch failed')) { if (messageCount === 0) { logger.warn('Meilisearch error, search will be disabled'); diff --git a/api/server/middleware/abortMiddleware.js b/api/server/middleware/abortMiddleware.js index 2137523efe..bfc28f513d 100644 --- a/api/server/middleware/abortMiddleware.js +++ b/api/server/middleware/abortMiddleware.js @@ -1,3 +1,4 @@ +// abortMiddleware.js const { isAssistantsEndpoint, ErrorTypes } = require('librechat-data-provider'); const { sendMessage, sendError, countTokens, isEnabled } = require('~/server/utils'); const { truncateText, smartTruncateText } = require('~/app/clients/prompts'); @@ -8,6 +9,68 @@ const { saveMessage, getConvo } = require('~/models'); const { abortRun } = require('./abortRun'); const { logger } = require('~/config'); +const abortDataMap = new WeakMap(); + +function cleanupAbortController(abortKey) { + if (!abortControllers.has(abortKey)) { + return false; + } + + const { abortController } = abortControllers.get(abortKey); + + if (!abortController) { + abortControllers.delete(abortKey); + return true; + } + + // 1. Check if this controller has any composed signals and clean them up + try { + // This creates a temporary composed signal to use for cleanup + const composedSignal = AbortSignal.any([abortController.signal]); + + // Get all event types - in practice, AbortSignal typically only uses 'abort' + const eventTypes = ['abort']; + + // First, execute a dummy listener removal to handle potential composed signals + for (const eventType of eventTypes) { + const dummyHandler = () => {}; + composedSignal.addEventListener(eventType, dummyHandler); + composedSignal.removeEventListener(eventType, dummyHandler); + + const listeners = composedSignal.listeners?.(eventType) || []; + for (const listener of listeners) { + composedSignal.removeEventListener(eventType, listener); + } + } + } catch (e) { + logger.debug(`Error cleaning up composed signals: ${e}`); + } + + // 2. Abort the controller if not already aborted + if (!abortController.signal.aborted) { + abortController.abort(); + } + + // 3. Remove from registry + abortControllers.delete(abortKey); + + // 4. Clean up any data stored in the WeakMap + if (abortDataMap.has(abortController)) { + abortDataMap.delete(abortController); + } + + // 5. Clean up function references on the controller + if (abortController.getAbortData) { + abortController.getAbortData = null; + } + + if (abortController.abortCompletion) { + abortController.abortCompletion = null; + } + + return true; +} + async function abortMessage(req, res) { let { abortKey, endpoint } = req.body; @@ -29,24 +92,24 @@ async function abortMessage(req, res) { if (!abortController) { return res.status(204).send({ message: 'Request not found' }); } - const finalEvent = await abortController.abortCompletion(); + + const finalEvent = await abortController.abortCompletion?.(); logger.debug( `[abortMessage] ID: ${req.user.id} | ${req.user.email} | Aborted request: ` + JSON.stringify({ abortKey }), ); - abortControllers.delete(abortKey); + cleanupAbortController(abortKey); if (res.headersSent && finalEvent) { return sendMessage(res, finalEvent); } res.setHeader('Content-Type', 'application/json'); - res.send(JSON.stringify(finalEvent)); } -const handleAbort = () => { - return async (req, res) => { +const handleAbort = function () { + return async function (req, res) { try { if (isEnabled(process.env.LIMIT_CONCURRENT_MESSAGES)) { await clearPendingReq({ userId: req.user.id }); @@ -62,8 +125,48 @@ const createAbortController = (req, res, getAbortData, getReqData) => { const abortController = new AbortController(); const { endpointOption } = req.body; + // Store minimal data in WeakMap to avoid circular references + abortDataMap.set(abortController, { + getAbortDataFn: getAbortData, + userId: req.user.id, + endpoint: endpointOption.endpoint, + iconURL: endpointOption.iconURL, + model: endpointOption.modelOptions?.model || endpointOption.model_parameters?.model, + }); + + // Replace the direct function reference with a wrapper that uses WeakMap abortController.getAbortData = function () { - return getAbortData(); + const data = abortDataMap.get(this); + if (!data || typeof data.getAbortDataFn !== 'function') { + return {}; + } + + try { + const result = data.getAbortDataFn(); + + // Create a copy without circular references + const cleanResult = { ...result }; + + // If userMessagePromise exists, break its reference to client + if ( + cleanResult.userMessagePromise && + typeof cleanResult.userMessagePromise.then === 'function' + ) { + // Create a new promise that fulfills with the same result but doesn't reference the original + const originalPromise = cleanResult.userMessagePromise; + cleanResult.userMessagePromise = new Promise((resolve, reject) => { + originalPromise.then( + (result) => resolve({ ...result }), + (error) => reject(error), + ); + }); + } + + return cleanResult; + } catch (err) { + logger.error('[abortController.getAbortData] Error:', err); + return {}; + } }; /** @@ -74,6 +177,7 @@ const createAbortController = (req, res, getAbortData, getReqData) => { sendMessage(res, { message: userMessage, created: true }); const abortKey = userMessage?.conversationId ?? req.user.id; + getReqData({ abortKey }); const prevRequest = abortControllers.get(abortKey); const { overrideUserMessageId } = req?.body ?? {}; @@ -81,34 +185,74 @@ const createAbortController = (req, res, getAbortData, getReqData) => { const data = prevRequest.abortController.getAbortData(); getReqData({ userMessage: data?.userMessage }); const addedAbortKey = `${abortKey}:${responseMessageId}`; - abortControllers.set(addedAbortKey, { abortController, ...endpointOption }); - res.on('finish', function () { - abortControllers.delete(addedAbortKey); - }); + + // Store minimal options + const minimalOptions = { + endpoint: endpointOption.endpoint, + iconURL: endpointOption.iconURL, + model: endpointOption.modelOptions?.model || endpointOption.model_parameters?.model, + }; + + abortControllers.set(addedAbortKey, { abortController, ...minimalOptions }); + + // Use a simple function for cleanup to avoid capturing context + const cleanupHandler = () => { + try { + cleanupAbortController(addedAbortKey); + } catch (e) { + // Ignore cleanup errors + } + }; + + res.on('finish', cleanupHandler); return; } - abortControllers.set(abortKey, { abortController, ...endpointOption }); + // Store minimal options + const minimalOptions = { + endpoint: endpointOption.endpoint, + iconURL: endpointOption.iconURL, + model: endpointOption.modelOptions?.model || endpointOption.model_parameters?.model, + }; - res.on('finish', function () { - abortControllers.delete(abortKey); - }); + abortControllers.set(abortKey, { abortController, ...minimalOptions }); + + // Use a simple function for cleanup to avoid capturing context + const cleanupHandler = () => { + try { + cleanupAbortController(abortKey); + } catch (e) { + // Ignore cleanup errors + } + }; + + res.on('finish', cleanupHandler); }; + // Define abortCompletion without capturing the entire parent scope abortController.abortCompletion = async function () { - abortController.abort(); + this.abort(); + + // Get data from WeakMap + const ctrlData = abortDataMap.get(this); + if (!ctrlData || !ctrlData.getAbortDataFn) { + return { final: true, conversation: {}, title: 'New Chat' }; + } + + // Get abort data using stored function const { conversationId, userMessage, userMessagePromise, promptTokens, ...responseData } = - getAbortData(); + ctrlData.getAbortDataFn(); + const completionTokens = await countTokens(responseData?.text ?? ''); - const user = req.user.id; + const user = ctrlData.userId; const responseMessage = { ...responseData, conversationId, finish_reason: 'incomplete', - endpoint: endpointOption.endpoint, - iconURL: endpointOption.iconURL, - model: endpointOption.modelOptions?.model ?? endpointOption.model_parameters?.model, + endpoint: ctrlData.endpoint, + iconURL: ctrlData.iconURL, + model: ctrlData.modelOptions?.model ?? ctrlData.model_parameters?.model, unfinished: false, error: false, isCreatedByUser: false, @@ -120,7 +264,7 @@ const createAbortController = (req, res, getAbortData, getReqData) => { { promptTokens, completionTokens }, ); - saveMessage( + await saveMessage( req, { ...responseMessage, user }, { context: 'api/server/middleware/abortMiddleware.js' }, @@ -130,10 +274,12 @@ const createAbortController = (req, res, getAbortData, getReqData) => { if (userMessagePromise) { const resolved = await userMessagePromise; conversation = resolved?.conversation; + // Break reference to promise + resolved.conversation = null; } if (!conversation) { - conversation = await getConvo(req.user.id, conversationId); + conversation = await getConvo(user, conversationId); } return { @@ -148,6 +294,13 @@ const createAbortController = (req, res, getAbortData, getReqData) => { return { abortController, onStart }; }; +/** + * @param {ServerResponse} res + * @param {ServerRequest} req + * @param {Error | unknown} error + * @param {Partial & { partialText?: string }} data + * @returns { Promise } + */ const handleAbortError = async (res, req, error, data) => { if (error?.message?.includes('base64')) { logger.error('[handleAbortError] Error in base64 encoding', { @@ -158,7 +311,7 @@ const handleAbortError = async (res, req, error, data) => { } else { logger.error('[handleAbortError] AI response error; aborting request:', error); } - const { sender, conversationId, messageId, parentMessageId, partialText } = data; + const { sender, conversationId, messageId, parentMessageId, userMessageId, partialText } = data; if (error.stack && error.stack.includes('google')) { logger.warn( @@ -178,17 +331,30 @@ const handleAbortError = async (res, req, error, data) => { errorText = `{"type":"${ErrorTypes.NO_SYSTEM_MESSAGES}"}`; } + /** + * @param {string} partialText + * @returns {Promise} + */ const respondWithError = async (partialText) => { + const endpointOption = req.body?.endpointOption; let options = { sender, messageId, conversationId, parentMessageId, text: errorText, - shouldSaveMessage: true, user: req.user.id, + spec: endpointOption?.spec, + iconURL: endpointOption?.iconURL, + modelLabel: endpointOption?.modelLabel, + shouldSaveMessage: userMessageId != null, + model: endpointOption?.modelOptions?.model || req.body?.model, }; + if (req.body?.agent_id) { + options.agent_id = req.body.agent_id; + } + if (partialText) { options = { ...options, @@ -198,11 +364,12 @@ const handleAbortError = async (res, req, error, data) => { }; } + // Create a simple callback without capturing parent scope const callback = async () => { - if (abortControllers.has(conversationId)) { - const { abortController } = abortControllers.get(conversationId); - abortController.abort(); - abortControllers.delete(conversationId); + try { + cleanupAbortController(conversationId); + } catch (e) { + // Ignore cleanup errors } }; @@ -223,6 +390,7 @@ const handleAbortError = async (res, req, error, data) => { module.exports = { handleAbort, - createAbortController, handleAbortError, + createAbortController, + cleanupAbortController, }; diff --git a/api/server/middleware/buildEndpointOption.js b/api/server/middleware/buildEndpointOption.js index a0ce754a1c..8394223b5e 100644 --- a/api/server/middleware/buildEndpointOption.js +++ b/api/server/middleware/buildEndpointOption.js @@ -1,6 +1,11 @@ -const { parseCompactConvo, EModelEndpoint, isAgentsEndpoint } = require('librechat-data-provider'); -const { getModelsConfig } = require('~/server/controllers/ModelController'); +const { + parseCompactConvo, + EModelEndpoint, + isAgentsEndpoint, + EndpointURLs, +} = require('librechat-data-provider'); const azureAssistants = require('~/server/services/Endpoints/azureAssistants'); +const { getModelsConfig } = require('~/server/controllers/ModelController'); const assistants = require('~/server/services/Endpoints/assistants'); const gptPlugins = require('~/server/services/Endpoints/gptPlugins'); const { processFiles } = require('~/server/services/Files/process'); @@ -10,7 +15,6 @@ const openAI = require('~/server/services/Endpoints/openAI'); const agents = require('~/server/services/Endpoints/agents'); const custom = require('~/server/services/Endpoints/custom'); const google = require('~/server/services/Endpoints/google'); -const { getConvoFiles } = require('~/models/Conversation'); const { handleError } = require('~/server/utils'); const buildFunction = { @@ -78,8 +82,9 @@ async function buildEndpointOption(req, res, next) { } try { - const isAgents = isAgentsEndpoint(endpoint); - const endpointFn = buildFunction[endpointType ?? endpoint]; + const isAgents = + isAgentsEndpoint(endpoint) || req.baseUrl.startsWith(EndpointURLs[EModelEndpoint.agents]); + const endpointFn = buildFunction[isAgents ? EModelEndpoint.agents : (endpointType ?? endpoint)]; const builder = isAgents ? (...args) => endpointFn(req, ...args) : endpointFn; // TODO: use object params @@ -87,16 +92,8 @@ async function buildEndpointOption(req, res, next) { // TODO: use `getModelsConfig` only when necessary const modelsConfig = await getModelsConfig(req); - const { resendFiles = true } = req.body.endpointOption; req.body.endpointOption.modelsConfig = modelsConfig; - if (isAgents && resendFiles && req.body.conversationId) { - const fileIds = await getConvoFiles(req.body.conversationId); - const requestFiles = req.body.files ?? []; - if (requestFiles.length || fileIds.length) { - req.body.endpointOption.attachments = processFiles(requestFiles, fileIds); - } - } else if (req.body.files) { - // hold the promise + if (req.body.files && !isAgents) { req.body.endpointOption.attachments = processFiles(req.body.files); } next(); diff --git a/api/server/middleware/checkBan.js b/api/server/middleware/checkBan.js index c397ca7d1a..4e0593192a 100644 --- a/api/server/middleware/checkBan.js +++ b/api/server/middleware/checkBan.js @@ -1,4 +1,4 @@ -const Keyv = require('keyv'); +const { Keyv } = require('keyv'); const uap = require('ua-parser-js'); const { ViolationTypes } = require('librechat-data-provider'); const { isEnabled, removePorts } = require('~/server/utils'); @@ -41,7 +41,7 @@ const banResponse = async (req, res) => { * @function * @param {Object} req - Express request object. * @param {Object} res - Express response object. - * @param {Function} next - Next middleware function. + * @param {import('express').NextFunction} next - Next middleware function. * * @returns {Promise} - Returns a Promise which when resolved calls next middleware if user or source IP is not banned. Otherwise calls `banResponse()` and sets ban details in `banCache`. */ diff --git a/api/server/middleware/concurrentLimiter.js b/api/server/middleware/concurrentLimiter.js index 58ff689a0b..73de65dd25 100644 --- a/api/server/middleware/concurrentLimiter.js +++ b/api/server/middleware/concurrentLimiter.js @@ -1,4 +1,4 @@ -const { Time } = require('librechat-data-provider'); +const { Time, CacheKeys } = require('librechat-data-provider'); const clearPendingReq = require('~/cache/clearPendingReq'); const { logViolation, getLogStores } = require('~/cache'); const { isEnabled } = require('~/server/utils'); @@ -21,11 +21,11 @@ const { * @function * @param {Object} req - Express request object containing user information. * @param {Object} res - Express response object. - * @param {function} next - Express next middleware function. + * @param {import('express').NextFunction} next - Next middleware function. * @throws {Error} Throws an error if the user exceeds the concurrent request limit. */ const concurrentLimiter = async (req, res, next) => { - const namespace = 'pending_req'; + const namespace = CacheKeys.PENDING_REQ; const cache = getLogStores(namespace); if (!cache) { return next(); diff --git a/api/server/middleware/index.js b/api/server/middleware/index.js index 3da9e06bd6..6a41d6f157 100644 --- a/api/server/middleware/index.js +++ b/api/server/middleware/index.js @@ -8,12 +8,14 @@ const concurrentLimiter = require('./concurrentLimiter'); const validateEndpoint = require('./validateEndpoint'); const requireLocalAuth = require('./requireLocalAuth'); const canDeleteAccount = require('./canDeleteAccount'); +const setBalanceConfig = require('./setBalanceConfig'); const requireLdapAuth = require('./requireLdapAuth'); const abortMiddleware = require('./abortMiddleware'); const checkInviteUser = require('./checkInviteUser'); const requireJwtAuth = require('./requireJwtAuth'); const validateModel = require('./validateModel'); const moderateText = require('./moderateText'); +const logHeaders = require('./logHeaders'); const setHeaders = require('./setHeaders'); const validate = require('./validate'); const limiters = require('./limiters'); @@ -31,6 +33,7 @@ module.exports = { checkBan, uaParser, setHeaders, + logHeaders, moderateText, validateModel, requireJwtAuth, @@ -39,6 +42,7 @@ module.exports = { requireLocalAuth, canDeleteAccount, validateEndpoint, + setBalanceConfig, concurrentLimiter, checkDomainAllowed, validateMessageReq, diff --git a/api/server/middleware/limiters/importLimiters.js b/api/server/middleware/limiters/importLimiters.js index a21fa6453e..f353f5e996 100644 --- a/api/server/middleware/limiters/importLimiters.js +++ b/api/server/middleware/limiters/importLimiters.js @@ -1,6 +1,10 @@ const rateLimit = require('express-rate-limit'); +const { RedisStore } = require('rate-limit-redis'); const { ViolationTypes } = require('librechat-data-provider'); +const ioredisClient = require('~/cache/ioredisClient'); const logViolation = require('~/cache/logViolation'); +const { isEnabled } = require('~/server/utils'); +const { logger } = require('~/config'); const getEnvironmentVariables = () => { const IMPORT_IP_MAX = parseInt(process.env.IMPORT_IP_MAX) || 100; @@ -48,21 +52,37 @@ const createImportLimiters = () => { const { importIpWindowMs, importIpMax, importUserWindowMs, importUserMax } = getEnvironmentVariables(); - const importIpLimiter = rateLimit({ + const ipLimiterOptions = { windowMs: importIpWindowMs, max: importIpMax, handler: createImportHandler(), - }); - - const importUserLimiter = rateLimit({ + }; + const userLimiterOptions = { windowMs: importUserWindowMs, max: importUserMax, handler: createImportHandler(false), keyGenerator: function (req) { return req.user?.id; // Use the user ID or NULL if not available }, - }); + }; + if (isEnabled(process.env.USE_REDIS) && ioredisClient) { + logger.debug('Using Redis for import rate limiters.'); + const sendCommand = (...args) => ioredisClient.call(...args); + const ipStore = new RedisStore({ + sendCommand, + prefix: 'import_ip_limiter:', + }); + const userStore = new RedisStore({ + sendCommand, + prefix: 'import_user_limiter:', + }); + ipLimiterOptions.store = ipStore; + userLimiterOptions.store = userStore; + } + + const importIpLimiter = rateLimit(ipLimiterOptions); + const importUserLimiter = rateLimit(userLimiterOptions); return { importIpLimiter, importUserLimiter }; }; diff --git a/api/server/middleware/limiters/loginLimiter.js b/api/server/middleware/limiters/loginLimiter.js index 937723e859..d57af29414 100644 --- a/api/server/middleware/limiters/loginLimiter.js +++ b/api/server/middleware/limiters/loginLimiter.js @@ -1,6 +1,9 @@ const rateLimit = require('express-rate-limit'); -const { removePorts } = require('~/server/utils'); +const { RedisStore } = require('rate-limit-redis'); +const { removePorts, isEnabled } = require('~/server/utils'); +const ioredisClient = require('~/cache/ioredisClient'); const { logViolation } = require('~/cache'); +const { logger } = require('~/config'); const { LOGIN_WINDOW = 5, LOGIN_MAX = 7, LOGIN_VIOLATION_SCORE: score } = process.env; const windowMs = LOGIN_WINDOW * 60 * 1000; @@ -20,11 +23,22 @@ const handler = async (req, res) => { return res.status(429).json({ message }); }; -const loginLimiter = rateLimit({ +const limiterOptions = { windowMs, max, handler, keyGenerator: removePorts, -}); +}; + +if (isEnabled(process.env.USE_REDIS) && ioredisClient) { + logger.debug('Using Redis for login rate limiter.'); + const store = new RedisStore({ + sendCommand: (...args) => ioredisClient.call(...args), + prefix: 'login_limiter:', + }); + limiterOptions.store = store; +} + +const loginLimiter = rateLimit(limiterOptions); module.exports = loginLimiter; diff --git a/api/server/middleware/limiters/messageLimiters.js b/api/server/middleware/limiters/messageLimiters.js index c84db1043c..4191c9fe7c 100644 --- a/api/server/middleware/limiters/messageLimiters.js +++ b/api/server/middleware/limiters/messageLimiters.js @@ -1,6 +1,10 @@ const rateLimit = require('express-rate-limit'); +const { RedisStore } = require('rate-limit-redis'); const denyRequest = require('~/server/middleware/denyRequest'); +const ioredisClient = require('~/cache/ioredisClient'); +const { isEnabled } = require('~/server/utils'); const { logViolation } = require('~/cache'); +const { logger } = require('~/config'); const { MESSAGE_IP_MAX = 40, @@ -41,25 +45,47 @@ const createHandler = (ip = true) => { }; /** - * Message request rate limiter by IP + * Message request rate limiters */ -const messageIpLimiter = rateLimit({ +const ipLimiterOptions = { windowMs: ipWindowMs, max: ipMax, handler: createHandler(), -}); +}; -/** - * Message request rate limiter by userId - */ -const messageUserLimiter = rateLimit({ +const userLimiterOptions = { windowMs: userWindowMs, max: userMax, handler: createHandler(false), keyGenerator: function (req) { return req.user?.id; // Use the user ID or NULL if not available }, -}); +}; + +if (isEnabled(process.env.USE_REDIS) && ioredisClient) { + logger.debug('Using Redis for message rate limiters.'); + const sendCommand = (...args) => ioredisClient.call(...args); + const ipStore = new RedisStore({ + sendCommand, + prefix: 'message_ip_limiter:', + }); + const userStore = new RedisStore({ + sendCommand, + prefix: 'message_user_limiter:', + }); + ipLimiterOptions.store = ipStore; + userLimiterOptions.store = userStore; +} + +/** + * Message request rate limiter by IP + */ +const messageIpLimiter = rateLimit(ipLimiterOptions); + +/** + * Message request rate limiter by userId + */ +const messageUserLimiter = rateLimit(userLimiterOptions); module.exports = { messageIpLimiter, diff --git a/api/server/middleware/limiters/registerLimiter.js b/api/server/middleware/limiters/registerLimiter.js index b069798b03..7d38b3044e 100644 --- a/api/server/middleware/limiters/registerLimiter.js +++ b/api/server/middleware/limiters/registerLimiter.js @@ -1,6 +1,9 @@ const rateLimit = require('express-rate-limit'); -const { removePorts } = require('~/server/utils'); +const { RedisStore } = require('rate-limit-redis'); +const { removePorts, isEnabled } = require('~/server/utils'); +const ioredisClient = require('~/cache/ioredisClient'); const { logViolation } = require('~/cache'); +const { logger } = require('~/config'); const { REGISTER_WINDOW = 60, REGISTER_MAX = 5, REGISTRATION_VIOLATION_SCORE: score } = process.env; const windowMs = REGISTER_WINDOW * 60 * 1000; @@ -20,11 +23,22 @@ const handler = async (req, res) => { return res.status(429).json({ message }); }; -const registerLimiter = rateLimit({ +const limiterOptions = { windowMs, max, handler, keyGenerator: removePorts, -}); +}; + +if (isEnabled(process.env.USE_REDIS) && ioredisClient) { + logger.debug('Using Redis for register rate limiter.'); + const store = new RedisStore({ + sendCommand: (...args) => ioredisClient.call(...args), + prefix: 'register_limiter:', + }); + limiterOptions.store = store; +} + +const registerLimiter = rateLimit(limiterOptions); module.exports = registerLimiter; diff --git a/api/server/middleware/limiters/resetPasswordLimiter.js b/api/server/middleware/limiters/resetPasswordLimiter.js index 5d2deb0282..673b23e8e5 100644 --- a/api/server/middleware/limiters/resetPasswordLimiter.js +++ b/api/server/middleware/limiters/resetPasswordLimiter.js @@ -1,7 +1,10 @@ const rateLimit = require('express-rate-limit'); +const { RedisStore } = require('rate-limit-redis'); const { ViolationTypes } = require('librechat-data-provider'); -const { removePorts } = require('~/server/utils'); +const { removePorts, isEnabled } = require('~/server/utils'); +const ioredisClient = require('~/cache/ioredisClient'); const { logViolation } = require('~/cache'); +const { logger } = require('~/config'); const { RESET_PASSWORD_WINDOW = 2, @@ -25,11 +28,22 @@ const handler = async (req, res) => { return res.status(429).json({ message }); }; -const resetPasswordLimiter = rateLimit({ +const limiterOptions = { windowMs, max, handler, keyGenerator: removePorts, -}); +}; + +if (isEnabled(process.env.USE_REDIS) && ioredisClient) { + logger.debug('Using Redis for reset password rate limiter.'); + const store = new RedisStore({ + sendCommand: (...args) => ioredisClient.call(...args), + prefix: 'reset_password_limiter:', + }); + limiterOptions.store = store; +} + +const resetPasswordLimiter = rateLimit(limiterOptions); module.exports = resetPasswordLimiter; diff --git a/api/server/middleware/limiters/sttLimiters.js b/api/server/middleware/limiters/sttLimiters.js index 76f2944f0a..72ed3af6a3 100644 --- a/api/server/middleware/limiters/sttLimiters.js +++ b/api/server/middleware/limiters/sttLimiters.js @@ -1,6 +1,10 @@ const rateLimit = require('express-rate-limit'); +const { RedisStore } = require('rate-limit-redis'); const { ViolationTypes } = require('librechat-data-provider'); +const ioredisClient = require('~/cache/ioredisClient'); const logViolation = require('~/cache/logViolation'); +const { isEnabled } = require('~/server/utils'); +const { logger } = require('~/config'); const getEnvironmentVariables = () => { const STT_IP_MAX = parseInt(process.env.STT_IP_MAX) || 100; @@ -47,20 +51,38 @@ const createSTTHandler = (ip = true) => { const createSTTLimiters = () => { const { sttIpWindowMs, sttIpMax, sttUserWindowMs, sttUserMax } = getEnvironmentVariables(); - const sttIpLimiter = rateLimit({ + const ipLimiterOptions = { windowMs: sttIpWindowMs, max: sttIpMax, handler: createSTTHandler(), - }); + }; - const sttUserLimiter = rateLimit({ + const userLimiterOptions = { windowMs: sttUserWindowMs, max: sttUserMax, handler: createSTTHandler(false), keyGenerator: function (req) { return req.user?.id; // Use the user ID or NULL if not available }, - }); + }; + + if (isEnabled(process.env.USE_REDIS) && ioredisClient) { + logger.debug('Using Redis for STT rate limiters.'); + const sendCommand = (...args) => ioredisClient.call(...args); + const ipStore = new RedisStore({ + sendCommand, + prefix: 'stt_ip_limiter:', + }); + const userStore = new RedisStore({ + sendCommand, + prefix: 'stt_user_limiter:', + }); + ipLimiterOptions.store = ipStore; + userLimiterOptions.store = userStore; + } + + const sttIpLimiter = rateLimit(ipLimiterOptions); + const sttUserLimiter = rateLimit(userLimiterOptions); return { sttIpLimiter, sttUserLimiter }; }; diff --git a/api/server/middleware/limiters/toolCallLimiter.js b/api/server/middleware/limiters/toolCallLimiter.js index 47dcaeabb4..482744a3e9 100644 --- a/api/server/middleware/limiters/toolCallLimiter.js +++ b/api/server/middleware/limiters/toolCallLimiter.js @@ -1,25 +1,42 @@ const rateLimit = require('express-rate-limit'); +const { RedisStore } = require('rate-limit-redis'); const { ViolationTypes } = require('librechat-data-provider'); +const ioredisClient = require('~/cache/ioredisClient'); const logViolation = require('~/cache/logViolation'); +const { isEnabled } = require('~/server/utils'); +const { logger } = require('~/config'); -const toolCallLimiter = rateLimit({ +const handler = async (req, res) => { + const type = ViolationTypes.TOOL_CALL_LIMIT; + const errorMessage = { + type, + max: 1, + limiter: 'user', + windowInMinutes: 1, + }; + + await logViolation(req, res, type, errorMessage, 0); + res.status(429).json({ message: 'Too many tool call requests. Try again later' }); +}; + +const limiterOptions = { windowMs: 1000, max: 1, - handler: async (req, res) => { - const type = ViolationTypes.TOOL_CALL_LIMIT; - const errorMessage = { - type, - max: 1, - limiter: 'user', - windowInMinutes: 1, - }; - - await logViolation(req, res, type, errorMessage, 0); - res.status(429).json({ message: 'Too many tool call requests. Try again later' }); - }, + handler, keyGenerator: function (req) { return req.user?.id; }, -}); +}; + +if (isEnabled(process.env.USE_REDIS) && ioredisClient) { + logger.debug('Using Redis for tool call rate limiter.'); + const store = new RedisStore({ + sendCommand: (...args) => ioredisClient.call(...args), + prefix: 'tool_call_limiter:', + }); + limiterOptions.store = store; +} + +const toolCallLimiter = rateLimit(limiterOptions); module.exports = toolCallLimiter; diff --git a/api/server/middleware/limiters/ttsLimiters.js b/api/server/middleware/limiters/ttsLimiters.js index 5619a49b63..9054a6beb1 100644 --- a/api/server/middleware/limiters/ttsLimiters.js +++ b/api/server/middleware/limiters/ttsLimiters.js @@ -1,6 +1,10 @@ const rateLimit = require('express-rate-limit'); +const { RedisStore } = require('rate-limit-redis'); const { ViolationTypes } = require('librechat-data-provider'); +const ioredisClient = require('~/cache/ioredisClient'); const logViolation = require('~/cache/logViolation'); +const { isEnabled } = require('~/server/utils'); +const { logger } = require('~/config'); const getEnvironmentVariables = () => { const TTS_IP_MAX = parseInt(process.env.TTS_IP_MAX) || 100; @@ -47,20 +51,38 @@ const createTTSHandler = (ip = true) => { const createTTSLimiters = () => { const { ttsIpWindowMs, ttsIpMax, ttsUserWindowMs, ttsUserMax } = getEnvironmentVariables(); - const ttsIpLimiter = rateLimit({ + const ipLimiterOptions = { windowMs: ttsIpWindowMs, max: ttsIpMax, handler: createTTSHandler(), - }); + }; - const ttsUserLimiter = rateLimit({ + const userLimiterOptions = { windowMs: ttsUserWindowMs, max: ttsUserMax, handler: createTTSHandler(false), keyGenerator: function (req) { return req.user?.id; // Use the user ID or NULL if not available }, - }); + }; + + if (isEnabled(process.env.USE_REDIS) && ioredisClient) { + logger.debug('Using Redis for TTS rate limiters.'); + const sendCommand = (...args) => ioredisClient.call(...args); + const ipStore = new RedisStore({ + sendCommand, + prefix: 'tts_ip_limiter:', + }); + const userStore = new RedisStore({ + sendCommand, + prefix: 'tts_user_limiter:', + }); + ipLimiterOptions.store = ipStore; + userLimiterOptions.store = userStore; + } + + const ttsIpLimiter = rateLimit(ipLimiterOptions); + const ttsUserLimiter = rateLimit(userLimiterOptions); return { ttsIpLimiter, ttsUserLimiter }; }; diff --git a/api/server/middleware/limiters/uploadLimiters.js b/api/server/middleware/limiters/uploadLimiters.js index 71af164fde..d9049f898e 100644 --- a/api/server/middleware/limiters/uploadLimiters.js +++ b/api/server/middleware/limiters/uploadLimiters.js @@ -1,6 +1,10 @@ const rateLimit = require('express-rate-limit'); +const { RedisStore } = require('rate-limit-redis'); const { ViolationTypes } = require('librechat-data-provider'); +const ioredisClient = require('~/cache/ioredisClient'); const logViolation = require('~/cache/logViolation'); +const { isEnabled } = require('~/server/utils'); +const { logger } = require('~/config'); const getEnvironmentVariables = () => { const FILE_UPLOAD_IP_MAX = parseInt(process.env.FILE_UPLOAD_IP_MAX) || 100; @@ -52,20 +56,38 @@ const createFileLimiters = () => { const { fileUploadIpWindowMs, fileUploadIpMax, fileUploadUserWindowMs, fileUploadUserMax } = getEnvironmentVariables(); - const fileUploadIpLimiter = rateLimit({ + const ipLimiterOptions = { windowMs: fileUploadIpWindowMs, max: fileUploadIpMax, handler: createFileUploadHandler(), - }); + }; - const fileUploadUserLimiter = rateLimit({ + const userLimiterOptions = { windowMs: fileUploadUserWindowMs, max: fileUploadUserMax, handler: createFileUploadHandler(false), keyGenerator: function (req) { return req.user?.id; // Use the user ID or NULL if not available }, - }); + }; + + if (isEnabled(process.env.USE_REDIS) && ioredisClient) { + logger.debug('Using Redis for file upload rate limiters.'); + const sendCommand = (...args) => ioredisClient.call(...args); + const ipStore = new RedisStore({ + sendCommand, + prefix: 'file_upload_ip_limiter:', + }); + const userStore = new RedisStore({ + sendCommand, + prefix: 'file_upload_user_limiter:', + }); + ipLimiterOptions.store = ipStore; + userLimiterOptions.store = userStore; + } + + const fileUploadIpLimiter = rateLimit(ipLimiterOptions); + const fileUploadUserLimiter = rateLimit(userLimiterOptions); return { fileUploadIpLimiter, fileUploadUserLimiter }; }; diff --git a/api/server/middleware/limiters/verifyEmailLimiter.js b/api/server/middleware/limiters/verifyEmailLimiter.js index 770090dba5..73bfa2daf3 100644 --- a/api/server/middleware/limiters/verifyEmailLimiter.js +++ b/api/server/middleware/limiters/verifyEmailLimiter.js @@ -1,7 +1,10 @@ const rateLimit = require('express-rate-limit'); +const { RedisStore } = require('rate-limit-redis'); const { ViolationTypes } = require('librechat-data-provider'); -const { removePorts } = require('~/server/utils'); +const { removePorts, isEnabled } = require('~/server/utils'); +const ioredisClient = require('~/cache/ioredisClient'); const { logViolation } = require('~/cache'); +const { logger } = require('~/config'); const { VERIFY_EMAIL_WINDOW = 2, @@ -25,11 +28,22 @@ const handler = async (req, res) => { return res.status(429).json({ message }); }; -const verifyEmailLimiter = rateLimit({ +const limiterOptions = { windowMs, max, handler, keyGenerator: removePorts, -}); +}; + +if (isEnabled(process.env.USE_REDIS) && ioredisClient) { + logger.debug('Using Redis for verify email rate limiter.'); + const store = new RedisStore({ + sendCommand: (...args) => ioredisClient.call(...args), + prefix: 'verify_email_limiter:', + }); + limiterOptions.store = store; +} + +const verifyEmailLimiter = rateLimit(limiterOptions); module.exports = verifyEmailLimiter; diff --git a/api/server/middleware/logHeaders.js b/api/server/middleware/logHeaders.js new file mode 100644 index 0000000000..26ca04da38 --- /dev/null +++ b/api/server/middleware/logHeaders.js @@ -0,0 +1,32 @@ +const { logger } = require('~/config'); + +/** + * Middleware to log Forwarded Headers + * @function + * @param {ServerRequest} req - Express request object containing user information. + * @param {ServerResponse} res - Express response object. + * @param {import('express').NextFunction} next - Next middleware function. + * @throws {Error} Throws an error if the user exceeds the concurrent request limit. + */ +const logHeaders = (req, res, next) => { + try { + const forwardedHeaders = {}; + if (req.headers['x-forwarded-for']) { + forwardedHeaders['x-forwarded-for'] = req.headers['x-forwarded-for']; + } + if (req.headers['x-forwarded-host']) { + forwardedHeaders['x-forwarded-host'] = req.headers['x-forwarded-host']; + } + if (req.headers['x-forwarded-proto']) { + forwardedHeaders['x-forwarded-proto'] = req.headers['x-forwarded-proto']; + } + if (Object.keys(forwardedHeaders).length > 0) { + logger.debug('X-Forwarded headers detected in OAuth request:', forwardedHeaders); + } + } catch (error) { + logger.error('Error logging X-Forwarded headers:', error); + } + next(); +}; + +module.exports = logHeaders; diff --git a/api/server/middleware/moderateText.js b/api/server/middleware/moderateText.js index 18d370b560..ff1a9de856 100644 --- a/api/server/middleware/moderateText.js +++ b/api/server/middleware/moderateText.js @@ -1,39 +1,41 @@ const axios = require('axios'); const { ErrorTypes } = require('librechat-data-provider'); +const { isEnabled } = require('~/server/utils'); const denyRequest = require('./denyRequest'); const { logger } = require('~/config'); async function moderateText(req, res, next) { - if (process.env.OPENAI_MODERATION === 'true') { - try { - const { text } = req.body; + if (!isEnabled(process.env.OPENAI_MODERATION)) { + return next(); + } + try { + const { text } = req.body; - const response = await axios.post( - process.env.OPENAI_MODERATION_REVERSE_PROXY || 'https://api.openai.com/v1/moderations', - { - input: text, + const response = await axios.post( + process.env.OPENAI_MODERATION_REVERSE_PROXY || 'https://api.openai.com/v1/moderations', + { + input: text, + }, + { + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${process.env.OPENAI_MODERATION_API_KEY}`, }, - { - headers: { - 'Content-Type': 'application/json', - Authorization: `Bearer ${process.env.OPENAI_MODERATION_API_KEY}`, - }, - }, - ); + }, + ); - const results = response.data.results; - const flagged = results.some((result) => result.flagged); + const results = response.data.results; + const flagged = results.some((result) => result.flagged); - if (flagged) { - const type = ErrorTypes.MODERATION; - const errorMessage = { type }; - return await denyRequest(req, res, errorMessage); - } - } catch (error) { - logger.error('Error in moderateText:', error); - const errorMessage = 'error in moderation check'; + if (flagged) { + const type = ErrorTypes.MODERATION; + const errorMessage = { type }; return await denyRequest(req, res, errorMessage); } + } catch (error) { + logger.error('Error in moderateText:', error); + const errorMessage = 'error in moderation check'; + return await denyRequest(req, res, errorMessage); } next(); } diff --git a/api/server/middleware/requireLocalAuth.js b/api/server/middleware/requireLocalAuth.js index 8319baf345..a71bd6c5b0 100644 --- a/api/server/middleware/requireLocalAuth.js +++ b/api/server/middleware/requireLocalAuth.js @@ -1,32 +1,18 @@ const passport = require('passport'); -const DebugControl = require('../../utils/debug.js'); - -function log({ title, parameters }) { - DebugControl.log.functionName(title); - if (parameters) { - DebugControl.log.parameters(parameters); - } -} +const { logger } = require('~/config'); const requireLocalAuth = (req, res, next) => { passport.authenticate('local', (err, user, info) => { if (err) { - log({ - title: '(requireLocalAuth) Error at passport.authenticate', - parameters: [{ name: 'error', value: err }], - }); + logger.error('[requireLocalAuth] Error at passport.authenticate:', err); return next(err); } if (!user) { - log({ - title: '(requireLocalAuth) Error: No user', - }); + logger.debug('[requireLocalAuth] Error: No user'); return res.status(404).send(info); } if (info && info.message) { - log({ - title: '(requireLocalAuth) Error: ' + info.message, - }); + logger.debug('[requireLocalAuth] Error: ' + info.message); return res.status(422).send({ message: info.message }); } req.user = user; diff --git a/api/server/middleware/roles/generateCheckAccess.js b/api/server/middleware/roles/generateCheckAccess.js index ffc0ddc613..cabbd405b0 100644 --- a/api/server/middleware/roles/generateCheckAccess.js +++ b/api/server/middleware/roles/generateCheckAccess.js @@ -1,4 +1,42 @@ const { getRoleByName } = require('~/models/Role'); +const { logger } = require('~/config'); + +/** + * Core function to check if a user has one or more required permissions + * + * @param {object} user - The user object + * @param {PermissionTypes} permissionType - The type of permission to check + * @param {Permissions[]} permissions - The list of specific permissions to check + * @param {Record} [bodyProps] - An optional object where keys are permissions and values are arrays of properties to check + * @param {object} [checkObject] - The object to check properties against + * @returns {Promise} Whether the user has the required permissions + */ +const checkAccess = async (user, permissionType, permissions, bodyProps = {}, checkObject = {}) => { + if (!user) { + return false; + } + + const role = await getRoleByName(user.role); + if (role && role.permissions && role.permissions[permissionType]) { + const hasAnyPermission = permissions.some((permission) => { + if (role.permissions[permissionType][permission]) { + return true; + } + + if (bodyProps[permission] && checkObject) { + return bodyProps[permission].some((prop) => + Object.prototype.hasOwnProperty.call(checkObject, prop), + ); + } + + return false; + }); + + return hasAnyPermission; + } + + return false; +}; /** * Middleware to check if a user has one or more required permissions, optionally based on `req.body` properties. @@ -6,42 +44,35 @@ const { getRoleByName } = require('~/models/Role'); * @param {PermissionTypes} permissionType - The type of permission to check. * @param {Permissions[]} permissions - The list of specific permissions to check. * @param {Record} [bodyProps] - An optional object where keys are permissions and values are arrays of `req.body` properties to check. - * @returns {Function} Express middleware function. + * @returns {(req: ServerRequest, res: ServerResponse, next: NextFunction) => Promise} Express middleware function. */ const generateCheckAccess = (permissionType, permissions, bodyProps = {}) => { return async (req, res, next) => { try { - const { user } = req; - if (!user) { - return res.status(401).json({ message: 'Authorization required' }); - } - - const role = await getRoleByName(user.role); - if (role && role[permissionType]) { - const hasAnyPermission = permissions.some((permission) => { - if (role[permissionType][permission]) { - return true; - } - - if (bodyProps[permission] && req.body) { - return bodyProps[permission].some((prop) => - Object.prototype.hasOwnProperty.call(req.body, prop), - ); - } - - return false; - }); - - if (hasAnyPermission) { - return next(); - } + const hasAccess = await checkAccess( + req.user, + permissionType, + permissions, + bodyProps, + req.body, + ); + + if (hasAccess) { + return next(); } + logger.warn( + `[${permissionType}] Forbidden: Insufficient permissions for User ${req.user.id}: ${permissions.join(', ')}`, + ); return res.status(403).json({ message: 'Forbidden: Insufficient permissions' }); } catch (error) { + logger.error(error); return res.status(500).json({ message: `Server error: ${error.message}` }); } }; }; -module.exports = generateCheckAccess; +module.exports = { + checkAccess, + generateCheckAccess, +}; diff --git a/api/server/middleware/roles/index.js b/api/server/middleware/roles/index.js index 999c36481e..a9fc5b2a08 100644 --- a/api/server/middleware/roles/index.js +++ b/api/server/middleware/roles/index.js @@ -1,7 +1,8 @@ const checkAdmin = require('./checkAdmin'); -const generateCheckAccess = require('./generateCheckAccess'); +const { checkAccess, generateCheckAccess } = require('./generateCheckAccess'); module.exports = { checkAdmin, + checkAccess, generateCheckAccess, }; diff --git a/api/server/middleware/setBalanceConfig.js b/api/server/middleware/setBalanceConfig.js new file mode 100644 index 0000000000..98d3cf1145 --- /dev/null +++ b/api/server/middleware/setBalanceConfig.js @@ -0,0 +1,91 @@ +const { getBalanceConfig } = require('~/server/services/Config'); +const Balance = require('~/models/Balance'); +const { logger } = require('~/config'); + +/** + * Middleware to synchronize user balance settings with current balance configuration. + * @function + * @param {Object} req - Express request object containing user information. + * @param {Object} res - Express response object. + * @param {import('express').NextFunction} next - Next middleware function. + */ +const setBalanceConfig = async (req, res, next) => { + try { + const balanceConfig = await getBalanceConfig(); + if (!balanceConfig?.enabled) { + return next(); + } + if (balanceConfig.startBalance == null) { + return next(); + } + + const userId = req.user._id; + const userBalanceRecord = await Balance.findOne({ user: userId }).lean(); + const updateFields = buildUpdateFields(balanceConfig, userBalanceRecord); + + if (Object.keys(updateFields).length === 0) { + return next(); + } + + await Balance.findOneAndUpdate( + { user: userId }, + { $set: updateFields }, + { upsert: true, new: true }, + ); + + next(); + } catch (error) { + logger.error('Error setting user balance:', error); + next(error); + } +}; + +/** + * Build an object containing fields that need updating + * @param {Object} config - The balance configuration + * @param {Object|null} userRecord - The user's current balance record, if any + * @returns {Object} Fields that need updating + */ +function buildUpdateFields(config, userRecord) { + const updateFields = {}; + + // Ensure user record has the required fields + if (!userRecord) { + updateFields.user = userRecord?.user; + updateFields.tokenCredits = config.startBalance; + } + + if (userRecord?.tokenCredits == null && config.startBalance != null) { + updateFields.tokenCredits = config.startBalance; + } + + const isAutoRefillConfigValid = + config.autoRefillEnabled && + config.refillIntervalValue != null && + config.refillIntervalUnit != null && + config.refillAmount != null; + + if (!isAutoRefillConfigValid) { + return updateFields; + } + + if (userRecord?.autoRefillEnabled !== config.autoRefillEnabled) { + updateFields.autoRefillEnabled = config.autoRefillEnabled; + } + + if (userRecord?.refillIntervalValue !== config.refillIntervalValue) { + updateFields.refillIntervalValue = config.refillIntervalValue; + } + + if (userRecord?.refillIntervalUnit !== config.refillIntervalUnit) { + updateFields.refillIntervalUnit = config.refillIntervalUnit; + } + + if (userRecord?.refillAmount !== config.refillAmount) { + updateFields.refillAmount = config.refillAmount; + } + + return updateFields; +} + +module.exports = setBalanceConfig; diff --git a/api/server/routes/__tests__/config.spec.js b/api/server/routes/__tests__/config.spec.js index 13af53f299..0bb80bb9ee 100644 --- a/api/server/routes/__tests__/config.spec.js +++ b/api/server/routes/__tests__/config.spec.js @@ -18,6 +18,7 @@ afterEach(() => { delete process.env.OPENID_ISSUER; delete process.env.OPENID_SESSION_SECRET; delete process.env.OPENID_BUTTON_LABEL; + delete process.env.OPENID_AUTO_REDIRECT; delete process.env.OPENID_AUTH_URL; delete process.env.GITHUB_CLIENT_ID; delete process.env.GITHUB_CLIENT_SECRET; diff --git a/api/server/routes/actions.js b/api/server/routes/actions.js index 454f4be6c7..dc474d1a67 100644 --- a/api/server/routes/actions.js +++ b/api/server/routes/actions.js @@ -1,5 +1,6 @@ const express = require('express'); const jwt = require('jsonwebtoken'); +const { CacheKeys } = require('librechat-data-provider'); const { getAccessToken } = require('~/server/services/TokenService'); const { logger, getFlowStateManager } = require('~/config'); const { getLogStores } = require('~/cache'); @@ -19,8 +20,8 @@ const JWT_SECRET = process.env.JWT_SECRET; router.get('/:action_id/oauth/callback', async (req, res) => { const { action_id } = req.params; const { code, state } = req.query; - - const flowManager = await getFlowStateManager(getLogStores); + const flowsCache = getLogStores(CacheKeys.FLOWS); + const flowManager = getFlowStateManager(flowsCache); let identifier = action_id; try { let decodedState; diff --git a/api/server/routes/agents/actions.js b/api/server/routes/agents/actions.js index 786f44dd8e..5413bc1d68 100644 --- a/api/server/routes/agents/actions.js +++ b/api/server/routes/agents/actions.js @@ -58,7 +58,7 @@ router.post('/:agent_id', async (req, res) => { } let { domain } = metadata; - domain = await domainParser(req, domain, true); + domain = await domainParser(domain, true); if (!domain) { return res.status(400).json({ message: 'No domain provided' }); @@ -164,7 +164,7 @@ router.delete('/:agent_id/:action_id', async (req, res) => { return true; }); - domain = await domainParser(req, domain, true); + domain = await domainParser(domain, true); if (!domain) { return res.status(400).json({ message: 'No domain provided' }); diff --git a/api/server/routes/agents/chat.js b/api/server/routes/agents/chat.js index fdb2db54d3..ef66ef7896 100644 --- a/api/server/routes/agents/chat.js +++ b/api/server/routes/agents/chat.js @@ -2,7 +2,7 @@ const express = require('express'); const { PermissionTypes, Permissions } = require('librechat-data-provider'); const { setHeaders, - handleAbort, + moderateText, // validateModel, generateCheckAccess, validateConvoAccess, @@ -14,28 +14,37 @@ const addTitle = require('~/server/services/Endpoints/agents/title'); const router = express.Router(); -router.post('/abort', handleAbort()); +router.use(moderateText); const checkAgentAccess = generateCheckAccess(PermissionTypes.AGENTS, [Permissions.USE]); +router.use(checkAgentAccess); +router.use(validateConvoAccess); +router.use(buildEndpointOption); +router.use(setHeaders); + +const controller = async (req, res, next) => { + await AgentController(req, res, next, initializeClient, addTitle); +}; + /** - * @route POST / + * @route POST / (regular endpoint) * @desc Chat with an assistant * @access Public * @param {express.Request} req - The request object, containing the request data. * @param {express.Response} res - The response object, used to send back a response. * @returns {void} */ -router.post( - '/', - // validateModel, - checkAgentAccess, - validateConvoAccess, - buildEndpointOption, - setHeaders, - async (req, res, next) => { - await AgentController(req, res, next, initializeClient, addTitle); - }, -); +router.post('/', controller); + +/** + * @route POST /:endpoint (ephemeral agents) + * @desc Chat with an assistant + * @access Public + * @param {express.Request} req - The request object, containing the request data. + * @param {express.Response} res - The response object, used to send back a response. + * @returns {void} + */ +router.post('/:endpoint', controller); module.exports = router; diff --git a/api/server/routes/agents/index.js b/api/server/routes/agents/index.js index d7ef93af73..1834d2e2bc 100644 --- a/api/server/routes/agents/index.js +++ b/api/server/routes/agents/index.js @@ -1,21 +1,40 @@ const express = require('express'); -const router = express.Router(); const { uaParser, checkBan, requireJwtAuth, - // concurrentLimiter, - // messageIpLimiter, - // messageUserLimiter, + messageIpLimiter, + concurrentLimiter, + messageUserLimiter, } = require('~/server/middleware'); - +const { isEnabled } = require('~/server/utils'); const { v1 } = require('./v1'); const chat = require('./chat'); +const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {}; + +const router = express.Router(); + router.use(requireJwtAuth); router.use(checkBan); router.use(uaParser); + router.use('/', v1); -router.use('/chat', chat); + +const chatRouter = express.Router(); +if (isEnabled(LIMIT_CONCURRENT_MESSAGES)) { + chatRouter.use(concurrentLimiter); +} + +if (isEnabled(LIMIT_MESSAGE_IP)) { + chatRouter.use(messageIpLimiter); +} + +if (isEnabled(LIMIT_MESSAGE_USER)) { + chatRouter.use(messageUserLimiter); +} + +chatRouter.use('/', chat); +router.use('/chat', chatRouter); module.exports = router; diff --git a/api/server/routes/ask/addToCache.js b/api/server/routes/ask/addToCache.js index 6e21edd2b8..a2f427098f 100644 --- a/api/server/routes/ask/addToCache.js +++ b/api/server/routes/ask/addToCache.js @@ -1,4 +1,4 @@ -const Keyv = require('keyv'); +const { Keyv } = require('keyv'); const { KeyvFile } = require('keyv-file'); const { logger } = require('~/config'); diff --git a/api/server/routes/ask/anthropic.js b/api/server/routes/ask/anthropic.js index a08d1d2570..afe1720d84 100644 --- a/api/server/routes/ask/anthropic.js +++ b/api/server/routes/ask/anthropic.js @@ -11,8 +11,6 @@ const { const router = express.Router(); -router.post('/abort', handleAbort()); - router.post( '/', validateEndpoint, diff --git a/api/server/routes/ask/custom.js b/api/server/routes/ask/custom.js index 668a9902cb..8fc343cf17 100644 --- a/api/server/routes/ask/custom.js +++ b/api/server/routes/ask/custom.js @@ -3,7 +3,6 @@ const AskController = require('~/server/controllers/AskController'); const { initializeClient } = require('~/server/services/Endpoints/custom'); const { addTitle } = require('~/server/services/Endpoints/openAI'); const { - handleAbort, setHeaders, validateModel, validateEndpoint, @@ -12,8 +11,6 @@ const { const router = express.Router(); -router.post('/abort', handleAbort()); - router.post( '/', validateEndpoint, diff --git a/api/server/routes/ask/google.js b/api/server/routes/ask/google.js index 2b3378bf6c..16c7e265f4 100644 --- a/api/server/routes/ask/google.js +++ b/api/server/routes/ask/google.js @@ -3,7 +3,6 @@ const AskController = require('~/server/controllers/AskController'); const { initializeClient, addTitle } = require('~/server/services/Endpoints/google'); const { setHeaders, - handleAbort, validateModel, validateEndpoint, buildEndpointOption, @@ -11,8 +10,6 @@ const { const router = express.Router(); -router.post('/abort', handleAbort()); - router.post( '/', validateEndpoint, diff --git a/api/server/routes/ask/gptPlugins.js b/api/server/routes/ask/gptPlugins.js index 036654f845..a40022848a 100644 --- a/api/server/routes/ask/gptPlugins.js +++ b/api/server/routes/ask/gptPlugins.js @@ -20,7 +20,6 @@ const { logger } = require('~/config'); const router = express.Router(); router.use(moderateText); -router.post('/abort', handleAbort()); router.post( '/', @@ -196,7 +195,8 @@ router.post( logger.debug('[/ask/gptPlugins]', response); - const { conversation = {} } = await client.responsePromise; + const { conversation = {} } = await response.databasePromise; + delete response.databasePromise; conversation.title = conversation && !conversation.title ? null : conversation?.title || 'New Chat'; diff --git a/api/server/routes/ask/index.js b/api/server/routes/ask/index.js index bd5666153f..525bd8e29d 100644 --- a/api/server/routes/ask/index.js +++ b/api/server/routes/ask/index.js @@ -1,10 +1,4 @@ const express = require('express'); -const openAI = require('./openAI'); -const custom = require('./custom'); -const google = require('./google'); -const anthropic = require('./anthropic'); -const gptPlugins = require('./gptPlugins'); -const { isEnabled } = require('~/server/utils'); const { EModelEndpoint } = require('librechat-data-provider'); const { uaParser, @@ -15,6 +9,12 @@ const { messageUserLimiter, validateConvoAccess, } = require('~/server/middleware'); +const { isEnabled } = require('~/server/utils'); +const gptPlugins = require('./gptPlugins'); +const anthropic = require('./anthropic'); +const custom = require('./custom'); +const google = require('./google'); +const openAI = require('./openAI'); const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {}; diff --git a/api/server/routes/ask/openAI.js b/api/server/routes/ask/openAI.js index 5083a08b10..dadf00def4 100644 --- a/api/server/routes/ask/openAI.js +++ b/api/server/routes/ask/openAI.js @@ -12,7 +12,6 @@ const { const router = express.Router(); router.use(moderateText); -router.post('/abort', handleAbort()); router.post( '/', diff --git a/api/server/routes/assistants/actions.js b/api/server/routes/assistants/actions.js index 9f4db5d6b8..3dc3923503 100644 --- a/api/server/routes/assistants/actions.js +++ b/api/server/routes/assistants/actions.js @@ -36,7 +36,7 @@ router.post('/:assistant_id', async (req, res) => { } let { domain } = metadata; - domain = await domainParser(req, domain, true); + domain = await domainParser(domain, true); if (!domain) { return res.status(400).json({ message: 'No domain provided' }); @@ -172,7 +172,7 @@ router.delete('/:assistant_id/:action_id/:model', async (req, res) => { return true; }); - domain = await domainParser(req, domain, true); + domain = await domainParser(domain, true); if (!domain) { return res.status(400).json({ message: 'No domain provided' }); diff --git a/api/server/routes/auth.js b/api/server/routes/auth.js index 3e86ffd868..187d908abd 100644 --- a/api/server/routes/auth.js +++ b/api/server/routes/auth.js @@ -7,13 +7,23 @@ const { } = require('~/server/controllers/AuthController'); const { loginController } = require('~/server/controllers/auth/LoginController'); const { logoutController } = require('~/server/controllers/auth/LogoutController'); +const { verify2FAWithTempToken } = require('~/server/controllers/auth/TwoFactorAuthController'); +const { + enable2FA, + verify2FA, + disable2FA, + regenerateBackupCodes, + confirm2FA, +} = require('~/server/controllers/TwoFactorController'); const { checkBan, + logHeaders, loginLimiter, requireJwtAuth, checkInviteUser, registerLimiter, requireLdapAuth, + setBalanceConfig, requireLocalAuth, resetPasswordLimiter, validateRegistration, @@ -27,9 +37,11 @@ const ldapAuth = !!process.env.LDAP_URL && !!process.env.LDAP_USER_SEARCH_BASE; router.post('/logout', requireJwtAuth, logoutController); router.post( '/login', + logHeaders, loginLimiter, checkBan, ldapAuth ? requireLdapAuth : requireLocalAuth, + setBalanceConfig, loginController, ); router.post('/refresh', refreshController); @@ -50,4 +62,11 @@ router.post( ); router.post('/resetPassword', checkBan, validatePasswordReset, resetPasswordController); +router.get('/2fa/enable', requireJwtAuth, enable2FA); +router.post('/2fa/verify', requireJwtAuth, verify2FA); +router.post('/2fa/verify-temp', checkBan, verify2FAWithTempToken); +router.post('/2fa/confirm', requireJwtAuth, confirm2FA); +router.post('/2fa/disable', requireJwtAuth, disable2FA); +router.post('/2fa/backup/regenerate', requireJwtAuth, regenerateBackupCodes); + module.exports = router; diff --git a/api/server/routes/bedrock/chat.js b/api/server/routes/bedrock/chat.js index c8d6be35de..263ca96002 100644 --- a/api/server/routes/bedrock/chat.js +++ b/api/server/routes/bedrock/chat.js @@ -4,6 +4,7 @@ const router = express.Router(); const { setHeaders, handleAbort, + moderateText, // validateModel, // validateEndpoint, buildEndpointOption, @@ -12,7 +13,7 @@ const { initializeClient } = require('~/server/services/Endpoints/bedrock'); const AgentController = require('~/server/controllers/agents/request'); const addTitle = require('~/server/services/Endpoints/agents/title'); -router.post('/abort', handleAbort()); +router.use(moderateText); /** * @route POST / diff --git a/api/server/routes/bedrock/index.js b/api/server/routes/bedrock/index.js index b1a9efec4c..ce440a7c0e 100644 --- a/api/server/routes/bedrock/index.js +++ b/api/server/routes/bedrock/index.js @@ -1,19 +1,35 @@ const express = require('express'); -const router = express.Router(); const { uaParser, checkBan, requireJwtAuth, - // concurrentLimiter, - // messageIpLimiter, - // messageUserLimiter, + messageIpLimiter, + concurrentLimiter, + messageUserLimiter, } = require('~/server/middleware'); - +const { isEnabled } = require('~/server/utils'); const chat = require('./chat'); +const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {}; + +const router = express.Router(); + router.use(requireJwtAuth); router.use(checkBan); router.use(uaParser); + +if (isEnabled(LIMIT_CONCURRENT_MESSAGES)) { + router.use(concurrentLimiter); +} + +if (isEnabled(LIMIT_MESSAGE_IP)) { + router.use(messageIpLimiter); +} + +if (isEnabled(LIMIT_MESSAGE_USER)) { + router.use(messageUserLimiter); +} + router.use('/chat', chat); module.exports = router; diff --git a/api/server/routes/config.js b/api/server/routes/config.js index 705a1d3cb1..ebafb05c30 100644 --- a/api/server/routes/config.js +++ b/api/server/routes/config.js @@ -47,10 +47,10 @@ router.get('/', async function (req, res) { githubLoginEnabled: !!process.env.GITHUB_CLIENT_ID && !!process.env.GITHUB_CLIENT_SECRET, googleLoginEnabled: !!process.env.GOOGLE_CLIENT_ID && !!process.env.GOOGLE_CLIENT_SECRET, appleLoginEnabled: - !!process.env.APPLE_CLIENT_ID && - !!process.env.APPLE_TEAM_ID && - !!process.env.APPLE_KEY_ID && - !!process.env.APPLE_PRIVATE_KEY_PATH, + !!process.env.APPLE_CLIENT_ID && + !!process.env.APPLE_TEAM_ID && + !!process.env.APPLE_KEY_ID && + !!process.env.APPLE_PRIVATE_KEY_PATH, openidLoginEnabled: !!process.env.OPENID_CLIENT_ID && !!process.env.OPENID_CLIENT_SECRET && @@ -58,6 +58,7 @@ router.get('/', async function (req, res) { !!process.env.OPENID_SESSION_SECRET, openidLabel: process.env.OPENID_BUTTON_LABEL || 'Continue with OpenID', openidImageUrl: process.env.OPENID_IMAGE_URL, + openidAutoRedirect: isEnabled(process.env.OPENID_AUTO_REDIRECT), serverDomain: process.env.DOMAIN_SERVER || 'http://localhost:3080', emailLoginEnabled, registrationEnabled: !ldap?.enabled && isEnabled(process.env.ALLOW_REGISTRATION), @@ -68,7 +69,6 @@ router.get('/', async function (req, res) { !!process.env.EMAIL_PASSWORD && !!process.env.EMAIL_FROM, passwordResetEnabled, - checkBalance: isEnabled(process.env.CHECK_BALANCE), showBirthdayIcon: isBirthday() || isEnabled(process.env.SHOW_BIRTHDAY_ICON) || @@ -76,10 +76,13 @@ router.get('/', async function (req, res) { helpAndFaqURL: process.env.HELP_AND_FAQ_URL || 'https://librechat.ai', interface: req.app.locals.interfaceConfig, modelSpecs: req.app.locals.modelSpecs, + balance: req.app.locals.balance, sharedLinksEnabled, publicSharedLinksEnabled, analyticsGtmId: process.env.ANALYTICS_GTM_ID, instanceProjectId: instanceProject._id.toString(), + bundlerURL: process.env.SANDPACK_BUNDLER_URL, + staticBundlerURL: process.env.SANDPACK_STATIC_BUNDLER_URL, }; if (ldap) { diff --git a/api/server/routes/convos.js b/api/server/routes/convos.js index a4d81e24e6..2473eb68f9 100644 --- a/api/server/routes/convos.js +++ b/api/server/routes/convos.js @@ -1,16 +1,17 @@ const multer = require('multer'); const express = require('express'); const { CacheKeys, EModelEndpoint } = require('librechat-data-provider'); -const { getConvosByPage, deleteConvos, getConvo, saveConvo } = require('~/models/Conversation'); +const { getConvosByCursor, deleteConvos, getConvo, saveConvo } = require('~/models/Conversation'); const { forkConversation, duplicateConversation } = require('~/server/utils/import/fork'); const { storage, importFileFilter } = require('~/server/routes/files/multer'); const requireJwtAuth = require('~/server/middleware/requireJwtAuth'); const { importConversations } = require('~/server/utils/import'); const { createImportLimiters } = require('~/server/middleware'); const { deleteToolCalls } = require('~/models/ToolCall'); +const { isEnabled, sleep } = require('~/server/utils'); const getLogStores = require('~/cache/getLogStores'); -const { sleep } = require('~/server/utils'); const { logger } = require('~/config'); + const assistantClients = { [EModelEndpoint.azureAssistants]: require('~/server/services/Endpoints/azureAssistants'), [EModelEndpoint.assistants]: require('~/server/services/Endpoints/assistants'), @@ -20,28 +21,30 @@ const router = express.Router(); router.use(requireJwtAuth); router.get('/', async (req, res) => { - let pageNumber = req.query.pageNumber || 1; - pageNumber = parseInt(pageNumber, 10); + const limit = parseInt(req.query.limit, 10) || 25; + const cursor = req.query.cursor; + const isArchived = isEnabled(req.query.isArchived); + const search = req.query.search ? decodeURIComponent(req.query.search) : undefined; + const order = req.query.order || 'desc'; - if (isNaN(pageNumber) || pageNumber < 1) { - return res.status(400).json({ error: 'Invalid page number' }); - } - - let pageSize = req.query.pageSize || 25; - pageSize = parseInt(pageSize, 10); - - if (isNaN(pageSize) || pageSize < 1) { - return res.status(400).json({ error: 'Invalid page size' }); - } - const isArchived = req.query.isArchived === 'true'; let tags; if (req.query.tags) { tags = Array.isArray(req.query.tags) ? req.query.tags : [req.query.tags]; - } else { - tags = undefined; } - res.status(200).send(await getConvosByPage(req.user.id, pageNumber, pageSize, isArchived, tags)); + try { + const result = await getConvosByCursor(req.user.id, { + cursor, + limit, + isArchived, + tags, + search, + order, + }); + res.status(200).json(result); + } catch (error) { + res.status(500).json({ error: 'Error fetching conversations' }); + } }); router.get('/:conversationId', async (req, res) => { @@ -76,22 +79,28 @@ router.post('/gen_title', async (req, res) => { } }); -router.post('/clear', async (req, res) => { +router.delete('/', async (req, res) => { let filter = {}; const { conversationId, source, thread_id, endpoint } = req.body.arg; - if (conversationId) { - filter = { conversationId }; + + // Prevent deletion of all conversations + if (!conversationId && !source && !thread_id && !endpoint) { + return res.status(400).json({ + error: 'no parameters provided', + }); } - if (source === 'button' && !conversationId) { + if (conversationId) { + filter = { conversationId }; + } else if (source === 'button') { return res.status(200).send('No conversationId provided'); } if ( - typeof endpoint != 'undefined' && + typeof endpoint !== 'undefined' && Object.prototype.propertyIsEnumerable.call(assistantClients, endpoint) ) { - /** @type {{ openai: OpenAI}} */ + /** @type {{ openai: OpenAI }} */ const { openai } = await assistantClients[endpoint].initializeClient({ req, res }); try { const response = await openai.beta.threads.del(thread_id); @@ -101,9 +110,6 @@ router.post('/clear', async (req, res) => { } } - // for debugging deletion source - // logger.debug('source:', source); - try { const dbResponse = await deleteConvos(req.user.id, filter); await deleteToolCalls(req.user.id, filter.conversationId); @@ -114,6 +120,17 @@ router.post('/clear', async (req, res) => { } }); +router.delete('/all', async (req, res) => { + try { + const dbResponse = await deleteConvos(req.user.id, {}); + await deleteToolCalls(req.user.id); + res.status(201).json(dbResponse); + } catch (error) { + logger.error('Error clearing conversations', error); + res.status(500).send('Error clearing conversations'); + } +}); + router.post('/update', async (req, res) => { const update = req.body.arg; diff --git a/api/server/routes/edit/anthropic.js b/api/server/routes/edit/anthropic.js index c7bf128d7c..704a9f4ea4 100644 --- a/api/server/routes/edit/anthropic.js +++ b/api/server/routes/edit/anthropic.js @@ -3,7 +3,6 @@ const EditController = require('~/server/controllers/EditController'); const { initializeClient } = require('~/server/services/Endpoints/anthropic'); const { setHeaders, - handleAbort, validateModel, validateEndpoint, buildEndpointOption, @@ -11,8 +10,6 @@ const { const router = express.Router(); -router.post('/abort', handleAbort()); - router.post( '/', validateEndpoint, diff --git a/api/server/routes/edit/custom.js b/api/server/routes/edit/custom.js index 0bf97ba180..a6fd804763 100644 --- a/api/server/routes/edit/custom.js +++ b/api/server/routes/edit/custom.js @@ -12,8 +12,6 @@ const { const router = express.Router(); -router.post('/abort', handleAbort()); - router.post( '/', validateEndpoint, diff --git a/api/server/routes/edit/google.js b/api/server/routes/edit/google.js index 7482f11b4c..187f4f6158 100644 --- a/api/server/routes/edit/google.js +++ b/api/server/routes/edit/google.js @@ -3,7 +3,6 @@ const EditController = require('~/server/controllers/EditController'); const { initializeClient } = require('~/server/services/Endpoints/google'); const { setHeaders, - handleAbort, validateModel, validateEndpoint, buildEndpointOption, @@ -11,8 +10,6 @@ const { const router = express.Router(); -router.post('/abort', handleAbort()); - router.post( '/', validateEndpoint, diff --git a/api/server/routes/edit/gptPlugins.js b/api/server/routes/edit/gptPlugins.js index 5547a1fcdf..94d9b91d0b 100644 --- a/api/server/routes/edit/gptPlugins.js +++ b/api/server/routes/edit/gptPlugins.js @@ -2,7 +2,6 @@ const express = require('express'); const { getResponseSender } = require('librechat-data-provider'); const { setHeaders, - handleAbort, moderateText, validateModel, handleAbortError, @@ -19,7 +18,6 @@ const { logger } = require('~/config'); const router = express.Router(); router.use(moderateText); -router.post('/abort', handleAbort()); router.post( '/', @@ -173,7 +171,8 @@ router.post( logger.debug('[/edit/gptPlugins] CLIENT RESPONSE', response); - const { conversation = {} } = await client.responsePromise; + const { conversation = {} } = await response.databasePromise; + delete response.databasePromise; conversation.title = conversation && !conversation.title ? null : conversation?.title || 'New Chat'; diff --git a/api/server/routes/edit/openAI.js b/api/server/routes/edit/openAI.js index ae26b235c7..ee25a42ee3 100644 --- a/api/server/routes/edit/openAI.js +++ b/api/server/routes/edit/openAI.js @@ -2,7 +2,6 @@ const express = require('express'); const EditController = require('~/server/controllers/EditController'); const { initializeClient } = require('~/server/services/Endpoints/openAI'); const { - handleAbort, setHeaders, validateModel, validateEndpoint, @@ -12,7 +11,6 @@ const { const router = express.Router(); router.use(moderateText); -router.post('/abort', handleAbort()); router.post( '/', diff --git a/api/server/routes/files/files.js b/api/server/routes/files/files.js index c320f7705b..5a520bdb65 100644 --- a/api/server/routes/files/files.js +++ b/api/server/routes/files/files.js @@ -2,7 +2,9 @@ const fs = require('fs').promises; const express = require('express'); const { EnvVar } = require('@librechat/agents'); const { + Time, isUUID, + CacheKeys, FileSources, EModelEndpoint, isAgentsEndpoint, @@ -16,9 +18,12 @@ const { } = require('~/server/services/Files/process'); const { getStrategyFunctions } = require('~/server/services/Files/strategies'); const { getOpenAIClient } = require('~/server/controllers/assistants/helpers'); -const { loadAuthValues } = require('~/app/clients/tools/util'); +const { loadAuthValues } = require('~/server/services/Tools/credentials'); +const { refreshS3FileUrls } = require('~/server/services/Files/S3/crud'); +const { getFiles, batchUpdateFiles } = require('~/models/File'); +const { getAssistant } = require('~/models/Assistant'); const { getAgent } = require('~/models/Agent'); -const { getFiles } = require('~/models/File'); +const { getLogStores } = require('~/cache'); const { logger } = require('~/config'); const router = express.Router(); @@ -26,6 +31,18 @@ const router = express.Router(); router.get('/', async (req, res) => { try { const files = await getFiles({ user: req.user.id }); + if (req.app.locals.fileStrategy === FileSources.s3) { + try { + const cache = getLogStores(CacheKeys.S3_EXPIRY_INTERVAL); + const alreadyChecked = await cache.get(req.user.id); + if (!alreadyChecked) { + await refreshS3FileUrls(files, batchUpdateFiles); + await cache.set(req.user.id, true, Time.THIRTY_MINUTES); + } + } catch (error) { + logger.warn('[/files] Error refreshing S3 file URLs:', error); + } + } res.status(200).send(files); } catch (error) { logger.error('[/files] Error getting files:', error); @@ -78,7 +95,7 @@ router.delete('/', async (req, res) => { }); } - /* Handle entity unlinking even if no valid files to delete */ + /* Handle agent unlinking even if no valid files to delete */ if (req.body.agent_id && req.body.tool_resource && dbFiles.length === 0) { const agent = await getAgent({ id: req.body.agent_id, @@ -88,7 +105,21 @@ router.delete('/', async (req, res) => { const agentFiles = files.filter((f) => toolResourceFiles.includes(f.file_id)); await processDeleteRequest({ req, files: agentFiles }); - res.status(200).json({ message: 'File associations removed successfully' }); + res.status(200).json({ message: 'File associations removed successfully from agent' }); + return; + } + + /* Handle assistant unlinking even if no valid files to delete */ + if (req.body.assistant_id && req.body.tool_resource && dbFiles.length === 0) { + const assistant = await getAssistant({ + id: req.body.assistant_id, + }); + + const toolResourceFiles = assistant.tool_resources?.[req.body.tool_resource]?.file_ids ?? []; + const assistantFiles = files.filter((f) => toolResourceFiles.includes(f.file_id)); + + await processDeleteRequest({ req, files: assistantFiles }); + res.status(200).json({ message: 'File associations removed successfully from assistant' }); return; } diff --git a/api/server/routes/index.js b/api/server/routes/index.js index 4b34029c7b..449759383d 100644 --- a/api/server/routes/index.js +++ b/api/server/routes/index.js @@ -10,6 +10,7 @@ const balance = require('./balance'); const plugins = require('./plugins'); const bedrock = require('./bedrock'); const actions = require('./actions'); +const banner = require('./banner'); const search = require('./search'); const models = require('./models'); const convos = require('./convos'); @@ -25,7 +26,6 @@ const edit = require('./edit'); const keys = require('./keys'); const user = require('./user'); const ask = require('./ask'); -const banner = require('./banner'); module.exports = { ask, @@ -38,13 +38,14 @@ module.exports = { oauth, files, share, + banner, agents, - bedrock, convos, search, - prompts, config, models, + bedrock, + prompts, plugins, actions, presets, @@ -55,5 +56,4 @@ module.exports = { assistants, categories, staticRoute, - banner, }; diff --git a/api/server/routes/messages.js b/api/server/routes/messages.js index 54c4aab1c2..d5980ae55b 100644 --- a/api/server/routes/messages.js +++ b/api/server/routes/messages.js @@ -10,12 +10,90 @@ const { } = require('~/models'); const { findAllArtifacts, replaceArtifactContent } = require('~/server/services/Artifacts/update'); const { requireJwtAuth, validateMessageReq } = require('~/server/middleware'); +const { cleanUpPrimaryKeyValue } = require('~/lib/utils/misc'); +const { getConvosQueried } = require('~/models/Conversation'); const { countTokens } = require('~/server/utils'); +const { Message } = require('~/models/Message'); const { logger } = require('~/config'); const router = express.Router(); router.use(requireJwtAuth); +router.get('/', async (req, res) => { + try { + const user = req.user.id ?? ''; + const { + cursor = null, + sortBy = 'createdAt', + sortDirection = 'desc', + pageSize: pageSizeRaw, + conversationId, + messageId, + search, + } = req.query; + const pageSize = parseInt(pageSizeRaw, 10) || 25; + + let response; + const sortField = ['endpoint', 'createdAt', 'updatedAt'].includes(sortBy) + ? sortBy + : 'createdAt'; + const sortOrder = sortDirection === 'asc' ? 1 : -1; + + if (conversationId && messageId) { + const message = await Message.findOne({ conversationId, messageId, user: user }).lean(); + response = { messages: message ? [message] : [], nextCursor: null }; + } else if (conversationId) { + const filter = { conversationId, user: user }; + if (cursor) { + filter[sortField] = sortOrder === 1 ? { $gt: cursor } : { $lt: cursor }; + } + const messages = await Message.find(filter) + .sort({ [sortField]: sortOrder }) + .limit(pageSize + 1) + .lean(); + const nextCursor = messages.length > pageSize ? messages.pop()[sortField] : null; + response = { messages, nextCursor }; + } else if (search) { + const searchResults = await Message.meiliSearch(search, undefined, true); + + const messages = searchResults.hits || []; + + const result = await getConvosQueried(req.user.id, messages, cursor); + + const activeMessages = []; + for (let i = 0; i < messages.length; i++) { + let message = messages[i]; + if (message.conversationId.includes('--')) { + message.conversationId = cleanUpPrimaryKeyValue(message.conversationId); + } + if (result.convoMap[message.conversationId]) { + const convo = result.convoMap[message.conversationId]; + + const dbMessage = await getMessage({ user, messageId: message.messageId }); + activeMessages.push({ + ...message, + title: convo.title, + conversationId: message.conversationId, + model: convo.model, + isCreatedByUser: dbMessage?.isCreatedByUser, + endpoint: dbMessage?.endpoint, + iconURL: dbMessage?.iconURL, + }); + } + } + + response = { messages: activeMessages, nextCursor: null }; + } else { + response = { messages: [], nextCursor: null }; + } + + res.status(200).json(response); + } catch (error) { + logger.error('Error fetching messages:', error); + res.status(500).json({ error: 'Internal server error' }); + } +}); + router.post('/artifact/:messageId', async (req, res) => { try { const { messageId } = req.params; diff --git a/api/server/routes/oauth.js b/api/server/routes/oauth.js index 046370798b..b2037683d2 100644 --- a/api/server/routes/oauth.js +++ b/api/server/routes/oauth.js @@ -1,7 +1,13 @@ // file deepcode ignore NoRateLimitingForLogin: Rate limiting is handled by the `loginLimiter` middleware const express = require('express'); const passport = require('passport'); -const { loginLimiter, checkBan, checkDomainAllowed } = require('~/server/middleware'); +const { + checkBan, + logHeaders, + loginLimiter, + setBalanceConfig, + checkDomainAllowed, +} = require('~/server/middleware'); const { setAuthTokens } = require('~/server/services/AuthService'); const { logger } = require('~/config'); @@ -12,6 +18,7 @@ const domains = { server: process.env.DOMAIN_SERVER, }; +router.use(logHeaders); router.use(loginLimiter); const oauthHandler = async (req, res) => { @@ -31,7 +38,9 @@ const oauthHandler = async (req, res) => { router.get('/error', (req, res) => { // A single error message is pushed by passport when authentication fails. logger.error('Error in OAuth authentication:', { message: req.session.messages.pop() }); - res.redirect(`${domains.client}/login`); + + // Redirect to login page with auth_failed parameter to prevent infinite redirect loops + res.redirect(`${domains.client}/login?redirect=false`); }); /** @@ -53,6 +62,7 @@ router.get( session: false, scope: ['openid', 'profile', 'email'], }), + setBalanceConfig, oauthHandler, ); @@ -77,6 +87,7 @@ router.get( scope: ['public_profile'], profileFields: ['id', 'email', 'name'], }), + setBalanceConfig, oauthHandler, ); @@ -97,6 +108,7 @@ router.get( failureMessage: true, session: false, }), + setBalanceConfig, oauthHandler, ); @@ -119,6 +131,7 @@ router.get( session: false, scope: ['user:email', 'read:user'], }), + setBalanceConfig, oauthHandler, ); @@ -141,6 +154,7 @@ router.get( session: false, scope: ['identify', 'email'], }), + setBalanceConfig, oauthHandler, ); @@ -161,6 +175,7 @@ router.post( failureMessage: true, session: false, }), + setBalanceConfig, oauthHandler, ); diff --git a/api/server/routes/roles.js b/api/server/routes/roles.js index e58ebb6fe7..17768c7de6 100644 --- a/api/server/routes/roles.js +++ b/api/server/routes/roles.js @@ -48,7 +48,7 @@ router.put('/:roleName/prompts', checkAdmin, async (req, res) => { const { roleName: _r } = req.params; // TODO: TEMP, use a better parsing for roleName const roleName = _r.toUpperCase(); - /** @type {TRole['PROMPTS']} */ + /** @type {TRole['permissions']['PROMPTS']} */ const updates = req.body; try { @@ -59,10 +59,16 @@ router.put('/:roleName/prompts', checkAdmin, async (req, res) => { return res.status(404).send({ message: 'Role not found' }); } + const currentPermissions = + role.permissions?.[PermissionTypes.PROMPTS] || role[PermissionTypes.PROMPTS] || {}; + const mergedUpdates = { - [PermissionTypes.PROMPTS]: { - ...role[PermissionTypes.PROMPTS], - ...parsedUpdates, + permissions: { + ...role.permissions, + [PermissionTypes.PROMPTS]: { + ...currentPermissions, + ...parsedUpdates, + }, }, }; @@ -81,7 +87,7 @@ router.put('/:roleName/agents', checkAdmin, async (req, res) => { const { roleName: _r } = req.params; // TODO: TEMP, use a better parsing for roleName const roleName = _r.toUpperCase(); - /** @type {TRole['AGENTS']} */ + /** @type {TRole['permissions']['AGENTS']} */ const updates = req.body; try { @@ -92,17 +98,23 @@ router.put('/:roleName/agents', checkAdmin, async (req, res) => { return res.status(404).send({ message: 'Role not found' }); } + const currentPermissions = + role.permissions?.[PermissionTypes.AGENTS] || role[PermissionTypes.AGENTS] || {}; + const mergedUpdates = { - [PermissionTypes.AGENTS]: { - ...role[PermissionTypes.AGENTS], - ...parsedUpdates, + permissions: { + ...role.permissions, + [PermissionTypes.AGENTS]: { + ...currentPermissions, + ...parsedUpdates, + }, }, }; const updatedRole = await updateRoleByName(roleName, mergedUpdates); res.status(200).send(updatedRole); } catch (error) { - return res.status(400).send({ message: 'Invalid prompt permissions.', error: error.errors }); + return res.status(400).send({ message: 'Invalid agent permissions.', error: error.errors }); } }); diff --git a/api/server/routes/search.js b/api/server/routes/search.js index 68cff7532b..5c7846aee1 100644 --- a/api/server/routes/search.js +++ b/api/server/routes/search.js @@ -1,93 +1,17 @@ -const Keyv = require('keyv'); const express = require('express'); const { MeiliSearch } = require('meilisearch'); -const { Conversation, getConvosQueried } = require('~/models/Conversation'); const requireJwtAuth = require('~/server/middleware/requireJwtAuth'); -const { cleanUpPrimaryKeyValue } = require('~/lib/utils/misc'); -const { reduceHits } = require('~/lib/utils/reduceHits'); const { isEnabled } = require('~/server/utils'); -const { Message } = require('~/models/Message'); -const keyvRedis = require('~/cache/keyvRedis'); -const { logger } = require('~/config'); const router = express.Router(); -const expiration = 60 * 1000; -const cache = isEnabled(process.env.USE_REDIS) - ? new Keyv({ store: keyvRedis }) - : new Keyv({ namespace: 'search', ttl: expiration }); - router.use(requireJwtAuth); -router.get('/sync', async function (req, res) { - await Message.syncWithMeili(); - await Conversation.syncWithMeili(); - res.send('synced'); -}); - -router.get('/', async function (req, res) { - try { - let user = req.user.id ?? ''; - const { q } = req.query; - const pageNumber = req.query.pageNumber || 1; - const key = `${user}:search:${q}`; - const cached = await cache.get(key); - if (cached) { - logger.debug('[/search] cache hit: ' + key); - const { pages, pageSize, messages } = cached; - res - .status(200) - .send({ conversations: cached[pageNumber], pages, pageNumber, pageSize, messages }); - return; - } - - const messages = (await Message.meiliSearch(q, undefined, true)).hits; - const titles = (await Conversation.meiliSearch(q)).hits; - - const sortedHits = reduceHits(messages, titles); - const result = await getConvosQueried(user, sortedHits, pageNumber); - - const activeMessages = []; - for (let i = 0; i < messages.length; i++) { - let message = messages[i]; - if (message.conversationId.includes('--')) { - message.conversationId = cleanUpPrimaryKeyValue(message.conversationId); - } - if (result.convoMap[message.conversationId]) { - const convo = result.convoMap[message.conversationId]; - const { title, chatGptLabel, model } = convo; - message = { ...message, ...{ title, chatGptLabel, model } }; - activeMessages.push(message); - } - } - result.messages = activeMessages; - if (result.cache) { - result.cache.messages = activeMessages; - cache.set(key, result.cache, expiration); - delete result.cache; - } - delete result.convoMap; - - res.status(200).send(result); - } catch (error) { - logger.error('[/search] Error while searching messages & conversations', error); - res.status(500).send({ message: 'Error searching' }); - } -}); - -router.get('/test', async function (req, res) { - const { q } = req.query; - const messages = ( - await Message.meiliSearch(q, { attributesToHighlight: ['text'] }, true) - ).hits.map((message) => { - const { _formatted, ...rest } = message; - return { ...rest, searchResult: true, text: _formatted.text }; - }); - res.send(messages); -}); - router.get('/enable', async function (req, res) { - let result = false; + if (!isEnabled(process.env.SEARCH)) { + return res.send(false); + } + try { const client = new MeiliSearch({ host: process.env.MEILI_HOST, @@ -95,8 +19,7 @@ router.get('/enable', async function (req, res) { }); const { status } = await client.health(); - result = status === 'available' && !!process.env.SEARCH; - return res.send(result); + return res.send(status === 'available'); } catch (error) { return res.send(false); } diff --git a/api/server/services/ActionService.js b/api/server/services/ActionService.js index 660e7aeb0d..c8a7955427 100644 --- a/api/server/services/ActionService.js +++ b/api/server/services/ActionService.js @@ -13,7 +13,6 @@ const { actionDomainSeparator, } = require('librechat-data-provider'); const { refreshAccessToken } = require('~/server/services/TokenService'); -const { isActionDomainAllowed } = require('~/server/services/domains'); const { logger, getFlowStateManager, sendEvent } = require('~/config'); const { encryptV2, decryptV2 } = require('~/server/utils/crypto'); const { getActions, deleteActions } = require('~/models/Action'); @@ -51,7 +50,7 @@ const validateAndUpdateTool = async ({ req, tool, assistant_id }) => { return null; } - const parsedDomain = await domainParser(req, domain, true); + const parsedDomain = await domainParser(domain, true); if (!parsedDomain) { return null; @@ -67,16 +66,14 @@ const validateAndUpdateTool = async ({ req, tool, assistant_id }) => { * * Necessary due to `[a-zA-Z0-9_-]*` Regex Validation, limited to a 64-character maximum. * - * @param {Express.Request} req - The Express Request object. * @param {string} domain - The domain name to encode/decode. * @param {boolean} inverse - False to decode from base64, true to encode to base64. * @returns {Promise} Encoded or decoded domain string. */ -async function domainParser(req, domain, inverse = false) { +async function domainParser(domain, inverse = false) { if (!domain) { return; } - const domainsCache = getLogStores(CacheKeys.ENCODED_DOMAINS); const cachedDomain = await domainsCache.get(domain); if (inverse && cachedDomain) { @@ -123,47 +120,39 @@ async function loadActionSets(searchParams) { * Creates a general tool for an entire action set. * * @param {Object} params - The parameters for loading action sets. - * @param {ServerRequest} params.req + * @param {string} params.userId * @param {ServerResponse} params.res * @param {Action} params.action - The action set. Necessary for decrypting authentication values. * @param {ActionRequest} params.requestBuilder - The ActionRequest builder class to execute the API call. * @param {string | undefined} [params.name] - The name of the tool. * @param {string | undefined} [params.description] - The description for the tool. * @param {import('zod').ZodTypeAny | undefined} [params.zodSchema] - The Zod schema for tool input validation/definition + * @param {{ oauth_client_id?: string; oauth_client_secret?: string; }} params.encrypted - The encrypted values for the action. * @returns { Promise unknown}> } An object with `_call` method to execute the tool input. */ async function createActionTool({ - req, + userId, res, action, requestBuilder, zodSchema, name, description, + encrypted, }) { - const isDomainAllowed = await isActionDomainAllowed(action.metadata.domain); - if (!isDomainAllowed) { - return null; - } - const encrypted = { - oauth_client_id: action.metadata.oauth_client_id, - oauth_client_secret: action.metadata.oauth_client_secret, - }; - action.metadata = await decryptMetadata(action.metadata); - /** @type {(toolInput: Object | string, config: GraphRunnableConfig) => Promise} */ const _call = async (toolInput, config) => { try { /** @type {import('librechat-data-provider').ActionMetadataRuntime} */ const metadata = action.metadata; const executor = requestBuilder.createExecutor(); - const preparedExecutor = executor.setParams(toolInput); + const preparedExecutor = executor.setParams(toolInput ?? {}); if (metadata.auth && metadata.auth.type !== AuthTypeEnum.None) { try { - const action_id = action.action_id; - const identifier = `${req.user.id}:${action.action_id}`; if (metadata.auth.type === AuthTypeEnum.OAuth && metadata.auth.authorization_url) { + const action_id = action.action_id; + const identifier = `${userId}:${action.action_id}`; const requestLogin = async () => { const { args: _args, stepId, ...toolCall } = config.toolCall ?? {}; if (!stepId) { @@ -171,7 +160,7 @@ async function createActionTool({ } const statePayload = { nonce: nanoid(), - user: req.user.id, + user: userId, action_id, }; @@ -198,26 +187,33 @@ async function createActionTool({ expires_at: Date.now() + Time.TWO_MINUTES, }, }; - const flowManager = await getFlowStateManager(getLogStores); + const flowsCache = getLogStores(CacheKeys.FLOWS); + const flowManager = getFlowStateManager(flowsCache); await flowManager.createFlowWithHandler( - `${identifier}:login`, + `${identifier}:oauth_login:${config.metadata.thread_id}:${config.metadata.run_id}`, 'oauth_login', async () => { sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data }); logger.debug('Sent OAuth login request to client', { action_id, identifier }); return true; }, + config?.signal, ); logger.debug('Waiting for OAuth Authorization response', { action_id, identifier }); - const result = await flowManager.createFlow(identifier, 'oauth', { - state: stateToken, - userId: req.user.id, - client_url: metadata.auth.client_url, - redirect_uri: `${process.env.DOMAIN_CLIENT}/api/actions/${action_id}/oauth/callback`, - /** Encrypted values */ - encrypted_oauth_client_id: encrypted.oauth_client_id, - encrypted_oauth_client_secret: encrypted.oauth_client_secret, - }); + const result = await flowManager.createFlow( + identifier, + 'oauth', + { + state: stateToken, + userId: userId, + client_url: metadata.auth.client_url, + redirect_uri: `${process.env.DOMAIN_CLIENT}/api/actions/${action_id}/oauth/callback`, + /** Encrypted values */ + encrypted_oauth_client_id: encrypted.oauth_client_id, + encrypted_oauth_client_secret: encrypted.oauth_client_secret, + }, + config?.signal, + ); logger.debug('Received OAuth Authorization response', { action_id, identifier }); data.delta.auth = undefined; data.delta.expires_at = undefined; @@ -235,10 +231,10 @@ async function createActionTool({ }; const tokenPromises = []; - tokenPromises.push(findToken({ userId: req.user.id, type: 'oauth', identifier })); + tokenPromises.push(findToken({ userId, type: 'oauth', identifier })); tokenPromises.push( findToken({ - userId: req.user.id, + userId, type: 'oauth_refresh', identifier: `${identifier}:refresh`, }), @@ -261,18 +257,20 @@ async function createActionTool({ const refresh_token = await decryptV2(refreshTokenData.token); const refreshTokens = async () => await refreshAccessToken({ + userId, identifier, refresh_token, - userId: req.user.id, client_url: metadata.auth.client_url, encrypted_oauth_client_id: encrypted.oauth_client_id, encrypted_oauth_client_secret: encrypted.oauth_client_secret, }); - const flowManager = await getFlowStateManager(getLogStores); + const flowsCache = getLogStores(CacheKeys.FLOWS); + const flowManager = getFlowStateManager(flowsCache); const refreshData = await flowManager.createFlowWithHandler( `${identifier}:refresh`, 'oauth_refresh', refreshTokens, + config?.signal, ); metadata.oauth_access_token = refreshData.access_token; if (refreshData.refresh_token) { @@ -308,9 +306,8 @@ async function createActionTool({ } return response.data; } catch (error) { - const logMessage = `API call to ${action.metadata.domain} failed`; - logAxiosError({ message: logMessage, error }); - throw error; + const message = `API call to ${action.metadata.domain} failed:`; + return logAxiosError({ message, error }); } }; @@ -327,6 +324,27 @@ async function createActionTool({ }; } +/** + * Encrypts a sensitive value. + * @param {string} value + * @returns {Promise} + */ +async function encryptSensitiveValue(value) { + // Encode API key to handle special characters like ":" + const encodedValue = encodeURIComponent(value); + return await encryptV2(encodedValue); +} + +/** + * Decrypts a sensitive value. + * @param {string} value + * @returns {Promise} + */ +async function decryptSensitiveValue(value) { + const decryptedValue = await decryptV2(value); + return decodeURIComponent(decryptedValue); +} + /** * Encrypts sensitive metadata values for an action. * @@ -339,17 +357,19 @@ async function encryptMetadata(metadata) { // ServiceHttp if (metadata.auth && metadata.auth.type === AuthTypeEnum.ServiceHttp) { if (metadata.api_key) { - encryptedMetadata.api_key = await encryptV2(metadata.api_key); + encryptedMetadata.api_key = await encryptSensitiveValue(metadata.api_key); } } // OAuth else if (metadata.auth && metadata.auth.type === AuthTypeEnum.OAuth) { if (metadata.oauth_client_id) { - encryptedMetadata.oauth_client_id = await encryptV2(metadata.oauth_client_id); + encryptedMetadata.oauth_client_id = await encryptSensitiveValue(metadata.oauth_client_id); } if (metadata.oauth_client_secret) { - encryptedMetadata.oauth_client_secret = await encryptV2(metadata.oauth_client_secret); + encryptedMetadata.oauth_client_secret = await encryptSensitiveValue( + metadata.oauth_client_secret, + ); } } @@ -368,17 +388,19 @@ async function decryptMetadata(metadata) { // ServiceHttp if (metadata.auth && metadata.auth.type === AuthTypeEnum.ServiceHttp) { if (metadata.api_key) { - decryptedMetadata.api_key = await decryptV2(metadata.api_key); + decryptedMetadata.api_key = await decryptSensitiveValue(metadata.api_key); } } // OAuth else if (metadata.auth && metadata.auth.type === AuthTypeEnum.OAuth) { if (metadata.oauth_client_id) { - decryptedMetadata.oauth_client_id = await decryptV2(metadata.oauth_client_id); + decryptedMetadata.oauth_client_id = await decryptSensitiveValue(metadata.oauth_client_id); } if (metadata.oauth_client_secret) { - decryptedMetadata.oauth_client_secret = await decryptV2(metadata.oauth_client_secret); + decryptedMetadata.oauth_client_secret = await decryptSensitiveValue( + metadata.oauth_client_secret, + ); } } diff --git a/api/server/services/ActionService.spec.js b/api/server/services/ActionService.spec.js index 8f9d67a9d1..f3b4423197 100644 --- a/api/server/services/ActionService.spec.js +++ b/api/server/services/ActionService.spec.js @@ -78,20 +78,20 @@ describe('domainParser', () => { // Non-azure request it('does not return domain as is if not azure', async () => { const domain = `example.com${actionDomainSeparator}test${actionDomainSeparator}`; - const result1 = await domainParser(reqNoAzure, domain, false); - const result2 = await domainParser(reqNoAzure, domain, true); + const result1 = await domainParser(domain, false); + const result2 = await domainParser(domain, true); expect(result1).not.toEqual(domain); expect(result2).not.toEqual(domain); }); // Test for Empty or Null Inputs it('returns undefined for null domain input', async () => { - const result = await domainParser(req, null, true); + const result = await domainParser(null, true); expect(result).toBeUndefined(); }); it('returns undefined for empty domain input', async () => { - const result = await domainParser(req, '', true); + const result = await domainParser('', true); expect(result).toBeUndefined(); }); @@ -102,7 +102,7 @@ describe('domainParser', () => { .toString('base64') .substring(0, Constants.ENCODED_DOMAIN_LENGTH); - await domainParser(req, domain, true); + await domainParser(domain, true); const cachedValue = await globalCache[encodedDomain]; expect(cachedValue).toEqual(Buffer.from(domain).toString('base64')); @@ -112,14 +112,14 @@ describe('domainParser', () => { it('encodes domain exactly at threshold without modification', async () => { const domain = 'a'.repeat(Constants.ENCODED_DOMAIN_LENGTH - TLD.length) + TLD; const expected = domain.replace(/\./g, actionDomainSeparator); - const result = await domainParser(req, domain, true); + const result = await domainParser(domain, true); expect(result).toEqual(expected); }); it('encodes domain just below threshold without modification', async () => { const domain = 'a'.repeat(Constants.ENCODED_DOMAIN_LENGTH - 1 - TLD.length) + TLD; const expected = domain.replace(/\./g, actionDomainSeparator); - const result = await domainParser(req, domain, true); + const result = await domainParser(domain, true); expect(result).toEqual(expected); }); @@ -129,7 +129,7 @@ describe('domainParser', () => { const encodedDomain = Buffer.from(unicodeDomain) .toString('base64') .substring(0, Constants.ENCODED_DOMAIN_LENGTH); - const result = await domainParser(req, unicodeDomain, true); + const result = await domainParser(unicodeDomain, true); expect(result).toEqual(encodedDomain); }); @@ -139,7 +139,6 @@ describe('domainParser', () => { globalCache[encodedDomain.substring(0, Constants.ENCODED_DOMAIN_LENGTH)] = encodedDomain; // Simulate caching const result = await domainParser( - req, encodedDomain.substring(0, Constants.ENCODED_DOMAIN_LENGTH), false, ); @@ -150,27 +149,27 @@ describe('domainParser', () => { it('returns domain with replaced separators if no cached domain exists', async () => { const domain = 'example.com'; const withSeparator = domain.replace(/\./g, actionDomainSeparator); - const result = await domainParser(req, withSeparator, false); + const result = await domainParser(withSeparator, false); expect(result).toEqual(domain); }); it('returns domain with replaced separators when inverse is false and under encoding length', async () => { const domain = 'examp.com'; const withSeparator = domain.replace(/\./g, actionDomainSeparator); - const result = await domainParser(req, withSeparator, false); + const result = await domainParser(withSeparator, false); expect(result).toEqual(domain); }); it('replaces periods with actionDomainSeparator when inverse is true and under encoding length', async () => { const domain = 'examp.com'; const expected = domain.replace(/\./g, actionDomainSeparator); - const result = await domainParser(req, domain, true); + const result = await domainParser(domain, true); expect(result).toEqual(expected); }); it('encodes domain when length is above threshold and inverse is true', async () => { const domain = 'a'.repeat(Constants.ENCODED_DOMAIN_LENGTH + 1).concat('.com'); - const result = await domainParser(req, domain, true); + const result = await domainParser(domain, true); expect(result).not.toEqual(domain); expect(result.length).toBeLessThanOrEqual(Constants.ENCODED_DOMAIN_LENGTH); }); @@ -180,20 +179,20 @@ describe('domainParser', () => { const encodedDomain = Buffer.from( originalDomain.replace(/\./g, actionDomainSeparator), ).toString('base64'); - const result = await domainParser(req, encodedDomain, false); + const result = await domainParser(encodedDomain, false); expect(result).toEqual(encodedDomain); }); it('decodes encoded value if cached and encoded value is provided, and inverse is false', async () => { const originalDomain = 'example.com'; - const encodedDomain = await domainParser(req, originalDomain, true); - const result = await domainParser(req, encodedDomain, false); + const encodedDomain = await domainParser(originalDomain, true); + const result = await domainParser(encodedDomain, false); expect(result).toEqual(originalDomain); }); it('handles invalid base64 encoded values gracefully', async () => { const invalidBase64Domain = 'not_base64_encoded'; - const result = await domainParser(req, invalidBase64Domain, false); + const result = await domainParser(invalidBase64Domain, false); expect(result).toEqual(invalidBase64Domain); }); }); diff --git a/api/server/services/AppService.js b/api/server/services/AppService.js index d194d31a6b..1ad3aaace6 100644 --- a/api/server/services/AppService.js +++ b/api/server/services/AppService.js @@ -1,15 +1,24 @@ -const { FileSources, EModelEndpoint, getConfigDefaults } = require('librechat-data-provider'); +const { + FileSources, + EModelEndpoint, + loadOCRConfig, + processMCPEnv, + getConfigDefaults, +} = require('librechat-data-provider'); const { checkVariables, checkHealth, checkConfig, checkAzureVariables } = require('./start/checks'); const { azureAssistantsDefaults, assistantsConfigSetup } = require('./start/assistants'); +const { initializeAzureBlobService } = require('./Files/Azure/initialize'); const { initializeFirebase } = require('./Files/Firebase/initialize'); const loadCustomConfig = require('./Config/loadCustomConfig'); const handleRateLimits = require('./Config/handleRateLimits'); const { loadDefaultInterface } = require('./start/interface'); const { azureConfigSetup } = require('./start/azureOpenAI'); const { processModelSpecs } = require('./start/modelSpecs'); +const { initializeS3 } = require('./Files/S3/initialize'); const { loadAndFormatTools } = require('./ToolService'); const { agentsConfigSetup } = require('./start/agents'); const { initializeRoles } = require('~/models/Role'); +const { isEnabled } = require('~/server/utils'); const { getMCPManager } = require('~/config'); const paths = require('~/config/paths'); @@ -21,13 +30,19 @@ const paths = require('~/config/paths'); */ const AppService = async (app) => { await initializeRoles(); - /** @type {TCustomConfig}*/ + /** @type {TCustomConfig} */ const config = (await loadCustomConfig()) ?? {}; const configDefaults = getConfigDefaults(); + const ocr = loadOCRConfig(config.ocr); const filteredTools = config.filteredTools; const includedTools = config.includedTools; const fileStrategy = config.fileStrategy ?? configDefaults.fileStrategy; + const startBalance = process.env.START_BALANCE; + const balance = config.balance ?? { + enabled: isEnabled(process.env.CHECK_BALANCE), + startBalance: startBalance ? parseInt(startBalance, 10) : undefined, + }; const imageOutputType = config?.imageOutputType ?? configDefaults.imageOutputType; process.env.CDN_PROVIDER = fileStrategy; @@ -37,9 +52,13 @@ const AppService = async (app) => { if (fileStrategy === FileSources.firebase) { initializeFirebase(); + } else if (fileStrategy === FileSources.azure_blob) { + initializeAzureBlobService(); + } else if (fileStrategy === FileSources.s3) { + initializeS3(); } - /** @type {Record} */ const availableTools = loadAndFormatTools({ adminFilter: filteredTools, adminIncluded: includedTools, @@ -47,8 +66,8 @@ const AppService = async (app) => { }); if (config.mcpServers != null) { - const mcpManager = await getMCPManager(); - await mcpManager.initializeMCP(config.mcpServers); + const mcpManager = getMCPManager(); + await mcpManager.initializeMCP(config.mcpServers, processMCPEnv); await mcpManager.mapAvailableTools(availableTools); } @@ -57,6 +76,7 @@ const AppService = async (app) => { const interfaceConfig = await loadDefaultInterface(config, configDefaults); const defaultLocals = { + ocr, paths, fileStrategy, socialLogins, @@ -65,6 +85,7 @@ const AppService = async (app) => { availableTools, imageOutputType, interfaceConfig, + balance, }; if (!Object.keys(config).length) { @@ -125,7 +146,7 @@ const AppService = async (app) => { ...defaultLocals, fileConfig: config?.fileConfig, secureImageLinks: config?.secureImageLinks, - modelSpecs: processModelSpecs(endpoints, config.modelSpecs), + modelSpecs: processModelSpecs(endpoints, config.modelSpecs, interfaceConfig), ...endpointLocals, }; }; diff --git a/api/server/services/AppService.spec.js b/api/server/services/AppService.spec.js index 61ac80fc6c..465ec9fdd6 100644 --- a/api/server/services/AppService.spec.js +++ b/api/server/services/AppService.spec.js @@ -15,6 +15,9 @@ jest.mock('./Config/loadCustomConfig', () => { Promise.resolve({ registration: { socialLogins: ['testLogin'] }, fileStrategy: 'testStrategy', + balance: { + enabled: true, + }, }), ); }); @@ -120,9 +123,13 @@ describe('AppService', () => { }, }, paths: expect.anything(), + ocr: expect.anything(), imageOutputType: expect.any(String), fileConfig: undefined, secureImageLinks: undefined, + balance: { enabled: true }, + filteredTools: undefined, + includedTools: undefined, }); }); @@ -340,9 +347,6 @@ describe('AppService', () => { process.env.FILE_UPLOAD_USER_MAX = 'initialUserMax'; process.env.FILE_UPLOAD_USER_WINDOW = 'initialUserWindow'; - // Mock a custom configuration without specific rate limits - require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve({})); - await AppService(app); // Verify that process.env falls back to the initial values @@ -403,9 +407,6 @@ describe('AppService', () => { process.env.IMPORT_USER_MAX = 'initialUserMax'; process.env.IMPORT_USER_WINDOW = 'initialUserWindow'; - // Mock a custom configuration without specific rate limits - require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve({})); - await AppService(app); // Verify that process.env falls back to the initial values @@ -444,13 +445,27 @@ describe('AppService updating app.locals and issuing warnings', () => { expect(app.locals.availableTools).toBeDefined(); expect(app.locals.fileStrategy).toEqual(FileSources.local); expect(app.locals.socialLogins).toEqual(defaultSocialLogins); + expect(app.locals.balance).toEqual( + expect.objectContaining({ + enabled: false, + startBalance: undefined, + }), + ); }); it('should update app.locals with values from loadCustomConfig', async () => { - // Mock loadCustomConfig to return a specific config object + // Mock loadCustomConfig to return a specific config object with a complete balance config const customConfig = { fileStrategy: 'firebase', registration: { socialLogins: ['testLogin'] }, + balance: { + enabled: false, + startBalance: 5000, + autoRefillEnabled: true, + refillIntervalValue: 15, + refillIntervalUnit: 'hours', + refillAmount: 5000, + }, }; require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve(customConfig), @@ -463,6 +478,7 @@ describe('AppService updating app.locals and issuing warnings', () => { expect(app.locals.availableTools).toBeDefined(); expect(app.locals.fileStrategy).toEqual(customConfig.fileStrategy); expect(app.locals.socialLogins).toEqual(customConfig.registration.socialLogins); + expect(app.locals.balance).toEqual(customConfig.balance); }); it('should apply the assistants endpoint configuration correctly to app.locals', async () => { @@ -588,4 +604,33 @@ describe('AppService updating app.locals and issuing warnings', () => { ); }); }); + + it('should not parse environment variable references in OCR config', async () => { + // Mock custom configuration with env variable references in OCR config + const mockConfig = { + ocr: { + apiKey: '${OCR_API_KEY_CUSTOM_VAR_NAME}', + baseURL: '${OCR_BASEURL_CUSTOM_VAR_NAME}', + strategy: 'mistral_ocr', + mistralModel: 'mistral-medium', + }, + }; + + require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve(mockConfig)); + + // Set actual environment variables with different values + process.env.OCR_API_KEY_CUSTOM_VAR_NAME = 'actual-api-key'; + process.env.OCR_BASEURL_CUSTOM_VAR_NAME = 'https://actual-ocr-url.com'; + + // Initialize app + const app = { locals: {} }; + await AppService(app); + + // Verify that the raw string references were preserved and not interpolated + expect(app.locals.ocr).toBeDefined(); + expect(app.locals.ocr.apiKey).toEqual('${OCR_API_KEY_CUSTOM_VAR_NAME}'); + expect(app.locals.ocr.baseURL).toEqual('${OCR_BASEURL_CUSTOM_VAR_NAME}'); + expect(app.locals.ocr.strategy).toEqual('mistral_ocr'); + expect(app.locals.ocr.mistralModel).toEqual('mistral-medium'); + }); }); diff --git a/api/server/services/AuthService.js b/api/server/services/AuthService.js index 3c02b7eea0..0bb1e22cf8 100644 --- a/api/server/services/AuthService.js +++ b/api/server/services/AuthService.js @@ -56,7 +56,7 @@ const logoutUser = async (req, refreshToken) => { try { req.session.destroy(); } catch (destroyErr) { - logger.error('[logoutUser] Failed to destroy session.', destroyErr); + logger.debug('[logoutUser] Failed to destroy session.', destroyErr); } return { status: 200, message: 'Logout successful' }; @@ -91,7 +91,7 @@ const sendVerificationEmail = async (user) => { subject: 'Verify your email', payload: { appName: process.env.APP_TITLE || 'LibreChat', - name: user.name, + name: user.name || user.username || user.email, verificationLink: verificationLink, year: new Date().getFullYear(), }, @@ -278,7 +278,7 @@ const requestPasswordReset = async (req) => { subject: 'Password Reset Request', payload: { appName: process.env.APP_TITLE || 'LibreChat', - name: user.name, + name: user.name || user.username || user.email, link: link, year: new Date().getFullYear(), }, @@ -331,7 +331,7 @@ const resetPassword = async (userId, token, password) => { subject: 'Password Reset Successfully', payload: { appName: process.env.APP_TITLE || 'LibreChat', - name: user.name, + name: user.name || user.username || user.email, year: new Date().getFullYear(), }, template: 'passwordReset.handlebars', @@ -414,7 +414,7 @@ const resendVerificationEmail = async (req) => { subject: 'Verify your email', payload: { appName: process.env.APP_TITLE || 'LibreChat', - name: user.name, + name: user.name || user.username || user.email, verificationLink: verificationLink, year: new Date().getFullYear(), }, diff --git a/api/server/services/Config/getCustomConfig.js b/api/server/services/Config/getCustomConfig.js index 5b9b2dd186..fdd84878eb 100644 --- a/api/server/services/Config/getCustomConfig.js +++ b/api/server/services/Config/getCustomConfig.js @@ -1,5 +1,5 @@ const { CacheKeys, EModelEndpoint } = require('librechat-data-provider'); -const { normalizeEndpointName } = require('~/server/utils'); +const { normalizeEndpointName, isEnabled } = require('~/server/utils'); const loadCustomConfig = require('./loadCustomConfig'); const getLogStores = require('~/cache/getLogStores'); @@ -23,6 +23,26 @@ async function getCustomConfig() { return customConfig; } +/** + * Retrieves the configuration object + * @function getBalanceConfig + * @returns {Promise} + * */ +async function getBalanceConfig() { + const isLegacyEnabled = isEnabled(process.env.CHECK_BALANCE); + const startBalance = process.env.START_BALANCE; + /** @type {TCustomConfig['balance']} */ + const config = { + enabled: isLegacyEnabled, + startBalance: startBalance != null && startBalance ? parseInt(startBalance, 10) : undefined, + }; + const customConfig = await getCustomConfig(); + if (!customConfig) { + return config; + } + return { ...config, ...(customConfig?.['balance'] ?? {}) }; +} + /** * * @param {string | EModelEndpoint} endpoint @@ -40,4 +60,4 @@ const getCustomEndpointConfig = async (endpoint) => { ); }; -module.exports = { getCustomConfig, getCustomEndpointConfig }; +module.exports = { getCustomConfig, getBalanceConfig, getCustomEndpointConfig }; diff --git a/api/server/services/Config/getEndpointsConfig.js b/api/server/services/Config/getEndpointsConfig.js index 4f8bde68ad..8ae022e4b3 100644 --- a/api/server/services/Config/getEndpointsConfig.js +++ b/api/server/services/Config/getEndpointsConfig.js @@ -33,10 +33,12 @@ async function getEndpointsConfig(req) { }; } if (mergedConfig[EModelEndpoint.agents] && req.app.locals?.[EModelEndpoint.agents]) { - const { disableBuilder, capabilities, ..._rest } = req.app.locals[EModelEndpoint.agents]; + const { disableBuilder, capabilities, allowedProviders, ..._rest } = + req.app.locals[EModelEndpoint.agents]; mergedConfig[EModelEndpoint.agents] = { ...mergedConfig[EModelEndpoint.agents], + allowedProviders, disableBuilder, capabilities, }; @@ -72,4 +74,15 @@ async function getEndpointsConfig(req) { return endpointsConfig; } -module.exports = { getEndpointsConfig }; +/** + * @param {ServerRequest} req + * @param {import('librechat-data-provider').AgentCapabilities} capability + * @returns {Promise} + */ +const checkCapability = async (req, capability) => { + const endpointsConfig = await getEndpointsConfig(req); + const capabilities = endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? []; + return capabilities.includes(capability); +}; + +module.exports = { getEndpointsConfig, checkCapability }; diff --git a/api/server/services/Config/loadConfigModels.js b/api/server/services/Config/loadConfigModels.js index 0df811468b..fc255b8c47 100644 --- a/api/server/services/Config/loadConfigModels.js +++ b/api/server/services/Config/loadConfigModels.js @@ -47,7 +47,7 @@ async function loadConfigModels(req) { ); /** - * @type {Record} + * @type {Record>} * Map for promises keyed by unique combination of baseURL and apiKey */ const fetchPromisesMap = {}; /** @@ -102,7 +102,7 @@ async function loadConfigModels(req) { for (const name of associatedNames) { const endpoint = endpointsMap[name]; - modelsConfig[name] = !modelData?.length ? endpoint.models.default ?? [] : modelData; + modelsConfig[name] = !modelData?.length ? (endpoint.models.default ?? []) : modelData; } } diff --git a/api/server/services/Config/loadDefaultModels.js b/api/server/services/Config/loadDefaultModels.js index 82db356841..db91c4101b 100644 --- a/api/server/services/Config/loadDefaultModels.js +++ b/api/server/services/Config/loadDefaultModels.js @@ -5,8 +5,8 @@ const { getGoogleModels, getBedrockModels, getAnthropicModels, - getChatGPTBrowserModels, } = require('~/server/services/ModelService'); +const { logger } = require('~/config'); /** * Loads the default models for the application. @@ -15,31 +15,68 @@ const { * @param {Express.Request} req - The Express request object. */ async function loadDefaultModels(req) { - const google = getGoogleModels(); - const openAI = await getOpenAIModels({ user: req.user.id }); - const anthropic = getAnthropicModels(); - const chatGPTBrowser = getChatGPTBrowserModels(); - const azureOpenAI = await getOpenAIModels({ user: req.user.id, azure: true }); - const gptPlugins = await getOpenAIModels({ - user: req.user.id, - azure: useAzurePlugins, - plugins: true, - }); - const assistants = await getOpenAIModels({ assistants: true }); - const azureAssistants = await getOpenAIModels({ azureAssistants: true }); + try { + const [ + openAI, + anthropic, + azureOpenAI, + gptPlugins, + assistants, + azureAssistants, + google, + bedrock, + ] = await Promise.all([ + getOpenAIModels({ user: req.user.id }).catch((error) => { + logger.error('Error fetching OpenAI models:', error); + return []; + }), + getAnthropicModels({ user: req.user.id }).catch((error) => { + logger.error('Error fetching Anthropic models:', error); + return []; + }), + getOpenAIModels({ user: req.user.id, azure: true }).catch((error) => { + logger.error('Error fetching Azure OpenAI models:', error); + return []; + }), + getOpenAIModels({ user: req.user.id, azure: useAzurePlugins, plugins: true }).catch( + (error) => { + logger.error('Error fetching Plugin models:', error); + return []; + }, + ), + getOpenAIModels({ assistants: true }).catch((error) => { + logger.error('Error fetching OpenAI Assistants API models:', error); + return []; + }), + getOpenAIModels({ azureAssistants: true }).catch((error) => { + logger.error('Error fetching Azure OpenAI Assistants API models:', error); + return []; + }), + Promise.resolve(getGoogleModels()).catch((error) => { + logger.error('Error getting Google models:', error); + return []; + }), + Promise.resolve(getBedrockModels()).catch((error) => { + logger.error('Error getting Bedrock models:', error); + return []; + }), + ]); - return { - [EModelEndpoint.openAI]: openAI, - [EModelEndpoint.agents]: openAI, - [EModelEndpoint.google]: google, - [EModelEndpoint.anthropic]: anthropic, - [EModelEndpoint.gptPlugins]: gptPlugins, - [EModelEndpoint.azureOpenAI]: azureOpenAI, - [EModelEndpoint.chatGPTBrowser]: chatGPTBrowser, - [EModelEndpoint.assistants]: assistants, - [EModelEndpoint.azureAssistants]: azureAssistants, - [EModelEndpoint.bedrock]: getBedrockModels(), - }; + return { + [EModelEndpoint.openAI]: openAI, + [EModelEndpoint.agents]: openAI, + [EModelEndpoint.google]: google, + [EModelEndpoint.anthropic]: anthropic, + [EModelEndpoint.gptPlugins]: gptPlugins, + [EModelEndpoint.azureOpenAI]: azureOpenAI, + [EModelEndpoint.assistants]: assistants, + [EModelEndpoint.azureAssistants]: azureAssistants, + [EModelEndpoint.bedrock]: bedrock, + }; + } catch (error) { + logger.error('Error fetching default models:', error); + throw new Error(`Failed to load default models: ${error.message}`); + } } module.exports = loadDefaultModels; diff --git a/api/server/services/Endpoints/agents/build.js b/api/server/services/Endpoints/agents/build.js index 027937e7fd..77ebbc58dc 100644 --- a/api/server/services/Endpoints/agents/build.js +++ b/api/server/services/Endpoints/agents/build.js @@ -1,19 +1,15 @@ +const { isAgentsEndpoint, Constants } = require('librechat-data-provider'); const { loadAgent } = require('~/models/Agent'); const { logger } = require('~/config'); -const buildOptions = (req, endpoint, parsedBody) => { - const { - spec, - iconURL, - agent_id, - instructions, - maxContextTokens, - resendFiles = true, - ...model_parameters - } = parsedBody; +const buildOptions = (req, endpoint, parsedBody, endpointType) => { + const { spec, iconURL, agent_id, instructions, maxContextTokens, ...model_parameters } = + parsedBody; const agentPromise = loadAgent({ req, - agent_id, + agent_id: isAgentsEndpoint(endpoint) ? agent_id : Constants.EPHEMERAL_AGENT_ID, + endpoint, + model_parameters, }).catch((error) => { logger.error(`[/agents/:${agent_id}] Error retrieving agent during build options step`, error); return undefined; @@ -24,7 +20,7 @@ const buildOptions = (req, endpoint, parsedBody) => { iconURL, endpoint, agent_id, - resendFiles, + endpointType, instructions, maxContextTokens, model_parameters, diff --git a/api/server/services/Endpoints/agents/initialize.js b/api/server/services/Endpoints/agents/initialize.js index 3e03a45125..c9e363e815 100644 --- a/api/server/services/Endpoints/agents/initialize.js +++ b/api/server/services/Endpoints/agents/initialize.js @@ -1,7 +1,12 @@ const { createContentAggregator, Providers } = require('@librechat/agents'); const { + Constants, + ErrorTypes, EModelEndpoint, + EToolResources, getResponseSender, + AgentCapabilities, + replaceSpecialVars, providerEndpointMap, } = require('librechat-data-provider'); const { @@ -15,53 +20,96 @@ const initCustom = require('~/server/services/Endpoints/custom/initialize'); const initGoogle = require('~/server/services/Endpoints/google/initialize'); const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts'); const { getCustomEndpointConfig } = require('~/server/services/Config'); +const { processFiles } = require('~/server/services/Files/process'); const { loadAgentTools } = require('~/server/services/ToolService'); const AgentClient = require('~/server/controllers/agents/client'); +const { getConvoFiles } = require('~/models/Conversation'); +const { getToolFilesByIds } = require('~/models/File'); const { getModelMaxTokens } = require('~/utils'); const { getAgent } = require('~/models/Agent'); +const { getFiles } = require('~/models/File'); const { logger } = require('~/config'); const providerConfigMap = { + [Providers.XAI]: initCustom, + [Providers.OLLAMA]: initCustom, + [Providers.DEEPSEEK]: initCustom, + [Providers.OPENROUTER]: initCustom, [EModelEndpoint.openAI]: initOpenAI, + [EModelEndpoint.google]: initGoogle, [EModelEndpoint.azureOpenAI]: initOpenAI, [EModelEndpoint.anthropic]: initAnthropic, [EModelEndpoint.bedrock]: getBedrockOptions, - [EModelEndpoint.google]: initGoogle, - [Providers.OLLAMA]: initCustom, }; /** - * - * @param {Promise> | undefined} _attachments - * @param {AgentToolResources | undefined} _tool_resources + * @param {Object} params + * @param {ServerRequest} params.req + * @param {Promise> | undefined} [params.attachments] + * @param {Set} params.requestFileSet + * @param {AgentToolResources | undefined} [params.tool_resources] * @returns {Promise<{ attachments: Array | undefined, tool_resources: AgentToolResources | undefined }>} */ -const primeResources = async (_attachments, _tool_resources) => { +const primeResources = async ({ + req, + attachments: _attachments, + tool_resources: _tool_resources, + requestFileSet, +}) => { try { + /** @type {Array | undefined} */ + let attachments; + const tool_resources = _tool_resources ?? {}; + const isOCREnabled = (req.app.locals?.[EModelEndpoint.agents]?.capabilities ?? []).includes( + AgentCapabilities.ocr, + ); + if (tool_resources[EToolResources.ocr]?.file_ids && isOCREnabled) { + const context = await getFiles( + { + file_id: { $in: tool_resources.ocr.file_ids }, + }, + {}, + {}, + ); + attachments = (attachments ?? []).concat(context); + } if (!_attachments) { - return { attachments: undefined, tool_resources: _tool_resources }; + return { attachments, tool_resources }; } /** @type {Array | undefined} */ const files = await _attachments; - const attachments = []; - const tool_resources = _tool_resources ?? {}; + if (!attachments) { + /** @type {Array} */ + attachments = []; + } for (const file of files) { if (!file) { continue; } if (file.metadata?.fileIdentifier) { - const execute_code = tool_resources.execute_code ?? {}; + const execute_code = tool_resources[EToolResources.execute_code] ?? {}; if (!execute_code.files) { - tool_resources.execute_code = { ...execute_code, files: [] }; + tool_resources[EToolResources.execute_code] = { ...execute_code, files: [] }; } - tool_resources.execute_code.files.push(file); + tool_resources[EToolResources.execute_code].files.push(file); } else if (file.embedded === true) { - const file_search = tool_resources.file_search ?? {}; + const file_search = tool_resources[EToolResources.file_search] ?? {}; if (!file_search.files) { - tool_resources.file_search = { ...file_search, files: [] }; + tool_resources[EToolResources.file_search] = { ...file_search, files: [] }; } - tool_resources.file_search.files.push(file); + tool_resources[EToolResources.file_search].files.push(file); + } else if ( + requestFileSet.has(file.file_id) && + file.type.startsWith('image') && + file.height && + file.width + ) { + const image_edit = tool_resources[EToolResources.image_edit] ?? {}; + if (!image_edit.files) { + tool_resources[EToolResources.image_edit] = { ...image_edit, files: [] }; + } + tool_resources[EToolResources.image_edit].files.push(file); } attachments.push(file); @@ -73,13 +121,26 @@ const primeResources = async (_attachments, _tool_resources) => { } }; +/** + * @param {...string | number} values + * @returns {string | number | undefined} + */ +function optionalChainWithEmptyCheck(...values) { + for (const value of values) { + if (value !== undefined && value !== null && value !== '') { + return value; + } + } + return values[values.length - 1]; +} + /** * @param {object} params * @param {ServerRequest} params.req * @param {ServerResponse} params.res * @param {Agent} params.agent + * @param {Set} [params.allowedProviders] * @param {object} [params.endpointOption] - * @param {AgentToolResources} [params.tool_resources] * @param {boolean} [params.isInitialAgent] * @returns {Promise} */ @@ -88,29 +149,71 @@ const initializeAgentOptions = async ({ res, agent, endpointOption, - tool_resources, + allowedProviders, isInitialAgent = false, }) => { - const { tools, toolContextMap } = await loadAgentTools({ + if (allowedProviders.size > 0 && !allowedProviders.has(agent.provider)) { + throw new Error( + `{ "type": "${ErrorTypes.INVALID_AGENT_PROVIDER}", "info": "${agent.provider}" }`, + ); + } + let currentFiles; + /** @type {Array} */ + const requestFiles = req.body.files ?? []; + if ( + isInitialAgent && + req.body.conversationId != null && + (agent.model_parameters?.resendFiles ?? true) === true + ) { + const fileIds = (await getConvoFiles(req.body.conversationId)) ?? []; + /** @type {Set} */ + const toolResourceSet = new Set(); + for (const tool of agent.tools) { + if (EToolResources[tool]) { + toolResourceSet.add(EToolResources[tool]); + } + } + const toolFiles = await getToolFilesByIds(fileIds, toolResourceSet); + if (requestFiles.length || toolFiles.length) { + currentFiles = await processFiles(requestFiles.concat(toolFiles)); + } + } else if (isInitialAgent && requestFiles.length) { + currentFiles = await processFiles(requestFiles); + } + + const { attachments, tool_resources } = await primeResources({ req, - res, - agent, - tool_resources, + attachments: currentFiles, + tool_resources: agent.tool_resources, + requestFileSet: new Set(requestFiles.map((file) => file.file_id)), }); const provider = agent.provider; - let getOptions = providerConfigMap[provider]; + const { tools, toolContextMap } = await loadAgentTools({ + req, + res, + agent: { + id: agent.id, + tools: agent.tools, + provider, + model: agent.model, + }, + tool_resources, + }); - if (!getOptions) { + agent.endpoint = provider; + let getOptions = providerConfigMap[provider]; + if (!getOptions && providerConfigMap[provider.toLowerCase()] != null) { + agent.provider = provider.toLowerCase(); + getOptions = providerConfigMap[agent.provider]; + } else if (!getOptions) { const customEndpointConfig = await getCustomEndpointConfig(provider); if (!customEndpointConfig) { throw new Error(`Provider ${provider} not supported`); } getOptions = initCustom; agent.provider = Providers.OPENAI; - agent.endpoint = provider.toLowerCase(); } - const model_parameters = Object.assign( {}, agent.model_parameters ?? { model: agent.model }, @@ -130,10 +233,18 @@ const initializeAgentOptions = async ({ endpointOption: _endpointOption, }); + if ( + agent.endpoint === EModelEndpoint.azureOpenAI && + options.llmConfig?.azureOpenAIApiInstanceName == null + ) { + agent.provider = Providers.OPENAI; + } + if (options.provider != null) { agent.provider = options.provider; } + /** @type {import('@librechat/agents').ClientOptions} */ agent.model_parameters = Object.assign(model_parameters, options.llmConfig); if (options.configOptions) { agent.model_parameters.configuration = options.configOptions; @@ -143,6 +254,13 @@ const initializeAgentOptions = async ({ agent.model_parameters.model = agent.model; } + if (agent.instructions && agent.instructions !== '') { + agent.instructions = replaceSpecialVars({ + text: agent.instructions, + user: req.user, + }); + } + if (typeof agent.artifacts === 'string' && agent.artifacts !== '') { agent.additional_instructions = generateArtifactsPrompt({ endpoint: agent.provider, @@ -152,15 +270,23 @@ const initializeAgentOptions = async ({ const tokensModel = agent.provider === EModelEndpoint.azureOpenAI ? agent.model : agent.model_parameters.model; - + const maxTokens = optionalChainWithEmptyCheck( + agent.model_parameters.maxOutputTokens, + agent.model_parameters.maxTokens, + 0, + ); + const maxContextTokens = optionalChainWithEmptyCheck( + agent.model_parameters.maxContextTokens, + agent.max_context_tokens, + getModelMaxTokens(tokensModel, providerEndpointMap[provider]), + 4096, + ); return { ...agent, tools, + attachments, toolContextMap, - maxContextTokens: - agent.max_context_tokens ?? - getModelMaxTokens(tokensModel, providerEndpointMap[provider]) ?? - 4000, + maxContextTokens: (maxContextTokens - maxTokens) * 0.9, }; }; @@ -193,12 +319,9 @@ const initializeClient = async ({ req, res, endpointOption }) => { throw new Error('Agent not found'); } - const { attachments, tool_resources } = await primeResources( - endpointOption.attachments, - primaryAgent.tool_resources, - ); - const agentConfigs = new Map(); + /** @type {Set} */ + const allowedProviders = new Set(req?.app?.locals?.[EModelEndpoint.agents]?.allowedProviders); // Handle primary agent const primaryConfig = await initializeAgentOptions({ @@ -206,7 +329,7 @@ const initializeClient = async ({ req, res, endpointOption }) => { res, agent: primaryAgent, endpointOption, - tool_resources, + allowedProviders, isInitialAgent: true, }); @@ -222,6 +345,7 @@ const initializeClient = async ({ req, res, endpointOption }) => { res, agent, endpointOption, + allowedProviders, }); agentConfigs.set(agentId, config); } @@ -236,18 +360,25 @@ const initializeClient = async ({ req, res, endpointOption }) => { const client = new AgentClient({ req, - agent: primaryConfig, + res, sender, - attachments, contentParts, + agentConfigs, eventHandlers, collectedUsage, + aggregateContent, artifactPromises, + agent: primaryConfig, spec: endpointOption.spec, iconURL: endpointOption.iconURL, - agentConfigs, - endpoint: EModelEndpoint.agents, + attachments: primaryConfig.attachments, + endpointType: endpointOption.endpointType, maxContextTokens: primaryConfig.maxContextTokens, + resendFiles: primaryConfig.model_parameters?.resendFiles ?? true, + endpoint: + primaryConfig.id === Constants.EPHEMERAL_AGENT_ID + ? primaryConfig.endpoint + : EModelEndpoint.agents, }); return { client }; diff --git a/api/server/services/Endpoints/agents/title.js b/api/server/services/Endpoints/agents/title.js index 56fd28668d..ab171bc79d 100644 --- a/api/server/services/Endpoints/agents/title.js +++ b/api/server/services/Endpoints/agents/title.js @@ -2,7 +2,11 @@ const { CacheKeys } = require('librechat-data-provider'); const getLogStores = require('~/cache/getLogStores'); const { isEnabled } = require('~/server/utils'); const { saveConvo } = require('~/models'); +const { logger } = require('~/config'); +/** + * Add title to conversation in a way that avoids memory retention + */ const addTitle = async (req, { text, response, client }) => { const { TITLE_CONVO = true } = process.env ?? {}; if (!isEnabled(TITLE_CONVO)) { @@ -13,28 +17,55 @@ const addTitle = async (req, { text, response, client }) => { return; } - // If the request was aborted, don't generate the title. - if (client.abortController.signal.aborted) { - return; - } - const titleCache = getLogStores(CacheKeys.GEN_TITLE); const key = `${req.user.id}-${response.conversationId}`; + /** @type {NodeJS.Timeout} */ + let timeoutId; + try { + const timeoutPromise = new Promise((_, reject) => { + timeoutId = setTimeout(() => reject(new Error('Title generation timeout')), 25000); + }).catch((error) => { + logger.error('Title error:', error); + }); - const title = await client.titleConvo({ - text, - responseText: response?.text ?? '', - conversationId: response.conversationId, - }); - await titleCache.set(key, title, 120000); - await saveConvo( - req, - { - conversationId: response.conversationId, - title, - }, - { context: 'api/server/services/Endpoints/agents/title.js' }, - ); + let titlePromise; + let abortController = new AbortController(); + if (client && typeof client.titleConvo === 'function') { + titlePromise = Promise.race([ + client + .titleConvo({ + text, + abortController, + }) + .catch((error) => { + logger.error('Client title error:', error); + }), + timeoutPromise, + ]); + } else { + return; + } + + const title = await titlePromise; + if (!abortController.signal.aborted) { + abortController.abort(); + } + if (timeoutId) { + clearTimeout(timeoutId); + } + + await titleCache.set(key, title, 120000); + await saveConvo( + req, + { + conversationId: response.conversationId, + title, + }, + { context: 'api/server/services/Endpoints/agents/title.js' }, + ); + } catch (error) { + logger.error('Error generating title:', error); + } }; module.exports = addTitle; diff --git a/api/server/services/Endpoints/anthropic/build.js b/api/server/services/Endpoints/anthropic/build.js index 028da36407..2deab4b975 100644 --- a/api/server/services/Endpoints/anthropic/build.js +++ b/api/server/services/Endpoints/anthropic/build.js @@ -1,4 +1,4 @@ -const { removeNullishValues } = require('librechat-data-provider'); +const { removeNullishValues, anthropicSettings } = require('librechat-data-provider'); const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts'); const buildOptions = (endpoint, parsedBody) => { @@ -6,8 +6,10 @@ const buildOptions = (endpoint, parsedBody) => { modelLabel, promptPrefix, maxContextTokens, - resendFiles = true, - promptCache = true, + resendFiles = anthropicSettings.resendFiles.default, + promptCache = anthropicSettings.promptCache.default, + thinking = anthropicSettings.thinking.default, + thinkingBudget = anthropicSettings.thinkingBudget.default, iconURL, greeting, spec, @@ -21,6 +23,8 @@ const buildOptions = (endpoint, parsedBody) => { promptPrefix, resendFiles, promptCache, + thinking, + thinkingBudget, iconURL, greeting, spec, diff --git a/api/server/services/Endpoints/anthropic/helpers.js b/api/server/services/Endpoints/anthropic/helpers.js new file mode 100644 index 0000000000..04e4efc61c --- /dev/null +++ b/api/server/services/Endpoints/anthropic/helpers.js @@ -0,0 +1,111 @@ +const { EModelEndpoint, anthropicSettings } = require('librechat-data-provider'); +const { matchModelName } = require('~/utils'); +const { logger } = require('~/config'); + +/** + * @param {string} modelName + * @returns {boolean} + */ +function checkPromptCacheSupport(modelName) { + const modelMatch = matchModelName(modelName, EModelEndpoint.anthropic); + if ( + modelMatch.includes('claude-3-5-sonnet-latest') || + modelMatch.includes('claude-3.5-sonnet-latest') + ) { + return false; + } + + if ( + modelMatch === 'claude-3-7-sonnet' || + modelMatch === 'claude-3-5-sonnet' || + modelMatch === 'claude-3-5-haiku' || + modelMatch === 'claude-3-haiku' || + modelMatch === 'claude-3-opus' || + modelMatch === 'claude-3.7-sonnet' || + modelMatch === 'claude-3.5-sonnet' || + modelMatch === 'claude-3.5-haiku' + ) { + return true; + } + + return false; +} + +/** + * Gets the appropriate headers for Claude models with cache control + * @param {string} model The model name + * @param {boolean} supportsCacheControl Whether the model supports cache control + * @returns {AnthropicClientOptions['extendedOptions']['defaultHeaders']|undefined} The headers object or undefined if not applicable + */ +function getClaudeHeaders(model, supportsCacheControl) { + if (!supportsCacheControl) { + return undefined; + } + + if (/claude-3[-.]5-sonnet/.test(model)) { + return { + 'anthropic-beta': 'max-tokens-3-5-sonnet-2024-07-15,prompt-caching-2024-07-31', + }; + } else if (/claude-3[-.]7/.test(model)) { + return { + 'anthropic-beta': + 'token-efficient-tools-2025-02-19,output-128k-2025-02-19,prompt-caching-2024-07-31', + }; + } else { + return { + 'anthropic-beta': 'prompt-caching-2024-07-31', + }; + } +} + +/** + * Configures reasoning-related options for Claude models + * @param {AnthropicClientOptions & { max_tokens?: number }} anthropicInput The request options object + * @param {Object} extendedOptions Additional client configuration options + * @param {boolean} extendedOptions.thinking Whether thinking is enabled in client config + * @param {number|null} extendedOptions.thinkingBudget The token budget for thinking + * @returns {Object} Updated request options + */ +function configureReasoning(anthropicInput, extendedOptions = {}) { + const updatedOptions = { ...anthropicInput }; + const currentMaxTokens = updatedOptions.max_tokens ?? updatedOptions.maxTokens; + if ( + extendedOptions.thinking && + updatedOptions?.model && + /claude-3[-.]7/.test(updatedOptions.model) + ) { + updatedOptions.thinking = { + type: 'enabled', + }; + } + + if (updatedOptions.thinking != null && extendedOptions.thinkingBudget != null) { + updatedOptions.thinking = { + ...updatedOptions.thinking, + budget_tokens: extendedOptions.thinkingBudget, + }; + } + + if ( + updatedOptions.thinking != null && + (currentMaxTokens == null || updatedOptions.thinking.budget_tokens > currentMaxTokens) + ) { + const maxTokens = anthropicSettings.maxOutputTokens.reset(updatedOptions.model); + updatedOptions.max_tokens = currentMaxTokens ?? maxTokens; + + logger.warn( + updatedOptions.max_tokens === maxTokens + ? '[AnthropicClient] max_tokens is not defined while thinking is enabled. Setting max_tokens to model default.' + : `[AnthropicClient] thinking budget_tokens (${updatedOptions.thinking.budget_tokens}) exceeds max_tokens (${updatedOptions.max_tokens}). Adjusting budget_tokens.`, + ); + + updatedOptions.thinking.budget_tokens = Math.min( + updatedOptions.thinking.budget_tokens, + Math.floor(updatedOptions.max_tokens * 0.9), + ); + } + + return updatedOptions; +} + +module.exports = { checkPromptCacheSupport, getClaudeHeaders, configureReasoning }; diff --git a/api/server/services/Endpoints/anthropic/initialize.js b/api/server/services/Endpoints/anthropic/initialize.js index ffd61441be..d4c6dd1795 100644 --- a/api/server/services/Endpoints/anthropic/initialize.js +++ b/api/server/services/Endpoints/anthropic/initialize.js @@ -1,7 +1,7 @@ const { EModelEndpoint } = require('librechat-data-provider'); const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService'); const { getLLMConfig } = require('~/server/services/Endpoints/anthropic/llm'); -const { AnthropicClient } = require('~/app'); +const AnthropicClient = require('~/app/clients/AnthropicClient'); const initializeClient = async ({ req, res, endpointOption, overrideModel, optionsOnly }) => { const { ANTHROPIC_API_KEY, ANTHROPIC_REVERSE_PROXY, PROXY } = process.env; @@ -27,6 +27,7 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio if (anthropicConfig) { clientOptions.streamRate = anthropicConfig.streamRate; + clientOptions.titleModel = anthropicConfig.titleModel; } /** @type {undefined | TBaseEndpoint} */ diff --git a/api/server/services/Endpoints/anthropic/llm.js b/api/server/services/Endpoints/anthropic/llm.js index 301d42712a..9f20b8e61d 100644 --- a/api/server/services/Endpoints/anthropic/llm.js +++ b/api/server/services/Endpoints/anthropic/llm.js @@ -1,5 +1,6 @@ const { HttpsProxyAgent } = require('https-proxy-agent'); const { anthropicSettings, removeNullishValues } = require('librechat-data-provider'); +const { checkPromptCacheSupport, getClaudeHeaders, configureReasoning } = require('./helpers'); /** * Generates configuration options for creating an Anthropic language model (LLM) instance. @@ -20,6 +21,14 @@ const { anthropicSettings, removeNullishValues } = require('librechat-data-provi * @returns {Object} Configuration options for creating an Anthropic LLM instance, with null and undefined values removed. */ function getLLMConfig(apiKey, options = {}) { + const systemOptions = { + thinking: options.modelOptions.thinking ?? anthropicSettings.thinking.default, + promptCache: options.modelOptions.promptCache ?? anthropicSettings.promptCache.default, + thinkingBudget: options.modelOptions.thinkingBudget ?? anthropicSettings.thinkingBudget.default, + }; + for (let key in systemOptions) { + delete options.modelOptions[key]; + } const defaultOptions = { model: anthropicSettings.model.default, maxOutputTokens: anthropicSettings.maxOutputTokens.default, @@ -29,19 +38,34 @@ function getLLMConfig(apiKey, options = {}) { const mergedOptions = Object.assign(defaultOptions, options.modelOptions); /** @type {AnthropicClientOptions} */ - const requestOptions = { + let requestOptions = { apiKey, model: mergedOptions.model, stream: mergedOptions.stream, temperature: mergedOptions.temperature, - topP: mergedOptions.topP, - topK: mergedOptions.topK, stopSequences: mergedOptions.stop, maxTokens: mergedOptions.maxOutputTokens || anthropicSettings.maxOutputTokens.reset(mergedOptions.model), clientOptions: {}, }; + requestOptions = configureReasoning(requestOptions, systemOptions); + + if (!/claude-3[-.]7/.test(mergedOptions.model)) { + requestOptions.topP = mergedOptions.topP; + requestOptions.topK = mergedOptions.topK; + } else if (requestOptions.thinking == null) { + requestOptions.topP = mergedOptions.topP; + requestOptions.topK = mergedOptions.topK; + } + + const supportsCacheControl = + systemOptions.promptCache === true && checkPromptCacheSupport(requestOptions.model); + const headers = getClaudeHeaders(requestOptions.model, supportsCacheControl); + if (headers) { + requestOptions.clientOptions.defaultHeaders = headers; + } + if (options.proxy) { requestOptions.clientOptions.httpAgent = new HttpsProxyAgent(options.proxy); } diff --git a/api/server/services/Endpoints/anthropic/llm.spec.js b/api/server/services/Endpoints/anthropic/llm.spec.js new file mode 100644 index 0000000000..9c453efb92 --- /dev/null +++ b/api/server/services/Endpoints/anthropic/llm.spec.js @@ -0,0 +1,153 @@ +const { anthropicSettings } = require('librechat-data-provider'); +const { getLLMConfig } = require('~/server/services/Endpoints/anthropic/llm'); + +jest.mock('https-proxy-agent', () => ({ + HttpsProxyAgent: jest.fn().mockImplementation((proxy) => ({ proxy })), +})); + +describe('getLLMConfig', () => { + it('should create a basic configuration with default values', () => { + const result = getLLMConfig('test-api-key', { modelOptions: {} }); + + expect(result.llmConfig).toHaveProperty('apiKey', 'test-api-key'); + expect(result.llmConfig).toHaveProperty('model', anthropicSettings.model.default); + expect(result.llmConfig).toHaveProperty('stream', true); + expect(result.llmConfig).toHaveProperty('maxTokens'); + }); + + it('should include proxy settings when provided', () => { + const result = getLLMConfig('test-api-key', { + modelOptions: {}, + proxy: 'http://proxy:8080', + }); + + expect(result.llmConfig.clientOptions).toHaveProperty('httpAgent'); + expect(result.llmConfig.clientOptions.httpAgent).toHaveProperty('proxy', 'http://proxy:8080'); + }); + + it('should include reverse proxy URL when provided', () => { + const result = getLLMConfig('test-api-key', { + modelOptions: {}, + reverseProxyUrl: 'http://reverse-proxy', + }); + + expect(result.llmConfig.clientOptions).toHaveProperty('baseURL', 'http://reverse-proxy'); + }); + + it('should include topK and topP for non-Claude-3.7 models', () => { + const result = getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3-opus', + topK: 10, + topP: 0.9, + }, + }); + + expect(result.llmConfig).toHaveProperty('topK', 10); + expect(result.llmConfig).toHaveProperty('topP', 0.9); + }); + + it('should include topK and topP for Claude-3.5 models', () => { + const result = getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3-5-sonnet', + topK: 10, + topP: 0.9, + }, + }); + + expect(result.llmConfig).toHaveProperty('topK', 10); + expect(result.llmConfig).toHaveProperty('topP', 0.9); + }); + + it('should NOT include topK and topP for Claude-3-7 models (hyphen notation)', () => { + const result = getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3-7-sonnet', + topK: 10, + topP: 0.9, + }, + }); + + expect(result.llmConfig).not.toHaveProperty('topK'); + expect(result.llmConfig).not.toHaveProperty('topP'); + }); + + it('should NOT include topK and topP for Claude-3.7 models (decimal notation)', () => { + const result = getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3.7-sonnet', + topK: 10, + topP: 0.9, + }, + }); + + expect(result.llmConfig).not.toHaveProperty('topK'); + expect(result.llmConfig).not.toHaveProperty('topP'); + }); + + it('should handle custom maxOutputTokens', () => { + const result = getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3-opus', + maxOutputTokens: 2048, + }, + }); + + expect(result.llmConfig).toHaveProperty('maxTokens', 2048); + }); + + it('should handle promptCache setting', () => { + const result = getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3-5-sonnet', + promptCache: true, + }, + }); + + // We're not checking specific header values since that depends on the actual helper function + // Just verifying that the promptCache setting is processed + expect(result.llmConfig).toBeDefined(); + }); + + it('should include topK and topP for Claude-3.7 models when thinking is not enabled', () => { + // Test with thinking explicitly set to null/undefined + const result = getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3-7-sonnet', + topK: 10, + topP: 0.9, + thinking: false, + }, + }); + + expect(result.llmConfig).toHaveProperty('topK', 10); + expect(result.llmConfig).toHaveProperty('topP', 0.9); + + // Test with thinking explicitly set to false + const result2 = getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3-7-sonnet', + topK: 10, + topP: 0.9, + thinking: false, + }, + }); + + expect(result2.llmConfig).toHaveProperty('topK', 10); + expect(result2.llmConfig).toHaveProperty('topP', 0.9); + + // Test with decimal notation as well + const result3 = getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3.7-sonnet', + topK: 10, + topP: 0.9, + thinking: false, + }, + }); + + expect(result3.llmConfig).toHaveProperty('topK', 10); + expect(result3.llmConfig).toHaveProperty('topP', 0.9); + }); +}); diff --git a/api/server/services/Endpoints/anthropic/title.js b/api/server/services/Endpoints/anthropic/title.js index 5c477632d2..0f9a5e97d0 100644 --- a/api/server/services/Endpoints/anthropic/title.js +++ b/api/server/services/Endpoints/anthropic/title.js @@ -13,11 +13,6 @@ const addTitle = async (req, { text, response, client }) => { return; } - // If the request was aborted, don't generate the title. - if (client.abortController.signal.aborted) { - return; - } - const titleCache = getLogStores(CacheKeys.GEN_TITLE); const key = `${req.user.id}-${response.conversationId}`; diff --git a/api/server/services/Endpoints/assistants/build.js b/api/server/services/Endpoints/assistants/build.js index 544567dd01..00a2abf606 100644 --- a/api/server/services/Endpoints/assistants/build.js +++ b/api/server/services/Endpoints/assistants/build.js @@ -3,7 +3,6 @@ const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts'); const { getAssistant } = require('~/models/Assistant'); const buildOptions = async (endpoint, parsedBody) => { - const { promptPrefix, assistant_id, iconURL, greeting, spec, artifacts, ...modelOptions } = parsedBody; const endpointOption = removeNullishValues({ diff --git a/api/server/services/Endpoints/bedrock/build.js b/api/server/services/Endpoints/bedrock/build.js index d6fb0636a9..f5228160fc 100644 --- a/api/server/services/Endpoints/bedrock/build.js +++ b/api/server/services/Endpoints/bedrock/build.js @@ -1,6 +1,5 @@ -const { removeNullishValues, bedrockInputParser } = require('librechat-data-provider'); +const { removeNullishValues } = require('librechat-data-provider'); const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts'); -const { logger } = require('~/config'); const buildOptions = (endpoint, parsedBody) => { const { @@ -15,12 +14,6 @@ const buildOptions = (endpoint, parsedBody) => { artifacts, ...model_parameters } = parsedBody; - let parsedParams = model_parameters; - try { - parsedParams = bedrockInputParser.parse(model_parameters); - } catch (error) { - logger.warn('Failed to parse bedrock input', error); - } const endpointOption = removeNullishValues({ endpoint, name, @@ -31,7 +24,7 @@ const buildOptions = (endpoint, parsedBody) => { spec, promptPrefix, maxContextTokens, - model_parameters: parsedParams, + model_parameters, }); if (typeof artifacts === 'string') { diff --git a/api/server/services/Endpoints/bedrock/initialize.js b/api/server/services/Endpoints/bedrock/initialize.js index 3ffa03393d..4d9ba361cf 100644 --- a/api/server/services/Endpoints/bedrock/initialize.js +++ b/api/server/services/Endpoints/bedrock/initialize.js @@ -23,8 +23,9 @@ const initializeClient = async ({ req, res, endpointOption }) => { const agent = { id: EModelEndpoint.bedrock, name: endpointOption.name, - instructions: endpointOption.promptPrefix, provider: EModelEndpoint.bedrock, + endpoint: EModelEndpoint.bedrock, + instructions: endpointOption.promptPrefix, model: endpointOption.model_parameters.model, model_parameters: endpointOption.model_parameters, }; @@ -54,6 +55,7 @@ const initializeClient = async ({ req, res, endpointOption }) => { const client = new AgentClient({ req, + res, agent, sender, // tools, diff --git a/api/server/services/Endpoints/bedrock/options.js b/api/server/services/Endpoints/bedrock/options.js index 11b33a5357..1936a8f483 100644 --- a/api/server/services/Endpoints/bedrock/options.js +++ b/api/server/services/Endpoints/bedrock/options.js @@ -1,14 +1,16 @@ const { HttpsProxyAgent } = require('https-proxy-agent'); const { - EModelEndpoint, - Constants, AuthType, + Constants, + EModelEndpoint, + bedrockInputParser, + bedrockOutputParser, removeNullishValues, } = require('librechat-data-provider'); const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService'); -const { sleep } = require('~/server/utils'); +const { createHandleLLMNewToken } = require('~/app/clients/generators'); -const getOptions = async ({ req, endpointOption }) => { +const getOptions = async ({ req, overrideModel, endpointOption }) => { const { BEDROCK_AWS_SECRET_ACCESS_KEY, BEDROCK_AWS_ACCESS_KEY_ID, @@ -62,39 +64,39 @@ const getOptions = async ({ req, endpointOption }) => { /** @type {BedrockClientOptions} */ const requestOptions = { - model: endpointOption.model, + model: overrideModel ?? endpointOption.model, region: BEDROCK_AWS_DEFAULT_REGION, - streaming: true, - streamUsage: true, - callbacks: [ - { - handleLLMNewToken: async () => { - if (!streamRate) { - return; - } - await sleep(streamRate); - }, - }, - ], }; - if (credentials) { - requestOptions.credentials = credentials; - } - - if (BEDROCK_REVERSE_PROXY) { - requestOptions.endpointHost = BEDROCK_REVERSE_PROXY; - } - const configOptions = {}; if (PROXY) { /** NOTE: NOT SUPPORTED BY BEDROCK */ configOptions.httpAgent = new HttpsProxyAgent(PROXY); } + const llmConfig = bedrockOutputParser( + bedrockInputParser.parse( + removeNullishValues(Object.assign(requestOptions, endpointOption.model_parameters)), + ), + ); + + if (credentials) { + llmConfig.credentials = credentials; + } + + if (BEDROCK_REVERSE_PROXY) { + llmConfig.endpointHost = BEDROCK_REVERSE_PROXY; + } + + llmConfig.callbacks = [ + { + handleLLMNewToken: createHandleLLMNewToken(streamRate), + }, + ]; + return { /** @type {BedrockClientOptions} */ - llmConfig: removeNullishValues(Object.assign(requestOptions, endpointOption.model_parameters)), + llmConfig, configOptions, }; }; diff --git a/api/server/services/Endpoints/custom/initialize.js b/api/server/services/Endpoints/custom/initialize.js index fe2beba582..592440db54 100644 --- a/api/server/services/Endpoints/custom/initialize.js +++ b/api/server/services/Endpoints/custom/initialize.js @@ -9,10 +9,11 @@ const { Providers } = require('@librechat/agents'); const { getUserKeyValues, checkUserKeyExpiry } = require('~/server/services/UserService'); const { getLLMConfig } = require('~/server/services/Endpoints/openAI/llm'); const { getCustomEndpointConfig } = require('~/server/services/Config'); +const { createHandleLLMNewToken } = require('~/app/clients/generators'); const { fetchModels } = require('~/server/services/ModelService'); -const { isUserProvided, sleep } = require('~/server/utils'); +const OpenAIClient = require('~/app/clients/OpenAIClient'); +const { isUserProvided } = require('~/server/utils'); const getLogStores = require('~/cache/getLogStores'); -const { OpenAIClient } = require('~/app'); const { PROXY } = process.env; @@ -141,15 +142,14 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid }, clientOptions, ); - const options = getLLMConfig(apiKey, clientOptions); + clientOptions.modelOptions.user = req.user.id; + const options = getLLMConfig(apiKey, clientOptions, endpoint); if (!customOptions.streamRate) { return options; } options.llmConfig.callbacks = [ { - handleLLMNewToken: async () => { - await sleep(customOptions.streamRate); - }, + handleLLMNewToken: createHandleLLMNewToken(clientOptions.streamRate), }, ]; return options; diff --git a/api/server/services/Endpoints/google/initialize.js b/api/server/services/Endpoints/google/initialize.js index c157dd8b28..b7419a8a87 100644 --- a/api/server/services/Endpoints/google/initialize.js +++ b/api/server/services/Endpoints/google/initialize.js @@ -5,12 +5,7 @@ const { isEnabled } = require('~/server/utils'); const { GoogleClient } = require('~/app'); const initializeClient = async ({ req, res, endpointOption, overrideModel, optionsOnly }) => { - const { - GOOGLE_KEY, - GOOGLE_REVERSE_PROXY, - GOOGLE_AUTH_HEADER, - PROXY, - } = process.env; + const { GOOGLE_KEY, GOOGLE_REVERSE_PROXY, GOOGLE_AUTH_HEADER, PROXY } = process.env; const isUserProvided = GOOGLE_KEY === 'user_provided'; const { key: expiresAt } = req.body; @@ -43,6 +38,7 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio if (googleConfig) { clientOptions.streamRate = googleConfig.streamRate; + clientOptions.titleModel = googleConfig.titleModel; } if (allConfig) { diff --git a/api/server/services/Endpoints/openAI/initialize.js b/api/server/services/Endpoints/openAI/initialize.js index 0eb0d566b9..714ed5a1e6 100644 --- a/api/server/services/Endpoints/openAI/initialize.js +++ b/api/server/services/Endpoints/openAI/initialize.js @@ -6,9 +6,10 @@ const { } = require('librechat-data-provider'); const { getUserKeyValues, checkUserKeyExpiry } = require('~/server/services/UserService'); const { getLLMConfig } = require('~/server/services/Endpoints/openAI/llm'); -const { isEnabled, isUserProvided, sleep } = require('~/server/utils'); +const { createHandleLLMNewToken } = require('~/app/clients/generators'); +const { isEnabled, isUserProvided } = require('~/server/utils'); +const OpenAIClient = require('~/app/clients/OpenAIClient'); const { getAzureCredentials } = require('~/utils'); -const { OpenAIClient } = require('~/app'); const initializeClient = async ({ req, @@ -113,6 +114,7 @@ const initializeClient = async ({ if (!isAzureOpenAI && openAIConfig) { clientOptions.streamRate = openAIConfig.streamRate; + clientOptions.titleModel = openAIConfig.titleModel; } /** @type {undefined | TBaseEndpoint} */ @@ -134,21 +136,18 @@ const initializeClient = async ({ } if (optionsOnly) { - clientOptions = Object.assign( - { - modelOptions: endpointOption.model_parameters, - }, - clientOptions, - ); + const modelOptions = endpointOption.model_parameters; + modelOptions.model = modelName; + clientOptions = Object.assign({ modelOptions }, clientOptions); + clientOptions.modelOptions.user = req.user.id; const options = getLLMConfig(apiKey, clientOptions); - if (!clientOptions.streamRate) { + const streamRate = clientOptions.streamRate; + if (!streamRate) { return options; } options.llmConfig.callbacks = [ { - handleLLMNewToken: async () => { - await sleep(clientOptions.streamRate); - }, + handleLLMNewToken: createHandleLLMNewToken(streamRate), }, ]; return options; diff --git a/api/server/services/Endpoints/openAI/llm.js b/api/server/services/Endpoints/openAI/llm.js index 2587b242c9..c1fd090b28 100644 --- a/api/server/services/Endpoints/openAI/llm.js +++ b/api/server/services/Endpoints/openAI/llm.js @@ -1,4 +1,5 @@ const { HttpsProxyAgent } = require('https-proxy-agent'); +const { KnownEndpoints } = require('librechat-data-provider'); const { sanitizeModelName, constructAzureURL } = require('~/utils'); const { isEnabled } = require('~/server/utils'); @@ -8,6 +9,7 @@ const { isEnabled } = require('~/server/utils'); * @param {Object} options - Additional options for configuring the LLM. * @param {Object} [options.modelOptions] - Model-specific options. * @param {string} [options.modelOptions.model] - The name of the model to use. + * @param {string} [options.modelOptions.user] - The user ID * @param {number} [options.modelOptions.temperature] - Controls randomness in output generation (0-2). * @param {number} [options.modelOptions.top_p] - Controls diversity via nucleus sampling (0-1). * @param {number} [options.modelOptions.frequency_penalty] - Reduces repetition of token sequences (-2 to 2). @@ -22,13 +24,13 @@ const { isEnabled } = require('~/server/utils'); * @param {boolean} [options.streaming] - Whether to use streaming mode. * @param {Object} [options.addParams] - Additional parameters to add to the model options. * @param {string[]} [options.dropParams] - Parameters to remove from the model options. + * @param {string|null} [endpoint=null] - The endpoint name * @returns {Object} Configuration options for creating an LLM instance. */ -function getLLMConfig(apiKey, options = {}) { - const { +function getLLMConfig(apiKey, options = {}, endpoint = null) { + let { modelOptions = {}, reverseProxyUrl, - useOpenRouter, defaultQuery, headers, proxy, @@ -48,19 +50,45 @@ function getLLMConfig(apiKey, options = {}) { if (addParams && typeof addParams === 'object') { Object.assign(llmConfig, addParams); } + /** Note: OpenAI Web Search models do not support any known parameters besdies `max_tokens` */ + if (modelOptions.model && /gpt-4o.*search/.test(modelOptions.model)) { + const searchExcludeParams = [ + 'frequency_penalty', + 'presence_penalty', + 'temperature', + 'top_p', + 'top_k', + 'stop', + 'logit_bias', + 'seed', + 'response_format', + 'n', + 'logprobs', + 'user', + ]; + + dropParams = dropParams || []; + dropParams = [...new Set([...dropParams, ...searchExcludeParams])]; + } if (dropParams && Array.isArray(dropParams)) { dropParams.forEach((param) => { - delete llmConfig[param]; + if (llmConfig[param]) { + llmConfig[param] = undefined; + } }); } + let useOpenRouter; /** @type {OpenAIClientOptions['configuration']} */ const configOptions = {}; - - // Handle OpenRouter or custom reverse proxy - if (useOpenRouter || reverseProxyUrl === 'https://openrouter.ai/api/v1') { - configOptions.baseURL = 'https://openrouter.ai/api/v1'; + if ( + (reverseProxyUrl && reverseProxyUrl.includes(KnownEndpoints.openrouter)) || + (endpoint && endpoint.toLowerCase().includes(KnownEndpoints.openrouter)) + ) { + useOpenRouter = true; + llmConfig.include_reasoning = true; + configOptions.baseURL = reverseProxyUrl; configOptions.defaultHeaders = Object.assign( { 'HTTP-Referer': 'https://librechat.ai', @@ -108,7 +136,7 @@ function getLLMConfig(apiKey, options = {}) { Object.assign(llmConfig, azure); llmConfig.model = llmConfig.azureOpenAIApiDeploymentName; } else { - llmConfig.openAIApiKey = apiKey; + llmConfig.apiKey = apiKey; // Object.assign(llmConfig, { // configuration: { apiKey }, // }); @@ -118,6 +146,19 @@ function getLLMConfig(apiKey, options = {}) { llmConfig.organization = process.env.OPENAI_ORGANIZATION; } + if (useOpenRouter && llmConfig.reasoning_effort != null) { + llmConfig.reasoning = { + effort: llmConfig.reasoning_effort, + }; + delete llmConfig.reasoning_effort; + } + + if (llmConfig?.['max_tokens'] != null) { + /** @type {number} */ + llmConfig.maxTokens = llmConfig['max_tokens']; + delete llmConfig['max_tokens']; + } + return { /** @type {OpenAIClientOptions} */ llmConfig, diff --git a/api/server/services/Endpoints/openAI/title.js b/api/server/services/Endpoints/openAI/title.js index 35291c5e31..3b9e9c82b7 100644 --- a/api/server/services/Endpoints/openAI/title.js +++ b/api/server/services/Endpoints/openAI/title.js @@ -13,11 +13,6 @@ const addTitle = async (req, { text, response, client }) => { return; } - // If the request was aborted and is not azure, don't generate the title. - if (!client.azure && client.abortController.signal.aborted) { - return; - } - const titleCache = getLogStores(CacheKeys.GEN_TITLE); const key = `${req.user.id}-${response.conversationId}`; diff --git a/api/server/services/Files/Audio/STTService.js b/api/server/services/Files/Audio/STTService.js index ea8d6ffaac..d6c8cc4146 100644 --- a/api/server/services/Files/Audio/STTService.js +++ b/api/server/services/Files/Audio/STTService.js @@ -7,6 +7,78 @@ const { getCustomConfig } = require('~/server/services/Config'); const { genAzureEndpoint } = require('~/utils'); const { logger } = require('~/config'); +/** + * Maps MIME types to their corresponding file extensions for audio files. + * @type {Object} + */ +const MIME_TO_EXTENSION_MAP = { + // MP4 container formats + 'audio/mp4': 'm4a', + 'audio/x-m4a': 'm4a', + // Ogg formats + 'audio/ogg': 'ogg', + 'audio/vorbis': 'ogg', + 'application/ogg': 'ogg', + // Wave formats + 'audio/wav': 'wav', + 'audio/x-wav': 'wav', + 'audio/wave': 'wav', + // MP3 formats + 'audio/mp3': 'mp3', + 'audio/mpeg': 'mp3', + 'audio/mpeg3': 'mp3', + // WebM formats + 'audio/webm': 'webm', + // Additional formats + 'audio/flac': 'flac', + 'audio/x-flac': 'flac', +}; + +/** + * Gets the file extension from the MIME type. + * @param {string} mimeType - The MIME type. + * @returns {string} The file extension. + */ +function getFileExtensionFromMime(mimeType) { + // Default fallback + if (!mimeType) { + return 'webm'; + } + + // Direct lookup (fastest) + const extension = MIME_TO_EXTENSION_MAP[mimeType]; + if (extension) { + return extension; + } + + // Try to extract subtype as fallback + const subtype = mimeType.split('/')[1]?.toLowerCase(); + + // If subtype matches a known extension + if (['mp3', 'mp4', 'ogg', 'wav', 'webm', 'm4a', 'flac'].includes(subtype)) { + return subtype === 'mp4' ? 'm4a' : subtype; + } + + // Generic checks for partial matches + if (subtype?.includes('mp4') || subtype?.includes('m4a')) { + return 'm4a'; + } + if (subtype?.includes('ogg')) { + return 'ogg'; + } + if (subtype?.includes('wav')) { + return 'wav'; + } + if (subtype?.includes('mp3') || subtype?.includes('mpeg')) { + return 'mp3'; + } + if (subtype?.includes('webm')) { + return 'webm'; + } + + return 'webm'; // Default fallback +} + /** * Service class for handling Speech-to-Text (STT) operations. * @class @@ -170,8 +242,10 @@ class STTService { throw new Error('Invalid provider'); } + const fileExtension = getFileExtensionFromMime(audioFile.mimetype); + const audioReadStream = Readable.from(audioBuffer); - audioReadStream.path = 'audio.wav'; + audioReadStream.path = `audio.${fileExtension}`; const [url, data, headers] = strategy.call(this, sttSchema, audioReadStream, audioFile); diff --git a/api/server/services/Files/Audio/streamAudio.js b/api/server/services/Files/Audio/streamAudio.js index ac046e68a6..a1d7c7a649 100644 --- a/api/server/services/Files/Audio/streamAudio.js +++ b/api/server/services/Files/Audio/streamAudio.js @@ -1,4 +1,10 @@ -const { CacheKeys, findLastSeparatorIndex, SEPARATORS, Time } = require('librechat-data-provider'); +const { + Time, + CacheKeys, + SEPARATORS, + parseTextParts, + findLastSeparatorIndex, +} = require('librechat-data-provider'); const { getMessage } = require('~/models/Message'); const { getLogStores } = require('~/cache'); @@ -84,10 +90,11 @@ function createChunkProcessor(user, messageId) { notFoundCount++; return []; } else { + const text = message.content?.length > 0 ? parseTextParts(message.content) : message.text; messageCache.set( messageId, { - text: message.text, + text, complete: true, }, Time.FIVE_MINUTES, @@ -95,7 +102,7 @@ function createChunkProcessor(user, messageId) { } const text = typeof message === 'string' ? message : message.text; - const complete = typeof message === 'string' ? false : message.complete ?? true; + const complete = typeof message === 'string' ? false : (message.complete ?? true); if (text === processedText) { noChangeCount++; diff --git a/api/server/services/Files/Azure/crud.js b/api/server/services/Files/Azure/crud.js new file mode 100644 index 0000000000..cb52de8317 --- /dev/null +++ b/api/server/services/Files/Azure/crud.js @@ -0,0 +1,253 @@ +const fs = require('fs'); +const path = require('path'); +const mime = require('mime'); +const axios = require('axios'); +const fetch = require('node-fetch'); +const { logger } = require('~/config'); +const { getAzureContainerClient } = require('./initialize'); + +const defaultBasePath = 'images'; +const { AZURE_STORAGE_PUBLIC_ACCESS = 'true', AZURE_CONTAINER_NAME = 'files' } = process.env; + +/** + * Uploads a buffer to Azure Blob Storage. + * + * Files will be stored at the path: {basePath}/{userId}/{fileName} within the container. + * + * @param {Object} params + * @param {string} params.userId - The user's id. + * @param {Buffer} params.buffer - The buffer to upload. + * @param {string} params.fileName - The name of the file. + * @param {string} [params.basePath='images'] - The base folder within the container. + * @param {string} [params.containerName] - The Azure Blob container name. + * @returns {Promise} The URL of the uploaded blob. + */ +async function saveBufferToAzure({ + userId, + buffer, + fileName, + basePath = defaultBasePath, + containerName, +}) { + try { + const containerClient = getAzureContainerClient(containerName); + const access = AZURE_STORAGE_PUBLIC_ACCESS?.toLowerCase() === 'true' ? 'blob' : undefined; + // Create the container if it doesn't exist. This is done per operation. + await containerClient.createIfNotExists({ access }); + const blobPath = `${basePath}/${userId}/${fileName}`; + const blockBlobClient = containerClient.getBlockBlobClient(blobPath); + await blockBlobClient.uploadData(buffer); + return blockBlobClient.url; + } catch (error) { + logger.error('[saveBufferToAzure] Error uploading buffer:', error); + throw error; + } +} + +/** + * Saves a file from a URL to Azure Blob Storage. + * + * @param {Object} params + * @param {string} params.userId - The user's id. + * @param {string} params.URL - The URL of the file. + * @param {string} params.fileName - The name of the file. + * @param {string} [params.basePath='images'] - The base folder within the container. + * @param {string} [params.containerName] - The Azure Blob container name. + * @returns {Promise} The URL of the uploaded blob. + */ +async function saveURLToAzure({ + userId, + URL, + fileName, + basePath = defaultBasePath, + containerName, +}) { + try { + const response = await fetch(URL); + const buffer = await response.buffer(); + return await saveBufferToAzure({ userId, buffer, fileName, basePath, containerName }); + } catch (error) { + logger.error('[saveURLToAzure] Error uploading file from URL:', error); + throw error; + } +} + +/** + * Retrieves a blob URL from Azure Blob Storage. + * + * @param {Object} params + * @param {string} params.fileName - The file name. + * @param {string} [params.basePath='images'] - The base folder used during upload. + * @param {string} [params.userId] - If files are stored in a user-specific directory. + * @param {string} [params.containerName] - The Azure Blob container name. + * @returns {Promise} The blob's URL. + */ +async function getAzureURL({ fileName, basePath = defaultBasePath, userId, containerName }) { + try { + const containerClient = getAzureContainerClient(containerName); + const blobPath = userId ? `${basePath}/${userId}/${fileName}` : `${basePath}/${fileName}`; + const blockBlobClient = containerClient.getBlockBlobClient(blobPath); + return blockBlobClient.url; + } catch (error) { + logger.error('[getAzureURL] Error retrieving blob URL:', error); + throw error; + } +} + +/** + * Deletes a blob from Azure Blob Storage. + * + * @param {Object} params + * @param {ServerRequest} params.req - The Express request object. + * @param {MongoFile} params.file - The file object. + */ +async function deleteFileFromAzure(req, file) { + try { + const containerClient = getAzureContainerClient(AZURE_CONTAINER_NAME); + const blobPath = file.filepath.split(`${AZURE_CONTAINER_NAME}/`)[1]; + if (!blobPath.includes(req.user.id)) { + throw new Error('User ID not found in blob path'); + } + const blockBlobClient = containerClient.getBlockBlobClient(blobPath); + await blockBlobClient.delete(); + logger.debug('[deleteFileFromAzure] Blob deleted successfully from Azure Blob Storage'); + } catch (error) { + logger.error('[deleteFileFromAzure] Error deleting blob:', error); + if (error.statusCode === 404) { + return; + } + throw error; + } +} + +/** + * Streams a file from disk directly to Azure Blob Storage without loading + * the entire file into memory. + * + * @param {Object} params + * @param {string} params.userId - The user's id. + * @param {string} params.filePath - The local file path to upload. + * @param {string} params.fileName - The name of the file in Azure. + * @param {string} [params.basePath='images'] - The base folder within the container. + * @param {string} [params.containerName] - The Azure Blob container name. + * @returns {Promise} The URL of the uploaded blob. + */ +async function streamFileToAzure({ + userId, + filePath, + fileName, + basePath = defaultBasePath, + containerName, +}) { + try { + const containerClient = getAzureContainerClient(containerName); + const access = AZURE_STORAGE_PUBLIC_ACCESS?.toLowerCase() === 'true' ? 'blob' : undefined; + + // Create the container if it doesn't exist + await containerClient.createIfNotExists({ access }); + + const blobPath = `${basePath}/${userId}/${fileName}`; + const blockBlobClient = containerClient.getBlockBlobClient(blobPath); + + // Get file size for proper content length + const stats = await fs.promises.stat(filePath); + + // Create read stream from the file + const fileStream = fs.createReadStream(filePath); + + const blobContentType = mime.getType(fileName); + await blockBlobClient.uploadStream( + fileStream, + undefined, // Use default concurrency (5) + undefined, // Use default buffer size (8MB) + { + blobHTTPHeaders: { + blobContentType, + }, + onProgress: (progress) => { + logger.debug( + `[streamFileToAzure] Upload progress: ${progress.loadedBytes} bytes of ${stats.size}`, + ); + }, + }, + ); + + return blockBlobClient.url; + } catch (error) { + logger.error('[streamFileToAzure] Error streaming file:', error); + throw error; + } +} + +/** + * Uploads a file from the local file system to Azure Blob Storage. + * + * This function reads the file from disk and then uploads it to Azure Blob Storage + * at the path: {basePath}/{userId}/{fileName}. + * + * @param {Object} params + * @param {object} params.req - The Express request object. + * @param {Express.Multer.File} params.file - The file object. + * @param {string} params.file_id - The file id. + * @param {string} [params.basePath='images'] - The base folder within the container. + * @param {string} [params.containerName] - The Azure Blob container name. + * @returns {Promise<{ filepath: string, bytes: number }>} An object containing the blob URL and its byte size. + */ +async function uploadFileToAzure({ + req, + file, + file_id, + basePath = defaultBasePath, + containerName, +}) { + try { + const inputFilePath = file.path; + const stats = await fs.promises.stat(inputFilePath); + const bytes = stats.size; + const userId = req.user.id; + const fileName = `${file_id}__${path.basename(inputFilePath)}`; + + const fileURL = await streamFileToAzure({ + userId, + filePath: inputFilePath, + fileName, + basePath, + containerName, + }); + + return { filepath: fileURL, bytes }; + } catch (error) { + logger.error('[uploadFileToAzure] Error uploading file:', error); + throw error; + } +} + +/** + * Retrieves a readable stream for a blob from Azure Blob Storage. + * + * @param {object} _req - The Express request object. + * @param {string} fileURL - The URL of the blob. + * @returns {Promise} A readable stream of the blob. + */ +async function getAzureFileStream(_req, fileURL) { + try { + const response = await axios({ + method: 'get', + url: fileURL, + responseType: 'stream', + }); + return response.data; + } catch (error) { + logger.error('[getAzureFileStream] Error getting blob stream:', error); + throw error; + } +} + +module.exports = { + saveBufferToAzure, + saveURLToAzure, + getAzureURL, + deleteFileFromAzure, + uploadFileToAzure, + getAzureFileStream, +}; diff --git a/api/server/services/Files/Azure/images.js b/api/server/services/Files/Azure/images.js new file mode 100644 index 0000000000..a83b700af3 --- /dev/null +++ b/api/server/services/Files/Azure/images.js @@ -0,0 +1,124 @@ +const fs = require('fs'); +const path = require('path'); +const sharp = require('sharp'); +const { resizeImageBuffer } = require('../images/resize'); +const { updateUser } = require('~/models/userMethods'); +const { updateFile } = require('~/models/File'); +const { logger } = require('~/config'); +const { saveBufferToAzure } = require('./crud'); + +/** + * Uploads an image file to Azure Blob Storage. + * It resizes and converts the image similar to your Firebase implementation. + * + * @param {Object} params + * @param {object} params.req - The Express request object. + * @param {Express.Multer.File} params.file - The file object. + * @param {string} params.file_id - The file id. + * @param {EModelEndpoint} params.endpoint - The endpoint parameters. + * @param {string} [params.resolution='high'] - The image resolution. + * @param {string} [params.basePath='images'] - The base folder within the container. + * @param {string} [params.containerName] - The Azure Blob container name. + * @returns {Promise<{ filepath: string, bytes: number, width: number, height: number }>} + */ +async function uploadImageToAzure({ + req, + file, + file_id, + endpoint, + resolution = 'high', + basePath = 'images', + containerName, +}) { + try { + const inputFilePath = file.path; + const inputBuffer = await fs.promises.readFile(inputFilePath); + const { + buffer: resizedBuffer, + width, + height, + } = await resizeImageBuffer(inputBuffer, resolution, endpoint); + const extension = path.extname(inputFilePath); + const userId = req.user.id; + let webPBuffer; + let fileName = `${file_id}__${path.basename(inputFilePath)}`; + const targetExtension = `.${req.app.locals.imageOutputType}`; + + if (extension.toLowerCase() === targetExtension) { + webPBuffer = resizedBuffer; + } else { + webPBuffer = await sharp(resizedBuffer).toFormat(req.app.locals.imageOutputType).toBuffer(); + const extRegExp = new RegExp(path.extname(fileName) + '$'); + fileName = fileName.replace(extRegExp, targetExtension); + if (!path.extname(fileName)) { + fileName += targetExtension; + } + } + const downloadURL = await saveBufferToAzure({ + userId, + buffer: webPBuffer, + fileName, + basePath, + containerName, + }); + await fs.promises.unlink(inputFilePath); + const bytes = Buffer.byteLength(webPBuffer); + return { filepath: downloadURL, bytes, width, height }; + } catch (error) { + logger.error('[uploadImageToAzure] Error uploading image:', error); + throw error; + } +} + +/** + * Prepares the image URL and updates the file record. + * + * @param {object} req - The Express request object. + * @param {MongoFile} file - The file object. + * @returns {Promise<[MongoFile, string]>} + */ +async function prepareAzureImageURL(req, file) { + const { filepath } = file; + const promises = []; + promises.push(updateFile({ file_id: file.file_id })); + promises.push(filepath); + return await Promise.all(promises); +} + +/** + * Uploads and processes a user's avatar to Azure Blob Storage. + * + * @param {Object} params + * @param {Buffer} params.buffer - The avatar image buffer. + * @param {string} params.userId - The user's id. + * @param {string} params.manual - Flag to indicate manual update. + * @param {string} [params.basePath='images'] - The base folder within the container. + * @param {string} [params.containerName] - The Azure Blob container name. + * @returns {Promise} The URL of the avatar. + */ +async function processAzureAvatar({ buffer, userId, manual, basePath = 'images', containerName }) { + try { + const downloadURL = await saveBufferToAzure({ + userId, + buffer, + fileName: 'avatar.png', + basePath, + containerName, + }); + const isManual = manual === 'true'; + const url = `${downloadURL}?manual=${isManual}`; + if (isManual) { + await updateUser(userId, { avatar: url }); + } + return url; + } catch (error) { + logger.error('[processAzureAvatar] Error uploading profile picture to Azure:', error); + throw error; + } +} + +module.exports = { + uploadImageToAzure, + prepareAzureImageURL, + processAzureAvatar, +}; diff --git a/api/server/services/Files/Azure/index.js b/api/server/services/Files/Azure/index.js new file mode 100644 index 0000000000..27ad97a852 --- /dev/null +++ b/api/server/services/Files/Azure/index.js @@ -0,0 +1,9 @@ +const crud = require('./crud'); +const images = require('./images'); +const initialize = require('./initialize'); + +module.exports = { + ...crud, + ...images, + ...initialize, +}; diff --git a/api/server/services/Files/Azure/initialize.js b/api/server/services/Files/Azure/initialize.js new file mode 100644 index 0000000000..56df24d04a --- /dev/null +++ b/api/server/services/Files/Azure/initialize.js @@ -0,0 +1,55 @@ +const { BlobServiceClient } = require('@azure/storage-blob'); +const { logger } = require('~/config'); + +let blobServiceClient = null; +let azureWarningLogged = false; + +/** + * Initializes the Azure Blob Service client. + * This function establishes a connection by checking if a connection string is provided. + * If available, the connection string is used; otherwise, Managed Identity (via DefaultAzureCredential) is utilized. + * Note: Container creation (and its public access settings) is handled later in the CRUD functions. + * @returns {BlobServiceClient|null} The initialized client, or null if the required configuration is missing. + */ +const initializeAzureBlobService = () => { + if (blobServiceClient) { + return blobServiceClient; + } + const connectionString = process.env.AZURE_STORAGE_CONNECTION_STRING; + if (connectionString) { + blobServiceClient = BlobServiceClient.fromConnectionString(connectionString); + logger.info('Azure Blob Service initialized using connection string'); + } else { + const { DefaultAzureCredential } = require('@azure/identity'); + const accountName = process.env.AZURE_STORAGE_ACCOUNT_NAME; + if (!accountName) { + if (!azureWarningLogged) { + logger.error( + '[initializeAzureBlobService] Azure Blob Service not initialized. Connection string missing and AZURE_STORAGE_ACCOUNT_NAME not provided.', + ); + azureWarningLogged = true; + } + return null; + } + const url = `https://${accountName}.blob.core.windows.net`; + const credential = new DefaultAzureCredential(); + blobServiceClient = new BlobServiceClient(url, credential); + logger.info('Azure Blob Service initialized using Managed Identity'); + } + return blobServiceClient; +}; + +/** + * Retrieves the Azure ContainerClient for the given container name. + * @param {string} [containerName=process.env.AZURE_CONTAINER_NAME || 'files'] - The container name. + * @returns {ContainerClient|null} The Azure ContainerClient. + */ +const getAzureContainerClient = (containerName = process.env.AZURE_CONTAINER_NAME || 'files') => { + const serviceClient = initializeAzureBlobService(); + return serviceClient ? serviceClient.getContainerClient(containerName) : null; +}; + +module.exports = { + initializeAzureBlobService, + getAzureContainerClient, +}; diff --git a/api/server/services/Files/Code/crud.js b/api/server/services/Files/Code/crud.js index 076a4d9f13..caea9ab30a 100644 --- a/api/server/services/Files/Code/crud.js +++ b/api/server/services/Files/Code/crud.js @@ -1,7 +1,9 @@ -// Code Files -const axios = require('axios'); const FormData = require('form-data'); const { getCodeBaseURL } = require('@librechat/agents'); +const { createAxiosInstance } = require('~/config'); +const { logAxiosError } = require('~/utils'); + +const axios = createAxiosInstance(); const MAX_FILE_SIZE = 150 * 1024 * 1024; @@ -15,7 +17,8 @@ const MAX_FILE_SIZE = 150 * 1024 * 1024; async function getCodeOutputDownloadStream(fileIdentifier, apiKey) { try { const baseURL = getCodeBaseURL(); - const response = await axios({ + /** @type {import('axios').AxiosRequestConfig} */ + const options = { method: 'get', url: `${baseURL}/download/${fileIdentifier}`, responseType: 'stream', @@ -24,11 +27,17 @@ async function getCodeOutputDownloadStream(fileIdentifier, apiKey) { 'X-API-Key': apiKey, }, timeout: 15000, - }); + }; + const response = await axios(options); return response; } catch (error) { - throw new Error(`Error downloading file: ${error.message}`); + throw new Error( + logAxiosError({ + message: `Error downloading code environment file stream: ${error.message}`, + error, + }), + ); } } @@ -53,7 +62,8 @@ async function uploadCodeEnvFile({ req, stream, filename, apiKey, entity_id = '' form.append('file', stream, filename); const baseURL = getCodeBaseURL(); - const response = await axios.post(`${baseURL}/upload`, form, { + /** @type {import('axios').AxiosRequestConfig} */ + const options = { headers: { ...form.getHeaders(), 'Content-Type': 'multipart/form-data', @@ -63,7 +73,9 @@ async function uploadCodeEnvFile({ req, stream, filename, apiKey, entity_id = '' }, maxContentLength: MAX_FILE_SIZE, maxBodyLength: MAX_FILE_SIZE, - }); + }; + + const response = await axios.post(`${baseURL}/upload`, form, options); /** @type {{ message: string; session_id: string; files: Array<{ fileId: string; filename: string }> }} */ const result = response.data; @@ -78,7 +90,12 @@ async function uploadCodeEnvFile({ req, stream, filename, apiKey, entity_id = '' return `${fileIdentifier}?entity_id=${entity_id}`; } catch (error) { - throw new Error(`Error uploading file: ${error.message}`); + throw new Error( + logAxiosError({ + message: `Error uploading code environment file: ${error.message}`, + error, + }), + ); } } diff --git a/api/server/services/Files/Code/process.js b/api/server/services/Files/Code/process.js index 2a941a4647..c92e628589 100644 --- a/api/server/services/Files/Code/process.js +++ b/api/server/services/Files/Code/process.js @@ -12,6 +12,7 @@ const { const { getStrategyFunctions } = require('~/server/services/Files/strategies'); const { convertImage } = require('~/server/services/Files/images/convert'); const { createFile, getFiles, updateFile } = require('~/models/File'); +const { logAxiosError } = require('~/utils'); const { logger } = require('~/config'); /** @@ -85,7 +86,10 @@ const processCodeOutput = async ({ /** Note: `messageId` & `toolCallId` are not part of file DB schema; message object records associated file ID */ return Object.assign(file, { messageId, toolCallId }); } catch (error) { - logger.error('Error downloading file:', error); + logAxiosError({ + message: 'Error downloading code environment file', + error, + }); } }; @@ -135,7 +139,10 @@ async function getSessionInfo(fileIdentifier, apiKey) { return response.data.find((file) => file.name.startsWith(path))?.lastModified; } catch (error) { - logger.error(`Error fetching session info: ${error.message}`, error); + logAxiosError({ + message: `Error fetching session info: ${error.message}`, + error, + }); return null; } } @@ -202,7 +209,7 @@ const primeFiles = async (options, apiKey) => { const { handleFileUpload: uploadCodeEnvFile } = getStrategyFunctions( FileSources.execute_code, ); - const stream = await getDownloadStream(file.filepath); + const stream = await getDownloadStream(options.req, file.filepath); const fileIdentifier = await uploadCodeEnvFile({ req: options.req, stream, diff --git a/api/server/services/Files/Firebase/crud.js b/api/server/services/Files/Firebase/crud.js index 76a6c1d8d4..8319f908ef 100644 --- a/api/server/services/Files/Firebase/crud.js +++ b/api/server/services/Files/Firebase/crud.js @@ -224,10 +224,11 @@ async function uploadFileToFirebase({ req, file, file_id }) { /** * Retrieves a readable stream for a file from Firebase storage. * + * @param {ServerRequest} _req * @param {string} filepath - The filepath. * @returns {Promise} A readable stream of the file. */ -async function getFirebaseFileStream(filepath) { +async function getFirebaseFileStream(_req, filepath) { try { const storage = getFirebaseStorage(); if (!storage) { diff --git a/api/server/services/Files/Local/crud.js b/api/server/services/Files/Local/crud.js index e004eab79e..783230f2f6 100644 --- a/api/server/services/Files/Local/crud.js +++ b/api/server/services/Files/Local/crud.js @@ -175,6 +175,17 @@ const isValidPath = (req, base, subfolder, filepath) => { return normalizedFilepath.startsWith(normalizedBase); }; +/** + * @param {string} filepath + */ +const unlinkFile = async (filepath) => { + try { + await fs.promises.unlink(filepath); + } catch (error) { + logger.error('Error deleting file:', error); + } +}; + /** * Deletes a file from the filesystem. This function takes a file object, constructs the full path, and * verifies the path's validity before deleting the file. If the path is invalid, an error is thrown. @@ -217,7 +228,7 @@ const deleteLocalFile = async (req, file) => { throw new Error(`Invalid file path: ${file.filepath}`); } - await fs.promises.unlink(filepath); + await unlinkFile(filepath); return; } @@ -233,7 +244,7 @@ const deleteLocalFile = async (req, file) => { throw new Error('Invalid file path'); } - await fs.promises.unlink(filepath); + await unlinkFile(filepath); }; /** @@ -275,11 +286,49 @@ async function uploadLocalFile({ req, file, file_id }) { /** * Retrieves a readable stream for a file from local storage. * + * @param {ServerRequest} req - The request object from Express * @param {string} filepath - The filepath. * @returns {ReadableStream} A readable stream of the file. */ -function getLocalFileStream(filepath) { +function getLocalFileStream(req, filepath) { try { + if (filepath.includes('/uploads/')) { + const basePath = filepath.split('/uploads/')[1]; + + if (!basePath) { + logger.warn(`Invalid base path: ${filepath}`); + throw new Error(`Invalid file path: ${filepath}`); + } + + const fullPath = path.join(req.app.locals.paths.uploads, basePath); + const uploadsDir = req.app.locals.paths.uploads; + + const rel = path.relative(uploadsDir, fullPath); + if (rel.startsWith('..') || path.isAbsolute(rel) || rel.includes(`..${path.sep}`)) { + logger.warn(`Invalid relative file path: ${filepath}`); + throw new Error(`Invalid file path: ${filepath}`); + } + + return fs.createReadStream(fullPath); + } else if (filepath.includes('/images/')) { + const basePath = filepath.split('/images/')[1]; + + if (!basePath) { + logger.warn(`Invalid base path: ${filepath}`); + throw new Error(`Invalid file path: ${filepath}`); + } + + const fullPath = path.join(req.app.locals.paths.imageOutput, basePath); + const publicDir = req.app.locals.paths.imageOutput; + + const rel = path.relative(publicDir, fullPath); + if (rel.startsWith('..') || path.isAbsolute(rel) || rel.includes(`..${path.sep}`)) { + logger.warn(`Invalid relative file path: ${filepath}`); + throw new Error(`Invalid file path: ${filepath}`); + } + + return fs.createReadStream(fullPath); + } return fs.createReadStream(filepath); } catch (error) { logger.error('Error getting local file stream:', error); diff --git a/api/server/services/Files/MistralOCR/crud.js b/api/server/services/Files/MistralOCR/crud.js new file mode 100644 index 0000000000..0c544b9eb4 --- /dev/null +++ b/api/server/services/Files/MistralOCR/crud.js @@ -0,0 +1,230 @@ +// ~/server/services/Files/MistralOCR/crud.js +const fs = require('fs'); +const path = require('path'); +const FormData = require('form-data'); +const { FileSources, envVarRegex, extractEnvVariable } = require('librechat-data-provider'); +const { loadAuthValues } = require('~/server/services/Tools/credentials'); +const { logger, createAxiosInstance } = require('~/config'); +const { logAxiosError } = require('~/utils/axios'); + +const axios = createAxiosInstance(); + +/** + * Uploads a document to Mistral API using file streaming to avoid loading the entire file into memory + * + * @param {Object} params Upload parameters + * @param {string} params.filePath The path to the file on disk + * @param {string} [params.fileName] Optional filename to use (defaults to the name from filePath) + * @param {string} params.apiKey Mistral API key + * @param {string} [params.baseURL=https://api.mistral.ai/v1] Mistral API base URL + * @returns {Promise} The response from Mistral API + */ +async function uploadDocumentToMistral({ + filePath, + fileName = '', + apiKey, + baseURL = 'https://api.mistral.ai/v1', +}) { + const form = new FormData(); + form.append('purpose', 'ocr'); + const actualFileName = fileName || path.basename(filePath); + const fileStream = fs.createReadStream(filePath); + form.append('file', fileStream, { filename: actualFileName }); + + return axios + .post(`${baseURL}/files`, form, { + headers: { + Authorization: `Bearer ${apiKey}`, + ...form.getHeaders(), + }, + maxBodyLength: Infinity, + maxContentLength: Infinity, + }) + .then((res) => res.data) + .catch((error) => { + logger.error('Error uploading document to Mistral:', error.message); + throw error; + }); +} + +async function getSignedUrl({ + apiKey, + fileId, + expiry = 24, + baseURL = 'https://api.mistral.ai/v1', +}) { + return axios + .get(`${baseURL}/files/${fileId}/url?expiry=${expiry}`, { + headers: { + Authorization: `Bearer ${apiKey}`, + }, + }) + .then((res) => res.data) + .catch((error) => { + logger.error('Error fetching signed URL:', error.message); + throw error; + }); +} + +/** + * @param {Object} params + * @param {string} params.apiKey + * @param {string} params.url - The document or image URL + * @param {string} [params.documentType='document_url'] - 'document_url' or 'image_url' + * @param {string} [params.model] + * @param {string} [params.baseURL] + * @returns {Promise} + */ +async function performOCR({ + apiKey, + url, + documentType = 'document_url', + model = 'mistral-ocr-latest', + baseURL = 'https://api.mistral.ai/v1', +}) { + const documentKey = documentType === 'image_url' ? 'image_url' : 'document_url'; + return axios + .post( + `${baseURL}/ocr`, + { + model, + include_image_base64: false, + document: { + type: documentType, + [documentKey]: url, + }, + }, + { + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${apiKey}`, + }, + }, + ) + .then((res) => res.data) + .catch((error) => { + logger.error('Error performing OCR:', error.message); + throw error; + }); +} + +function extractVariableName(str) { + const match = str.match(envVarRegex); + return match ? match[1] : null; +} + +/** + * Uploads a file to the Mistral OCR API and processes the OCR result. + * + * @param {Object} params - The params object. + * @param {ServerRequest} params.req - The request object from Express. It should have a `user` property with an `id` + * representing the user + * @param {Express.Multer.File} params.file - The file object, which is part of the request. The file object should + * have a `mimetype` property that tells us the file type + * @param {string} params.file_id - The file ID. + * @param {string} [params.entity_id] - The entity ID, not used here but passed for consistency. + * @returns {Promise<{ filepath: string, bytes: number }>} - The result object containing the processed `text` and `images` (not currently used), + * along with the `filename` and `bytes` properties. + */ +const uploadMistralOCR = async ({ req, file, file_id, entity_id }) => { + try { + /** @type {TCustomConfig['ocr']} */ + const ocrConfig = req.app.locals?.ocr; + + const apiKeyConfig = ocrConfig.apiKey || ''; + const baseURLConfig = ocrConfig.baseURL || ''; + + const isApiKeyEnvVar = envVarRegex.test(apiKeyConfig); + const isBaseURLEnvVar = envVarRegex.test(baseURLConfig); + + const isApiKeyEmpty = !apiKeyConfig.trim(); + const isBaseURLEmpty = !baseURLConfig.trim(); + + let apiKey, baseURL; + + if (isApiKeyEnvVar || isBaseURLEnvVar || isApiKeyEmpty || isBaseURLEmpty) { + const apiKeyVarName = isApiKeyEnvVar ? extractVariableName(apiKeyConfig) : 'OCR_API_KEY'; + const baseURLVarName = isBaseURLEnvVar ? extractVariableName(baseURLConfig) : 'OCR_BASEURL'; + + const authValues = await loadAuthValues({ + userId: req.user.id, + authFields: [baseURLVarName, apiKeyVarName], + optional: new Set([baseURLVarName]), + }); + + apiKey = authValues[apiKeyVarName]; + baseURL = authValues[baseURLVarName]; + } else { + apiKey = apiKeyConfig; + baseURL = baseURLConfig; + } + + const mistralFile = await uploadDocumentToMistral({ + filePath: file.path, + fileName: file.originalname, + apiKey, + baseURL, + }); + + const modelConfig = ocrConfig.mistralModel || ''; + const model = envVarRegex.test(modelConfig) + ? extractEnvVariable(modelConfig) + : modelConfig.trim() || 'mistral-ocr-latest'; + + const signedUrlResponse = await getSignedUrl({ + apiKey, + baseURL, + fileId: mistralFile.id, + }); + + const mimetype = (file.mimetype || '').toLowerCase(); + const originalname = file.originalname || ''; + const isImage = + mimetype.startsWith('image') || /\.(png|jpe?g|gif|bmp|webp|tiff?)$/i.test(originalname); + const documentType = isImage ? 'image_url' : 'document_url'; + + const ocrResult = await performOCR({ + apiKey, + baseURL, + model, + url: signedUrlResponse.url, + documentType, + }); + + let aggregatedText = ''; + const images = []; + ocrResult.pages.forEach((page, index) => { + if (ocrResult.pages.length > 1) { + aggregatedText += `# PAGE ${index + 1}\n`; + } + + aggregatedText += page.markdown + '\n\n'; + + if (page.images && page.images.length > 0) { + page.images.forEach((image) => { + if (image.image_base64) { + images.push(image.image_base64); + } + }); + } + }); + + return { + filename: file.originalname, + bytes: aggregatedText.length * 4, + filepath: FileSources.mistral_ocr, + text: aggregatedText, + images, + }; + } catch (error) { + const message = 'Error uploading document to Mistral OCR API'; + throw new Error(logAxiosError({ error, message })); + } +}; + +module.exports = { + uploadDocumentToMistral, + uploadMistralOCR, + getSignedUrl, + performOCR, +}; diff --git a/api/server/services/Files/MistralOCR/crud.spec.js b/api/server/services/Files/MistralOCR/crud.spec.js new file mode 100644 index 0000000000..c3d2f46c40 --- /dev/null +++ b/api/server/services/Files/MistralOCR/crud.spec.js @@ -0,0 +1,852 @@ +const fs = require('fs'); + +const mockAxios = { + interceptors: { + request: { use: jest.fn(), eject: jest.fn() }, + response: { use: jest.fn(), eject: jest.fn() }, + }, + create: jest.fn().mockReturnValue({ + defaults: { + proxy: null, + }, + get: jest.fn().mockResolvedValue({ data: {} }), + post: jest.fn().mockResolvedValue({ data: {} }), + put: jest.fn().mockResolvedValue({ data: {} }), + delete: jest.fn().mockResolvedValue({ data: {} }), + }), + get: jest.fn().mockResolvedValue({ data: {} }), + post: jest.fn().mockResolvedValue({ data: {} }), + put: jest.fn().mockResolvedValue({ data: {} }), + delete: jest.fn().mockResolvedValue({ data: {} }), + reset: jest.fn().mockImplementation(function () { + this.get.mockClear(); + this.post.mockClear(); + this.put.mockClear(); + this.delete.mockClear(); + this.create.mockClear(); + }), +}; + +jest.mock('axios', () => mockAxios); +jest.mock('fs'); +jest.mock('~/config', () => ({ + logger: { + error: jest.fn(), + }, + createAxiosInstance: () => mockAxios, +})); +jest.mock('~/server/services/Tools/credentials', () => ({ + loadAuthValues: jest.fn(), +})); + +const { uploadDocumentToMistral, uploadMistralOCR, getSignedUrl, performOCR } = require('./crud'); + +describe('MistralOCR Service', () => { + afterEach(() => { + mockAxios.reset(); + jest.clearAllMocks(); + }); + + describe('uploadDocumentToMistral', () => { + beforeEach(() => { + // Create a more complete mock for file streams that FormData can work with + const mockReadStream = { + on: jest.fn().mockImplementation(function (event, handler) { + // Simulate immediate 'end' event to make FormData complete processing + if (event === 'end') { + handler(); + } + return this; + }), + pipe: jest.fn().mockImplementation(function () { + return this; + }), + pause: jest.fn(), + resume: jest.fn(), + emit: jest.fn(), + once: jest.fn(), + destroy: jest.fn(), + }; + + fs.createReadStream = jest.fn().mockReturnValue(mockReadStream); + + // Mock FormData's append to avoid actual stream processing + jest.mock('form-data', () => { + const mockFormData = function () { + return { + append: jest.fn(), + getHeaders: jest + .fn() + .mockReturnValue({ 'content-type': 'multipart/form-data; boundary=---boundary' }), + getBuffer: jest.fn().mockReturnValue(Buffer.from('mock-form-data')), + getLength: jest.fn().mockReturnValue(100), + }; + }; + return mockFormData; + }); + }); + + it('should upload a document to Mistral API using file streaming', async () => { + const mockResponse = { data: { id: 'file-123', purpose: 'ocr' } }; + mockAxios.post.mockResolvedValueOnce(mockResponse); + + const result = await uploadDocumentToMistral({ + filePath: '/path/to/test.pdf', + fileName: 'test.pdf', + apiKey: 'test-api-key', + }); + + // Check that createReadStream was called with the correct file path + expect(fs.createReadStream).toHaveBeenCalledWith('/path/to/test.pdf'); + + // Since we're mocking FormData, we'll just check that axios was called correctly + expect(mockAxios.post).toHaveBeenCalledWith( + 'https://api.mistral.ai/v1/files', + expect.anything(), + expect.objectContaining({ + headers: expect.objectContaining({ + Authorization: 'Bearer test-api-key', + }), + maxBodyLength: Infinity, + maxContentLength: Infinity, + }), + ); + expect(result).toEqual(mockResponse.data); + }); + + it('should handle errors during document upload', async () => { + const errorMessage = 'API error'; + mockAxios.post.mockRejectedValueOnce(new Error(errorMessage)); + + await expect( + uploadDocumentToMistral({ + filePath: '/path/to/test.pdf', + fileName: 'test.pdf', + apiKey: 'test-api-key', + }), + ).rejects.toThrow(); + + const { logger } = require('~/config'); + expect(logger.error).toHaveBeenCalledWith( + expect.stringContaining('Error uploading document to Mistral:'), + expect.any(String), + ); + }); + }); + + describe('getSignedUrl', () => { + it('should fetch signed URL from Mistral API', async () => { + const mockResponse = { data: { url: 'https://document-url.com' } }; + mockAxios.get.mockResolvedValueOnce(mockResponse); + + const result = await getSignedUrl({ + fileId: 'file-123', + apiKey: 'test-api-key', + }); + + expect(mockAxios.get).toHaveBeenCalledWith( + 'https://api.mistral.ai/v1/files/file-123/url?expiry=24', + { + headers: { + Authorization: 'Bearer test-api-key', + }, + }, + ); + expect(result).toEqual(mockResponse.data); + }); + + it('should handle errors when fetching signed URL', async () => { + const errorMessage = 'API error'; + mockAxios.get.mockRejectedValueOnce(new Error(errorMessage)); + + await expect( + getSignedUrl({ + fileId: 'file-123', + apiKey: 'test-api-key', + }), + ).rejects.toThrow(); + + const { logger } = require('~/config'); + expect(logger.error).toHaveBeenCalledWith('Error fetching signed URL:', errorMessage); + }); + }); + + describe('performOCR', () => { + it('should perform OCR using Mistral API (document_url)', async () => { + const mockResponse = { + data: { + pages: [{ markdown: 'Page 1 content' }, { markdown: 'Page 2 content' }], + }, + }; + mockAxios.post.mockResolvedValueOnce(mockResponse); + + const result = await performOCR({ + apiKey: 'test-api-key', + url: 'https://document-url.com', + model: 'mistral-ocr-latest', + documentType: 'document_url', + }); + + expect(mockAxios.post).toHaveBeenCalledWith( + 'https://api.mistral.ai/v1/ocr', + { + model: 'mistral-ocr-latest', + include_image_base64: false, + document: { + type: 'document_url', + document_url: 'https://document-url.com', + }, + }, + { + headers: { + 'Content-Type': 'application/json', + Authorization: 'Bearer test-api-key', + }, + }, + ); + expect(result).toEqual(mockResponse.data); + }); + + it('should perform OCR using Mistral API (image_url)', async () => { + const mockResponse = { + data: { + pages: [{ markdown: 'Image OCR content' }], + }, + }; + mockAxios.post.mockResolvedValueOnce(mockResponse); + + const result = await performOCR({ + apiKey: 'test-api-key', + url: 'https://image-url.com/image.png', + model: 'mistral-ocr-latest', + documentType: 'image_url', + }); + + expect(mockAxios.post).toHaveBeenCalledWith( + 'https://api.mistral.ai/v1/ocr', + { + model: 'mistral-ocr-latest', + include_image_base64: false, + document: { + type: 'image_url', + image_url: 'https://image-url.com/image.png', + }, + }, + { + headers: { + 'Content-Type': 'application/json', + Authorization: 'Bearer test-api-key', + }, + }, + ); + expect(result).toEqual(mockResponse.data); + }); + + it('should handle errors during OCR processing', async () => { + const errorMessage = 'OCR processing error'; + mockAxios.post.mockRejectedValueOnce(new Error(errorMessage)); + + await expect( + performOCR({ + apiKey: 'test-api-key', + url: 'https://document-url.com', + }), + ).rejects.toThrow(); + + const { logger } = require('~/config'); + expect(logger.error).toHaveBeenCalledWith('Error performing OCR:', errorMessage); + }); + }); + + describe('uploadMistralOCR', () => { + beforeEach(() => { + const mockReadStream = { + on: jest.fn().mockImplementation(function (event, handler) { + if (event === 'end') { + handler(); + } + return this; + }), + pipe: jest.fn().mockImplementation(function () { + return this; + }), + pause: jest.fn(), + resume: jest.fn(), + emit: jest.fn(), + once: jest.fn(), + destroy: jest.fn(), + }; + + fs.createReadStream = jest.fn().mockReturnValue(mockReadStream); + }); + + it('should process OCR for a file with standard configuration', async () => { + // Setup mocks + const { loadAuthValues } = require('~/server/services/Tools/credentials'); + loadAuthValues.mockResolvedValue({ + OCR_API_KEY: 'test-api-key', + OCR_BASEURL: 'https://api.mistral.ai/v1', + }); + + // Mock file upload response + mockAxios.post.mockResolvedValueOnce({ + data: { id: 'file-123', purpose: 'ocr' }, + }); + + // Mock signed URL response + mockAxios.get.mockResolvedValueOnce({ + data: { url: 'https://signed-url.com' }, + }); + + // Mock OCR response with text and images + mockAxios.post.mockResolvedValueOnce({ + data: { + pages: [ + { + markdown: 'Page 1 content', + images: [{ image_base64: 'base64image1' }], + }, + { + markdown: 'Page 2 content', + images: [{ image_base64: 'base64image2' }], + }, + ], + }, + }); + + const req = { + user: { id: 'user123' }, + app: { + locals: { + ocr: { + // Use environment variable syntax to ensure loadAuthValues is called + apiKey: '${OCR_API_KEY}', + baseURL: '${OCR_BASEURL}', + mistralModel: 'mistral-medium', + }, + }, + }, + }; + + const file = { + path: '/tmp/upload/file.pdf', + originalname: 'document.pdf', + mimetype: 'application/pdf', + }; + + const result = await uploadMistralOCR({ + req, + file, + file_id: 'file123', + entity_id: 'entity123', + }); + + expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); + + expect(loadAuthValues).toHaveBeenCalledWith({ + userId: 'user123', + authFields: ['OCR_BASEURL', 'OCR_API_KEY'], + optional: expect.any(Set), + }); + + // Verify OCR result + expect(result).toEqual({ + filename: 'document.pdf', + bytes: expect.any(Number), + filepath: 'mistral_ocr', + text: expect.stringContaining('# PAGE 1'), + images: ['base64image1', 'base64image2'], + }); + }); + + it('should process OCR for an image file and use image_url type', async () => { + const { loadAuthValues } = require('~/server/services/Tools/credentials'); + loadAuthValues.mockResolvedValue({ + OCR_API_KEY: 'test-api-key', + OCR_BASEURL: 'https://api.mistral.ai/v1', + }); + + // Mock file upload response + mockAxios.post.mockResolvedValueOnce({ + data: { id: 'file-456', purpose: 'ocr' }, + }); + + // Mock signed URL response + mockAxios.get.mockResolvedValueOnce({ + data: { url: 'https://signed-url.com/image.png' }, + }); + + // Mock OCR response for image + mockAxios.post.mockResolvedValueOnce({ + data: { + pages: [ + { + markdown: 'Image OCR result', + images: [{ image_base64: 'imgbase64' }], + }, + ], + }, + }); + + const req = { + user: { id: 'user456' }, + app: { + locals: { + ocr: { + apiKey: '${OCR_API_KEY}', + baseURL: '${OCR_BASEURL}', + mistralModel: 'mistral-medium', + }, + }, + }, + }; + + const file = { + path: '/tmp/upload/image.png', + originalname: 'image.png', + mimetype: 'image/png', + }; + + const result = await uploadMistralOCR({ + req, + file, + file_id: 'file456', + entity_id: 'entity456', + }); + + expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/image.png'); + + expect(loadAuthValues).toHaveBeenCalledWith({ + userId: 'user456', + authFields: ['OCR_BASEURL', 'OCR_API_KEY'], + optional: expect.any(Set), + }); + + // Check that the OCR API was called with image_url type + expect(mockAxios.post).toHaveBeenCalledWith( + 'https://api.mistral.ai/v1/ocr', + expect.objectContaining({ + document: expect.objectContaining({ + type: 'image_url', + image_url: 'https://signed-url.com/image.png', + }), + }), + expect.any(Object), + ); + + expect(result).toEqual({ + filename: 'image.png', + bytes: expect.any(Number), + filepath: 'mistral_ocr', + text: expect.stringContaining('Image OCR result'), + images: ['imgbase64'], + }); + }); + + it('should process variable references in configuration', async () => { + // Setup mocks with environment variables + const { loadAuthValues } = require('~/server/services/Tools/credentials'); + loadAuthValues.mockResolvedValue({ + CUSTOM_API_KEY: 'custom-api-key', + CUSTOM_BASEURL: 'https://custom-api.mistral.ai/v1', + }); + + // Mock API responses + mockAxios.post.mockResolvedValueOnce({ + data: { id: 'file-123', purpose: 'ocr' }, + }); + mockAxios.get.mockResolvedValueOnce({ + data: { url: 'https://signed-url.com' }, + }); + mockAxios.post.mockResolvedValueOnce({ + data: { + pages: [{ markdown: 'Content from custom API' }], + }, + }); + + const req = { + user: { id: 'user123' }, + app: { + locals: { + ocr: { + apiKey: '${CUSTOM_API_KEY}', + baseURL: '${CUSTOM_BASEURL}', + mistralModel: '${CUSTOM_MODEL}', + }, + }, + }, + }; + + // Set environment variable for model + process.env.CUSTOM_MODEL = 'mistral-large'; + + const file = { + path: '/tmp/upload/file.pdf', + originalname: 'document.pdf', + }; + + const result = await uploadMistralOCR({ + req, + file, + file_id: 'file123', + entity_id: 'entity123', + }); + + expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); + + // Verify that custom environment variables were extracted and used + expect(loadAuthValues).toHaveBeenCalledWith({ + userId: 'user123', + authFields: ['CUSTOM_BASEURL', 'CUSTOM_API_KEY'], + optional: expect.any(Set), + }); + + // Check that mistral-large was used in the OCR API call + expect(mockAxios.post).toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ + model: 'mistral-large', + }), + expect.anything(), + ); + + expect(result.text).toEqual('Content from custom API\n\n'); + }); + + it('should fall back to default values when variables are not properly formatted', async () => { + const { loadAuthValues } = require('~/server/services/Tools/credentials'); + loadAuthValues.mockResolvedValue({ + OCR_API_KEY: 'default-api-key', + OCR_BASEURL: undefined, // Testing optional parameter + }); + + mockAxios.post.mockResolvedValueOnce({ + data: { id: 'file-123', purpose: 'ocr' }, + }); + mockAxios.get.mockResolvedValueOnce({ + data: { url: 'https://signed-url.com' }, + }); + mockAxios.post.mockResolvedValueOnce({ + data: { + pages: [{ markdown: 'Default API result' }], + }, + }); + + const req = { + user: { id: 'user123' }, + app: { + locals: { + ocr: { + // Use environment variable syntax to ensure loadAuthValues is called + apiKey: '${INVALID_FORMAT}', // Using valid env var format but with an invalid name + baseURL: '${OCR_BASEURL}', // Using valid env var format + mistralModel: 'mistral-ocr-latest', // Plain string value + }, + }, + }, + }; + + const file = { + path: '/tmp/upload/file.pdf', + originalname: 'document.pdf', + }; + + await uploadMistralOCR({ + req, + file, + file_id: 'file123', + entity_id: 'entity123', + }); + + expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); + + // Should use the default values + expect(loadAuthValues).toHaveBeenCalledWith({ + userId: 'user123', + authFields: ['OCR_BASEURL', 'INVALID_FORMAT'], + optional: expect.any(Set), + }); + + // Should use the default model when not using environment variable format + expect(mockAxios.post).toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ + model: 'mistral-ocr-latest', + }), + expect.anything(), + ); + }); + + it('should handle API errors during OCR process', async () => { + const { loadAuthValues } = require('~/server/services/Tools/credentials'); + loadAuthValues.mockResolvedValue({ + OCR_API_KEY: 'test-api-key', + }); + + // Mock file upload to fail + mockAxios.post.mockRejectedValueOnce(new Error('Upload failed')); + + const req = { + user: { id: 'user123' }, + app: { + locals: { + ocr: { + apiKey: 'OCR_API_KEY', + baseURL: 'OCR_BASEURL', + }, + }, + }, + }; + + const file = { + path: '/tmp/upload/file.pdf', + originalname: 'document.pdf', + }; + + await expect( + uploadMistralOCR({ + req, + file, + file_id: 'file123', + entity_id: 'entity123', + }), + ).rejects.toThrow('Error uploading document to Mistral OCR API'); + expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); + }); + + it('should handle single page documents without page numbering', async () => { + const { loadAuthValues } = require('~/server/services/Tools/credentials'); + loadAuthValues.mockResolvedValue({ + OCR_API_KEY: 'test-api-key', + OCR_BASEURL: 'https://api.mistral.ai/v1', // Make sure this is included + }); + + // Clear all previous mocks + mockAxios.post.mockClear(); + mockAxios.get.mockClear(); + + // 1. First mock: File upload response + mockAxios.post.mockImplementationOnce(() => + Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }), + ); + + // 2. Second mock: Signed URL response + mockAxios.get.mockImplementationOnce(() => + Promise.resolve({ data: { url: 'https://signed-url.com' } }), + ); + + // 3. Third mock: OCR response + mockAxios.post.mockImplementationOnce(() => + Promise.resolve({ + data: { + pages: [{ markdown: 'Single page content' }], + }, + }), + ); + + const req = { + user: { id: 'user123' }, + app: { + locals: { + ocr: { + apiKey: 'OCR_API_KEY', + baseURL: 'OCR_BASEURL', + mistralModel: 'mistral-ocr-latest', + }, + }, + }, + }; + + const file = { + path: '/tmp/upload/file.pdf', + originalname: 'single-page.pdf', + }; + + const result = await uploadMistralOCR({ + req, + file, + file_id: 'file123', + entity_id: 'entity123', + }); + + expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); + + // Verify that single page documents don't include page numbering + expect(result.text).not.toContain('# PAGE'); + expect(result.text).toEqual('Single page content\n\n'); + }); + + it('should use literal values in configuration when provided directly', async () => { + const { loadAuthValues } = require('~/server/services/Tools/credentials'); + // We'll still mock this but it should not be used for literal values + loadAuthValues.mockResolvedValue({}); + + // Clear all previous mocks + mockAxios.post.mockClear(); + mockAxios.get.mockClear(); + + // 1. First mock: File upload response + mockAxios.post.mockImplementationOnce(() => + Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }), + ); + + // 2. Second mock: Signed URL response + mockAxios.get.mockImplementationOnce(() => + Promise.resolve({ data: { url: 'https://signed-url.com' } }), + ); + + // 3. Third mock: OCR response + mockAxios.post.mockImplementationOnce(() => + Promise.resolve({ + data: { + pages: [{ markdown: 'Processed with literal config values' }], + }, + }), + ); + + const req = { + user: { id: 'user123' }, + app: { + locals: { + ocr: { + // Direct values that should be used as-is, without variable substitution + apiKey: 'actual-api-key-value', + baseURL: 'https://direct-api-url.mistral.ai/v1', + mistralModel: 'mistral-direct-model', + }, + }, + }, + }; + + const file = { + path: '/tmp/upload/file.pdf', + originalname: 'direct-values.pdf', + }; + + const result = await uploadMistralOCR({ + req, + file, + file_id: 'file123', + entity_id: 'entity123', + }); + + expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); + + // Verify the correct URL was used with the direct baseURL value + expect(mockAxios.post).toHaveBeenCalledWith( + 'https://direct-api-url.mistral.ai/v1/files', + expect.any(Object), + expect.objectContaining({ + headers: expect.objectContaining({ + Authorization: 'Bearer actual-api-key-value', + }), + }), + ); + + // Check the OCR call was made with the direct model value + expect(mockAxios.post).toHaveBeenCalledWith( + 'https://direct-api-url.mistral.ai/v1/ocr', + expect.objectContaining({ + model: 'mistral-direct-model', + }), + expect.any(Object), + ); + + // Verify the result + expect(result.text).toEqual('Processed with literal config values\n\n'); + + // Verify loadAuthValues was never called since we used direct values + expect(loadAuthValues).not.toHaveBeenCalled(); + }); + + it('should handle empty configuration values and use defaults', async () => { + const { loadAuthValues } = require('~/server/services/Tools/credentials'); + // Set up the mock values to be returned by loadAuthValues + loadAuthValues.mockResolvedValue({ + OCR_API_KEY: 'default-from-env-key', + OCR_BASEURL: 'https://default-from-env.mistral.ai/v1', + }); + + // Clear all previous mocks + mockAxios.post.mockClear(); + mockAxios.get.mockClear(); + + // 1. First mock: File upload response + mockAxios.post.mockImplementationOnce(() => + Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }), + ); + + // 2. Second mock: Signed URL response + mockAxios.get.mockImplementationOnce(() => + Promise.resolve({ data: { url: 'https://signed-url.com' } }), + ); + + // 3. Third mock: OCR response + mockAxios.post.mockImplementationOnce(() => + Promise.resolve({ + data: { + pages: [{ markdown: 'Content from default configuration' }], + }, + }), + ); + + const req = { + user: { id: 'user123' }, + app: { + locals: { + ocr: { + // Empty string values - should fall back to defaults + apiKey: '', + baseURL: '', + mistralModel: '', + }, + }, + }, + }; + + const file = { + path: '/tmp/upload/file.pdf', + originalname: 'empty-config.pdf', + }; + + const result = await uploadMistralOCR({ + req, + file, + file_id: 'file123', + entity_id: 'entity123', + }); + + expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); + + // Verify loadAuthValues was called with the default variable names + expect(loadAuthValues).toHaveBeenCalledWith({ + userId: 'user123', + authFields: ['OCR_BASEURL', 'OCR_API_KEY'], + optional: expect.any(Set), + }); + + // Verify the API calls used the default values from loadAuthValues + expect(mockAxios.post).toHaveBeenCalledWith( + 'https://default-from-env.mistral.ai/v1/files', + expect.any(Object), + expect.objectContaining({ + headers: expect.objectContaining({ + Authorization: 'Bearer default-from-env-key', + }), + }), + ); + + // Verify the OCR model defaulted to mistral-ocr-latest + expect(mockAxios.post).toHaveBeenCalledWith( + 'https://default-from-env.mistral.ai/v1/ocr', + expect.objectContaining({ + model: 'mistral-ocr-latest', + }), + expect.any(Object), + ); + + // Check result + expect(result.text).toEqual('Content from default configuration\n\n'); + }); + }); +}); diff --git a/api/server/services/Files/MistralOCR/index.js b/api/server/services/Files/MistralOCR/index.js new file mode 100644 index 0000000000..a6223d1ee5 --- /dev/null +++ b/api/server/services/Files/MistralOCR/index.js @@ -0,0 +1,5 @@ +const crud = require('./crud'); + +module.exports = { + ...crud, +}; diff --git a/api/server/services/Files/S3/crud.js b/api/server/services/Files/S3/crud.js new file mode 100644 index 0000000000..10c04106d8 --- /dev/null +++ b/api/server/services/Files/S3/crud.js @@ -0,0 +1,467 @@ +const fs = require('fs'); +const path = require('path'); +const fetch = require('node-fetch'); +const { FileSources } = require('librechat-data-provider'); +const { + PutObjectCommand, + GetObjectCommand, + HeadObjectCommand, + DeleteObjectCommand, +} = require('@aws-sdk/client-s3'); +const { getSignedUrl } = require('@aws-sdk/s3-request-presigner'); +const { initializeS3 } = require('./initialize'); +const { logger } = require('~/config'); + +const bucketName = process.env.AWS_BUCKET_NAME; +const defaultBasePath = 'images'; + +let s3UrlExpirySeconds = 7 * 24 * 60 * 60; +let s3RefreshExpiryMs = null; + +if (process.env.S3_URL_EXPIRY_SECONDS !== undefined) { + const parsed = parseInt(process.env.S3_URL_EXPIRY_SECONDS, 10); + + if (!isNaN(parsed) && parsed > 0) { + s3UrlExpirySeconds = Math.min(parsed, 7 * 24 * 60 * 60); + } else { + logger.warn( + `[S3] Invalid S3_URL_EXPIRY_SECONDS value: "${process.env.S3_URL_EXPIRY_SECONDS}". Using 7-day expiry.`, + ); + } +} + +if (process.env.S3_REFRESH_EXPIRY_MS !== null && process.env.S3_REFRESH_EXPIRY_MS) { + const parsed = parseInt(process.env.S3_REFRESH_EXPIRY_MS, 10); + + if (!isNaN(parsed) && parsed > 0) { + s3RefreshExpiryMs = parsed; + logger.info(`[S3] Using custom refresh expiry time: ${s3RefreshExpiryMs}ms`); + } else { + logger.warn( + `[S3] Invalid S3_REFRESH_EXPIRY_MS value: "${process.env.S3_REFRESH_EXPIRY_MS}". Using default refresh logic.`, + ); + } +} + +/** + * Constructs the S3 key based on the base path, user ID, and file name. + */ +const getS3Key = (basePath, userId, fileName) => `${basePath}/${userId}/${fileName}`; + +/** + * Uploads a buffer to S3 and returns a signed URL. + * + * @param {Object} params + * @param {string} params.userId - The user's unique identifier. + * @param {Buffer} params.buffer - The buffer containing file data. + * @param {string} params.fileName - The file name to use in S3. + * @param {string} [params.basePath='images'] - The base path in the bucket. + * @returns {Promise} Signed URL of the uploaded file. + */ +async function saveBufferToS3({ userId, buffer, fileName, basePath = defaultBasePath }) { + const key = getS3Key(basePath, userId, fileName); + const params = { Bucket: bucketName, Key: key, Body: buffer }; + + try { + const s3 = initializeS3(); + await s3.send(new PutObjectCommand(params)); + return await getS3URL({ userId, fileName, basePath }); + } catch (error) { + logger.error('[saveBufferToS3] Error uploading buffer to S3:', error.message); + throw error; + } +} + +/** + * Retrieves a URL for a file stored in S3. + * Returns a signed URL with expiration time or a proxy URL based on config + * + * @param {Object} params + * @param {string} params.userId - The user's unique identifier. + * @param {string} params.fileName - The file name in S3. + * @param {string} [params.basePath='images'] - The base path in the bucket. + * @returns {Promise} A URL to access the S3 object + */ +async function getS3URL({ userId, fileName, basePath = defaultBasePath }) { + const key = getS3Key(basePath, userId, fileName); + const params = { Bucket: bucketName, Key: key }; + + try { + const s3 = initializeS3(); + return await getSignedUrl(s3, new GetObjectCommand(params), { expiresIn: s3UrlExpirySeconds }); + } catch (error) { + logger.error('[getS3URL] Error getting signed URL from S3:', error.message); + throw error; + } +} + +/** + * Saves a file from a given URL to S3. + * + * @param {Object} params + * @param {string} params.userId - The user's unique identifier. + * @param {string} params.URL - The source URL of the file. + * @param {string} params.fileName - The file name to use in S3. + * @param {string} [params.basePath='images'] - The base path in the bucket. + * @returns {Promise} Signed URL of the uploaded file. + */ +async function saveURLToS3({ userId, URL, fileName, basePath = defaultBasePath }) { + try { + const response = await fetch(URL); + const buffer = await response.buffer(); + // Optionally you can call getBufferMetadata(buffer) if needed. + return await saveBufferToS3({ userId, buffer, fileName, basePath }); + } catch (error) { + logger.error('[saveURLToS3] Error uploading file from URL to S3:', error.message); + throw error; + } +} + +/** + * Deletes a file from S3. + * + * @param {Object} params + * @param {ServerRequest} params.req + * @param {MongoFile} params.file - The file object to delete. + * @returns {Promise} + */ +async function deleteFileFromS3(req, file) { + const key = extractKeyFromS3Url(file.filepath); + const params = { Bucket: bucketName, Key: key }; + if (!key.includes(req.user.id)) { + const message = `[deleteFileFromS3] User ID mismatch: ${req.user.id} vs ${key}`; + logger.error(message); + throw new Error(message); + } + + try { + const s3 = initializeS3(); + + try { + const headCommand = new HeadObjectCommand(params); + await s3.send(headCommand); + logger.debug('[deleteFileFromS3] File exists, proceeding with deletion'); + } catch (headErr) { + if (headErr.name === 'NotFound') { + logger.warn(`[deleteFileFromS3] File does not exist: ${key}`); + return; + } + } + + const deleteResult = await s3.send(new DeleteObjectCommand(params)); + logger.debug('[deleteFileFromS3] Delete command response:', JSON.stringify(deleteResult)); + try { + await s3.send(new HeadObjectCommand(params)); + logger.error('[deleteFileFromS3] File still exists after deletion!'); + } catch (verifyErr) { + if (verifyErr.name === 'NotFound') { + logger.debug(`[deleteFileFromS3] Verified file is deleted: ${key}`); + } else { + logger.error('[deleteFileFromS3] Error verifying deletion:', verifyErr); + } + } + + logger.debug('[deleteFileFromS3] S3 File deletion completed'); + } catch (error) { + logger.error(`[deleteFileFromS3] Error deleting file from S3: ${error.message}`); + logger.error(error.stack); + + // If the file is not found, we can safely return. + if (error.code === 'NoSuchKey') { + return; + } + throw error; + } +} + +/** + * Uploads a local file to S3 by streaming it directly without loading into memory. + * + * @param {Object} params + * @param {import('express').Request} params.req - The Express request (must include user). + * @param {Express.Multer.File} params.file - The file object from Multer. + * @param {string} params.file_id - Unique file identifier. + * @param {string} [params.basePath='images'] - The base path in the bucket. + * @returns {Promise<{ filepath: string, bytes: number }>} + */ +async function uploadFileToS3({ req, file, file_id, basePath = defaultBasePath }) { + try { + const inputFilePath = file.path; + const userId = req.user.id; + const fileName = `${file_id}__${path.basename(inputFilePath)}`; + const key = getS3Key(basePath, userId, fileName); + + const stats = await fs.promises.stat(inputFilePath); + const bytes = stats.size; + const fileStream = fs.createReadStream(inputFilePath); + + const s3 = initializeS3(); + const uploadParams = { + Bucket: bucketName, + Key: key, + Body: fileStream, + }; + + await s3.send(new PutObjectCommand(uploadParams)); + const fileURL = await getS3URL({ userId, fileName, basePath }); + return { filepath: fileURL, bytes }; + } catch (error) { + logger.error('[uploadFileToS3] Error streaming file to S3:', error); + try { + if (file && file.path) { + await fs.promises.unlink(file.path); + } + } catch (unlinkError) { + logger.error( + '[uploadFileToS3] Error deleting temporary file, likely already deleted:', + unlinkError.message, + ); + } + throw error; + } +} + +/** + * Extracts the S3 key from a URL or returns the key if already properly formatted + * + * @param {string} fileUrlOrKey - The file URL or key + * @returns {string} The S3 key + */ +function extractKeyFromS3Url(fileUrlOrKey) { + if (!fileUrlOrKey) { + throw new Error('Invalid input: URL or key is empty'); + } + + try { + const url = new URL(fileUrlOrKey); + return url.pathname.substring(1); + } catch (error) { + const parts = fileUrlOrKey.split('/'); + + if (parts.length >= 3 && !fileUrlOrKey.startsWith('http') && !fileUrlOrKey.startsWith('/')) { + return fileUrlOrKey; + } + + return fileUrlOrKey.startsWith('/') ? fileUrlOrKey.substring(1) : fileUrlOrKey; + } +} + +/** + * Retrieves a readable stream for a file stored in S3. + * + * @param {ServerRequest} req - Server request object. + * @param {string} filePath - The S3 key of the file. + * @returns {Promise} + */ +async function getS3FileStream(_req, filePath) { + try { + const Key = extractKeyFromS3Url(filePath); + const params = { Bucket: bucketName, Key }; + const s3 = initializeS3(); + const data = await s3.send(new GetObjectCommand(params)); + return data.Body; // Returns a Node.js ReadableStream. + } catch (error) { + logger.error('[getS3FileStream] Error retrieving S3 file stream:', error); + throw error; + } +} + +/** + * Determines if a signed S3 URL is close to expiration + * + * @param {string} signedUrl - The signed S3 URL + * @param {number} bufferSeconds - Buffer time in seconds + * @returns {boolean} True if the URL needs refreshing + */ +function needsRefresh(signedUrl, bufferSeconds) { + try { + // Parse the URL + const url = new URL(signedUrl); + + // Check if it has the signature parameters that indicate it's a signed URL + // X-Amz-Signature is the most reliable indicator for AWS signed URLs + if (!url.searchParams.has('X-Amz-Signature')) { + // Not a signed URL, so no expiration to check (or it's already a proxy URL) + return false; + } + + // Extract the expiration time from the URL + const expiresParam = url.searchParams.get('X-Amz-Expires'); + const dateParam = url.searchParams.get('X-Amz-Date'); + + if (!expiresParam || !dateParam) { + // Missing expiration information, assume it needs refresh to be safe + return true; + } + + // Parse the AWS date format (YYYYMMDDTHHMMSSZ) + const year = dateParam.substring(0, 4); + const month = dateParam.substring(4, 6); + const day = dateParam.substring(6, 8); + const hour = dateParam.substring(9, 11); + const minute = dateParam.substring(11, 13); + const second = dateParam.substring(13, 15); + + const dateObj = new Date(`${year}-${month}-${day}T${hour}:${minute}:${second}Z`); + const expiresAtDate = new Date(dateObj.getTime() + parseInt(expiresParam) * 1000); + + // Check if it's close to expiration + const now = new Date(); + + // If S3_REFRESH_EXPIRY_MS is set, use it to determine if URL is expired + if (s3RefreshExpiryMs !== null) { + const urlCreationTime = dateObj.getTime(); + const urlAge = now.getTime() - urlCreationTime; + return urlAge >= s3RefreshExpiryMs; + } + + // Otherwise use the default buffer-based logic + const bufferTime = new Date(now.getTime() + bufferSeconds * 1000); + return expiresAtDate <= bufferTime; + } catch (error) { + logger.error('Error checking URL expiration:', error); + // If we can't determine, assume it needs refresh to be safe + return true; + } +} + +/** + * Generates a new URL for an expired S3 URL + * @param {string} currentURL - The current file URL + * @returns {Promise} + */ +async function getNewS3URL(currentURL) { + try { + const s3Key = extractKeyFromS3Url(currentURL); + if (!s3Key) { + return; + } + const keyParts = s3Key.split('/'); + if (keyParts.length < 3) { + return; + } + + const basePath = keyParts[0]; + const userId = keyParts[1]; + const fileName = keyParts.slice(2).join('/'); + + return await getS3URL({ + userId, + fileName, + basePath, + }); + } catch (error) { + logger.error('Error getting new S3 URL:', error); + } +} + +/** + * Refreshes S3 URLs for an array of files if they're expired or close to expiring + * + * @param {MongoFile[]} files - Array of file documents + * @param {(files: MongoFile[]) => Promise} batchUpdateFiles - Function to update files in the database + * @param {number} [bufferSeconds=3600] - Buffer time in seconds to check for expiration + * @returns {Promise} The files with refreshed URLs if needed + */ +async function refreshS3FileUrls(files, batchUpdateFiles, bufferSeconds = 3600) { + if (!files || !Array.isArray(files) || files.length === 0) { + return files; + } + + const filesToUpdate = []; + + for (let i = 0; i < files.length; i++) { + const file = files[i]; + if (!file?.file_id) { + continue; + } + if (file.source !== FileSources.s3) { + continue; + } + if (!file.filepath) { + continue; + } + if (!needsRefresh(file.filepath, bufferSeconds)) { + continue; + } + try { + const newURL = await getNewS3URL(file.filepath); + if (!newURL) { + continue; + } + filesToUpdate.push({ + file_id: file.file_id, + filepath: newURL, + }); + files[i].filepath = newURL; + } catch (error) { + logger.error(`Error refreshing S3 URL for file ${file.file_id}:`, error); + } + } + + if (filesToUpdate.length > 0) { + await batchUpdateFiles(filesToUpdate); + } + + return files; +} + +/** + * Refreshes a single S3 URL if it's expired or close to expiring + * + * @param {{ filepath: string, source: string }} fileObj - Simple file object containing filepath and source + * @param {number} [bufferSeconds=3600] - Buffer time in seconds to check for expiration + * @returns {Promise} The refreshed URL or the original URL if no refresh needed + */ +async function refreshS3Url(fileObj, bufferSeconds = 3600) { + if (!fileObj || fileObj.source !== FileSources.s3 || !fileObj.filepath) { + return fileObj?.filepath || ''; + } + + if (!needsRefresh(fileObj.filepath, bufferSeconds)) { + return fileObj.filepath; + } + + try { + const s3Key = extractKeyFromS3Url(fileObj.filepath); + if (!s3Key) { + logger.warn(`Unable to extract S3 key from URL: ${fileObj.filepath}`); + return fileObj.filepath; + } + + const keyParts = s3Key.split('/'); + if (keyParts.length < 3) { + logger.warn(`Invalid S3 key format: ${s3Key}`); + return fileObj.filepath; + } + + const basePath = keyParts[0]; + const userId = keyParts[1]; + const fileName = keyParts.slice(2).join('/'); + + const newUrl = await getS3URL({ + userId, + fileName, + basePath, + }); + + logger.debug(`Refreshed S3 URL for key: ${s3Key}`); + return newUrl; + } catch (error) { + logger.error(`Error refreshing S3 URL: ${error.message}`); + return fileObj.filepath; + } +} + +module.exports = { + saveBufferToS3, + saveURLToS3, + getS3URL, + deleteFileFromS3, + uploadFileToS3, + getS3FileStream, + refreshS3FileUrls, + refreshS3Url, + needsRefresh, + getNewS3URL, +}; diff --git a/api/server/services/Files/S3/images.js b/api/server/services/Files/S3/images.js new file mode 100644 index 0000000000..378212cb5e --- /dev/null +++ b/api/server/services/Files/S3/images.js @@ -0,0 +1,118 @@ +const fs = require('fs'); +const path = require('path'); +const sharp = require('sharp'); +const { resizeImageBuffer } = require('../images/resize'); +const { updateUser } = require('~/models/userMethods'); +const { saveBufferToS3 } = require('./crud'); +const { updateFile } = require('~/models/File'); +const { logger } = require('~/config'); + +const defaultBasePath = 'images'; + +/** + * Resizes, converts, and uploads an image file to S3. + * + * @param {Object} params + * @param {import('express').Request} params.req - Express request (expects user and app.locals.imageOutputType). + * @param {Express.Multer.File} params.file - File object from Multer. + * @param {string} params.file_id - Unique file identifier. + * @param {any} params.endpoint - Endpoint identifier used in image processing. + * @param {string} [params.resolution='high'] - Desired image resolution. + * @param {string} [params.basePath='images'] - Base path in the bucket. + * @returns {Promise<{ filepath: string, bytes: number, width: number, height: number }>} + */ +async function uploadImageToS3({ + req, + file, + file_id, + endpoint, + resolution = 'high', + basePath = defaultBasePath, +}) { + try { + const inputFilePath = file.path; + const inputBuffer = await fs.promises.readFile(inputFilePath); + const { + buffer: resizedBuffer, + width, + height, + } = await resizeImageBuffer(inputBuffer, resolution, endpoint); + const extension = path.extname(inputFilePath); + const userId = req.user.id; + + let processedBuffer; + let fileName = `${file_id}__${path.basename(inputFilePath)}`; + const targetExtension = `.${req.app.locals.imageOutputType}`; + + if (extension.toLowerCase() === targetExtension) { + processedBuffer = resizedBuffer; + } else { + processedBuffer = await sharp(resizedBuffer) + .toFormat(req.app.locals.imageOutputType) + .toBuffer(); + fileName = fileName.replace(new RegExp(path.extname(fileName) + '$'), targetExtension); + if (!path.extname(fileName)) { + fileName += targetExtension; + } + } + + const downloadURL = await saveBufferToS3({ + userId, + buffer: processedBuffer, + fileName, + basePath, + }); + await fs.promises.unlink(inputFilePath); + const bytes = Buffer.byteLength(processedBuffer); + return { filepath: downloadURL, bytes, width, height }; + } catch (error) { + logger.error('[uploadImageToS3] Error uploading image to S3:', error.message); + throw error; + } +} + +/** + * Updates a file record and returns its signed URL. + * + * @param {import('express').Request} req - Express request. + * @param {Object} file - File metadata. + * @returns {Promise<[Promise, string]>} + */ +async function prepareImageURLS3(req, file) { + try { + const updatePromise = updateFile({ file_id: file.file_id }); + return Promise.all([updatePromise, file.filepath]); + } catch (error) { + logger.error('[prepareImageURLS3] Error preparing image URL:', error.message); + throw error; + } +} + +/** + * Processes a user's avatar image by uploading it to S3 and updating the user's avatar URL if required. + * + * @param {Object} params + * @param {Buffer} params.buffer - Avatar image buffer. + * @param {string} params.userId - User's unique identifier. + * @param {string} params.manual - 'true' or 'false' flag for manual update. + * @param {string} [params.basePath='images'] - Base path in the bucket. + * @returns {Promise} Signed URL of the uploaded avatar. + */ +async function processS3Avatar({ buffer, userId, manual, basePath = defaultBasePath }) { + try { + const downloadURL = await saveBufferToS3({ userId, buffer, fileName: 'avatar.png', basePath }); + if (manual === 'true') { + await updateUser(userId, { avatar: downloadURL }); + } + return downloadURL; + } catch (error) { + logger.error('[processS3Avatar] Error processing S3 avatar:', error.message); + throw error; + } +} + +module.exports = { + uploadImageToS3, + prepareImageURLS3, + processS3Avatar, +}; diff --git a/api/server/services/Files/S3/index.js b/api/server/services/Files/S3/index.js new file mode 100644 index 0000000000..27ad97a852 --- /dev/null +++ b/api/server/services/Files/S3/index.js @@ -0,0 +1,9 @@ +const crud = require('./crud'); +const images = require('./images'); +const initialize = require('./initialize'); + +module.exports = { + ...crud, + ...images, + ...initialize, +}; diff --git a/api/server/services/Files/S3/initialize.js b/api/server/services/Files/S3/initialize.js new file mode 100644 index 0000000000..2daec25235 --- /dev/null +++ b/api/server/services/Files/S3/initialize.js @@ -0,0 +1,53 @@ +const { S3Client } = require('@aws-sdk/client-s3'); +const { logger } = require('~/config'); + +let s3 = null; + +/** + * Initializes and returns an instance of the AWS S3 client. + * + * If AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are provided, they will be used. + * Otherwise, the AWS SDK's default credentials chain (including IRSA) is used. + * + * If AWS_ENDPOINT_URL is provided, it will be used as the endpoint. + * + * @returns {S3Client|null} An instance of S3Client if the region is provided; otherwise, null. + */ +const initializeS3 = () => { + if (s3) { + return s3; + } + + const region = process.env.AWS_REGION; + if (!region) { + logger.error('[initializeS3] AWS_REGION is not set. Cannot initialize S3.'); + return null; + } + + // Read the custom endpoint if provided. + const endpoint = process.env.AWS_ENDPOINT_URL; + const accessKeyId = process.env.AWS_ACCESS_KEY_ID; + const secretAccessKey = process.env.AWS_SECRET_ACCESS_KEY; + + const config = { + region, + // Conditionally add the endpoint if it is provided + ...(endpoint ? { endpoint } : {}), + }; + + if (accessKeyId && secretAccessKey) { + s3 = new S3Client({ + ...config, + credentials: { accessKeyId, secretAccessKey }, + }); + logger.info('[initializeS3] S3 initialized with provided credentials.'); + } else { + // When using IRSA, credentials are automatically provided via the IAM Role attached to the ServiceAccount. + s3 = new S3Client(config); + logger.info('[initializeS3] S3 initialized using default credentials (IRSA).'); + } + + return s3; +}; + +module.exports = { initializeS3 }; diff --git a/api/server/services/Files/VectorDB/crud.js b/api/server/services/Files/VectorDB/crud.js index d290eea4b1..37a1e81487 100644 --- a/api/server/services/Files/VectorDB/crud.js +++ b/api/server/services/Files/VectorDB/crud.js @@ -37,7 +37,14 @@ const deleteVectors = async (req, file) => { error, message: 'Error deleting vectors', }); - throw new Error(error.message || 'An error occurred during file deletion.'); + if ( + error.response && + error.response.status !== 404 && + (error.response.status < 200 || error.response.status >= 300) + ) { + logger.warn('Error deleting vectors, file will not be deleted'); + throw new Error(error.message || 'An error occurred during file deletion.'); + } } }; diff --git a/api/server/services/Files/images/encode.js b/api/server/services/Files/images/encode.js index 94153ffc64..154941fd89 100644 --- a/api/server/services/Files/images/encode.js +++ b/api/server/services/Files/images/encode.js @@ -7,8 +7,47 @@ const { EModelEndpoint, } = require('librechat-data-provider'); const { getStrategyFunctions } = require('~/server/services/Files/strategies'); +const { logAxiosError } = require('~/utils'); const { logger } = require('~/config'); +/** + * Converts a readable stream to a base64 encoded string. + * + * @param {NodeJS.ReadableStream} stream - The readable stream to convert. + * @param {boolean} [destroyStream=true] - Whether to destroy the stream after processing. + * @returns {Promise} - Promise resolving to the base64 encoded content. + */ +async function streamToBase64(stream, destroyStream = true) { + return new Promise((resolve, reject) => { + const chunks = []; + + stream.on('data', (chunk) => { + chunks.push(chunk); + }); + + stream.on('end', () => { + try { + const buffer = Buffer.concat(chunks); + const base64Data = buffer.toString('base64'); + chunks.length = 0; // Clear the array + resolve(base64Data); + } catch (err) { + reject(err); + } + }); + + stream.on('error', (error) => { + chunks.length = 0; + reject(error); + }); + }).finally(() => { + // Clean up the stream if required + if (destroyStream && stream.destroy && typeof stream.destroy === 'function') { + stream.destroy(); + } + }); +} + /** * Fetches an image from a URL and returns its base64 representation. * @@ -22,10 +61,12 @@ async function fetchImageToBase64(url) { const response = await axios.get(url, { responseType: 'arraybuffer', }); - return Buffer.from(response.data).toString('base64'); + const base64Data = Buffer.from(response.data).toString('base64'); + response.data = null; + return base64Data; } catch (error) { - logger.error('Error fetching image to convert to base64', error); - throw error; + const message = 'Error fetching image to convert to base64'; + throw new Error(logAxiosError({ message, error })); } } @@ -37,18 +78,23 @@ const base64Only = new Set([ EModelEndpoint.bedrock, ]); +const blobStorageSources = new Set([FileSources.azure_blob, FileSources.s3]); + /** * Encodes and formats the given files. * @param {Express.Request} req - The request object. * @param {Array} files - The array of files to encode and format. * @param {EModelEndpoint} [endpoint] - Optional: The endpoint for the image. * @param {string} [mode] - Optional: The endpoint mode for the image. - * @returns {Promise} - A promise that resolves to the result object containing the encoded images and file details. + * @returns {Promise<{ text: string; files: MongoFile[]; image_urls: MessageContentImageUrl[] }>} - A promise that resolves to the result object containing the encoded images and file details. */ async function encodeAndFormat(req, files, endpoint, mode) { const promises = []; + /** @type {Record, 'prepareImagePayload' | 'getDownloadStream'>>} */ const encodingMethods = {}; + /** @type {{ text: string; files: MongoFile[]; image_urls: MessageContentImageUrl[] }} */ const result = { + text: '', files: [], image_urls: [], }; @@ -58,7 +104,11 @@ async function encodeAndFormat(req, files, endpoint, mode) { } for (let file of files) { + /** @type {FileSources} */ const source = file.source ?? FileSources.local; + if (source === FileSources.text && file.text) { + result.text += `${!result.text ? 'Attached document(s):\n```md' : '\n\n---\n\n'}# "${file.filename}"\n${file.text}\n`; + } if (!file.height) { promises.push([file, null]); @@ -66,18 +116,29 @@ async function encodeAndFormat(req, files, endpoint, mode) { } if (!encodingMethods[source]) { - const { prepareImagePayload } = getStrategyFunctions(source); + const { prepareImagePayload, getDownloadStream } = getStrategyFunctions(source); if (!prepareImagePayload) { throw new Error(`Encoding function not implemented for ${source}`); } - encodingMethods[source] = prepareImagePayload; + encodingMethods[source] = { prepareImagePayload, getDownloadStream }; } - const preparePayload = encodingMethods[source]; - - /* Google & Anthropic don't support passing URLs to payload */ - if (source !== FileSources.local && base64Only.has(endpoint)) { + const preparePayload = encodingMethods[source].prepareImagePayload; + /* We need to fetch the image and convert it to base64 if we are using S3/Azure Blob storage. */ + if (blobStorageSources.has(source)) { + try { + const downloadStream = encodingMethods[source].getDownloadStream; + let stream = await downloadStream(req, file.filepath); + let base64Data = await streamToBase64(stream); + stream = null; + promises.push([file, base64Data]); + base64Data = null; + continue; + } catch (error) { + // Error handling code + } + } else if (source !== FileSources.local && base64Only.has(endpoint)) { const [_file, imageURL] = await preparePayload(req, file); promises.push([_file, await fetchImageToBase64(imageURL)]); continue; @@ -85,10 +146,15 @@ async function encodeAndFormat(req, files, endpoint, mode) { promises.push(preparePayload(req, file)); } + if (result.text) { + result.text += '\n```'; + } + const detail = req.body.imageDetail ?? ImageDetail.auto; /** @type {Array<[MongoFile, string]>} */ const formattedImages = await Promise.all(promises); + promises.length = 0; for (const [file, imageContent] of formattedImages) { const fileMetadata = { @@ -121,8 +187,8 @@ async function encodeAndFormat(req, files, endpoint, mode) { }; if (mode === VisionModes.agents) { - result.image_urls.push(imagePart); - result.files.push(fileMetadata); + result.image_urls.push({ ...imagePart }); + result.files.push({ ...fileMetadata }); continue; } @@ -144,10 +210,11 @@ async function encodeAndFormat(req, files, endpoint, mode) { delete imagePart.image_url; } - result.image_urls.push(imagePart); - result.files.push(fileMetadata); + result.image_urls.push({ ...imagePart }); + result.files.push({ ...fileMetadata }); } - return result; + formattedImages.length = 0; + return { ...result }; } module.exports = { diff --git a/api/server/services/Files/process.js b/api/server/services/Files/process.js index a5d9c8c1e0..81a4f52855 100644 --- a/api/server/services/Files/process.js +++ b/api/server/services/Files/process.js @@ -28,8 +28,8 @@ const { addResourceFileId, deleteResourceFileId } = require('~/server/controller const { addAgentResourceFile, removeAgentResourceFiles } = require('~/models/Agent'); const { getOpenAIClient } = require('~/server/controllers/assistants/helpers'); const { createFile, updateFileUsage, deleteFiles } = require('~/models/File'); -const { getEndpointsConfig } = require('~/server/services/Config'); -const { loadAuthValues } = require('~/app/clients/tools/util'); +const { loadAuthValues } = require('~/server/services/Tools/credentials'); +const { checkCapability } = require('~/server/services/Config'); const { LB_QueueAsyncCall } = require('~/server/utils/queue'); const { getStrategyFunctions } = require('./strategies'); const { determineFileType } = require('~/server/utils'); @@ -162,7 +162,6 @@ const processDeleteRequest = async ({ req, files }) => { for (const file of files) { const source = file.source ?? FileSources.local; - if (req.body.agent_id && req.body.tool_resource) { agentFiles.push({ tool_resource: req.body.tool_resource, @@ -170,6 +169,11 @@ const processDeleteRequest = async ({ req, files }) => { }); } + if (source === FileSources.text) { + resolvedFileIds.push(file.file_id); + continue; + } + if (checkOpenAIStorage(source) && !client[source]) { await initializeClients(); } @@ -347,8 +351,8 @@ const uploadImageBuffer = async ({ req, context, metadata = {}, resize = true }) req.app.locals.imageOutputType }`; } - - const filepath = await saveBuffer({ userId: req.user.id, fileName: filename, buffer }); + const fileName = `${file_id}-${filename}`; + const filepath = await saveBuffer({ userId: req.user.id, fileName, buffer }); return await createFile( { user: req.user.id, @@ -453,17 +457,6 @@ const processFileUpload = async ({ req, res, metadata }) => { res.status(200).json({ message: 'File uploaded and processed successfully', ...result }); }; -/** - * @param {ServerRequest} req - * @param {AgentCapabilities} capability - * @returns {Promise} - */ -const checkCapability = async (req, capability) => { - const endpointsConfig = await getEndpointsConfig(req); - const capabilities = endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? []; - return capabilities.includes(capability); -}; - /** * Applies the current strategy for file uploads. * Saves file metadata to the database with an expiry TTL. @@ -499,7 +492,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => { let fileInfoMetadata; const entity_id = messageAttachment === true ? undefined : agent_id; - + const basePath = mime.getType(file.originalname)?.startsWith('image') ? 'images' : 'uploads'; if (tool_resource === EToolResources.execute_code) { const isCodeEnabled = await checkCapability(req, AgentCapabilities.execute_code); if (!isCodeEnabled) { @@ -521,6 +514,52 @@ const processAgentFileUpload = async ({ req, res, metadata }) => { if (!isFileSearchEnabled) { throw new Error('File search is not enabled for Agents'); } + } else if (tool_resource === EToolResources.ocr) { + const isOCREnabled = await checkCapability(req, AgentCapabilities.ocr); + if (!isOCREnabled) { + throw new Error('OCR capability is not enabled for Agents'); + } + + const { handleFileUpload: uploadMistralOCR } = getStrategyFunctions( + req.app.locals?.ocr?.strategy ?? FileSources.mistral_ocr, + ); + const { file_id, temp_file_id } = metadata; + + const { + text, + bytes, + // TODO: OCR images support? + images, + filename, + filepath: ocrFileURL, + } = await uploadMistralOCR({ req, file, file_id, entity_id: agent_id, basePath }); + + const fileInfo = removeNullishValues({ + text, + bytes, + file_id, + temp_file_id, + user: req.user.id, + type: 'text/plain', + filepath: ocrFileURL, + source: FileSources.text, + filename: filename ?? file.originalname, + model: messageAttachment ? undefined : req.body.model, + context: messageAttachment ? FileContext.message_attachment : FileContext.agents, + }); + + if (!messageAttachment && tool_resource) { + await addAgentResourceFile({ + req, + file_id, + agent_id, + tool_resource, + }); + } + const result = await createFile(fileInfo, true); + return res + .status(200) + .json({ message: 'Agent file uploaded and processed successfully', ...result }); } const source = @@ -543,6 +582,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => { file, file_id, entity_id, + basePath, }); let filepath = _filepath; @@ -801,8 +841,7 @@ async function saveBase64Image( { req, file_id: _file_id, filename: _filename, endpoint, context, resolution = 'high' }, ) { const file_id = _file_id ?? v4(); - - let filename = _filename; + let filename = `${file_id}-${_filename}`; const { buffer: inputBuffer, type } = base64ToBuffer(url); if (!path.extname(_filename)) { const extension = mime.getExtension(type); diff --git a/api/server/services/Files/strategies.js b/api/server/services/Files/strategies.js index ddfdd57469..c6cfe77069 100644 --- a/api/server/services/Files/strategies.js +++ b/api/server/services/Files/strategies.js @@ -21,9 +21,32 @@ const { processLocalAvatar, getLocalFileStream, } = require('./Local'); +const { + getS3URL, + saveURLToS3, + saveBufferToS3, + getS3FileStream, + uploadImageToS3, + prepareImageURLS3, + deleteFileFromS3, + processS3Avatar, + uploadFileToS3, +} = require('./S3'); +const { + saveBufferToAzure, + saveURLToAzure, + getAzureURL, + deleteFileFromAzure, + uploadFileToAzure, + getAzureFileStream, + uploadImageToAzure, + prepareAzureImageURL, + processAzureAvatar, +} = require('./Azure'); const { uploadOpenAIFile, deleteOpenAIFile, getOpenAIFileStream } = require('./OpenAI'); const { getCodeOutputDownloadStream, uploadCodeEnvFile } = require('./Code'); const { uploadVectors, deleteVectors } = require('./VectorDB'); +const { uploadMistralOCR } = require('./MistralOCR'); /** * Firebase Storage Strategy Functions @@ -57,6 +80,38 @@ const localStrategy = () => ({ getDownloadStream: getLocalFileStream, }); +/** + * S3 Storage Strategy Functions + * + * */ +const s3Strategy = () => ({ + handleFileUpload: uploadFileToS3, + saveURL: saveURLToS3, + getFileURL: getS3URL, + deleteFile: deleteFileFromS3, + saveBuffer: saveBufferToS3, + prepareImagePayload: prepareImageURLS3, + processAvatar: processS3Avatar, + handleImageUpload: uploadImageToS3, + getDownloadStream: getS3FileStream, +}); + +/** + * Azure Blob Storage Strategy Functions + * + * */ +const azureStrategy = () => ({ + handleFileUpload: uploadFileToAzure, + saveURL: saveURLToAzure, + getFileURL: getAzureURL, + deleteFile: deleteFileFromAzure, + saveBuffer: saveBufferToAzure, + prepareImagePayload: prepareAzureImageURL, + processAvatar: processAzureAvatar, + handleImageUpload: uploadImageToAzure, + getDownloadStream: getAzureFileStream, +}); + /** * VectorDB Storage Strategy Functions * @@ -127,6 +182,26 @@ const codeOutputStrategy = () => ({ getDownloadStream: getCodeOutputDownloadStream, }); +const mistralOCRStrategy = () => ({ + /** @type {typeof saveFileFromURL | null} */ + saveURL: null, + /** @type {typeof getLocalFileURL | null} */ + getFileURL: null, + /** @type {typeof saveLocalBuffer | null} */ + saveBuffer: null, + /** @type {typeof processLocalAvatar | null} */ + processAvatar: null, + /** @type {typeof uploadLocalImage | null} */ + handleImageUpload: null, + /** @type {typeof prepareImagesLocal | null} */ + prepareImagePayload: null, + /** @type {typeof deleteLocalFile | null} */ + deleteFile: null, + /** @type {typeof getLocalFileStream | null} */ + getDownloadStream: null, + handleFileUpload: uploadMistralOCR, +}); + // Strategy Selector const getStrategyFunctions = (fileSource) => { if (fileSource === FileSources.firebase) { @@ -137,10 +212,16 @@ const getStrategyFunctions = (fileSource) => { return openAIStrategy(); } else if (fileSource === FileSources.azure) { return openAIStrategy(); + } else if (fileSource === FileSources.azure_blob) { + return azureStrategy(); } else if (fileSource === FileSources.vectordb) { return vectorStrategy(); + } else if (fileSource === FileSources.s3) { + return s3Strategy(); } else if (fileSource === FileSources.execute_code) { return codeOutputStrategy(); + } else if (fileSource === FileSources.mistral_ocr) { + return mistralOCRStrategy(); } else { throw new Error('Invalid file source'); } diff --git a/api/server/services/MCP.js b/api/server/services/MCP.js index f934f9d519..1d4fc5112c 100644 --- a/api/server/services/MCP.js +++ b/api/server/services/MCP.js @@ -13,13 +13,13 @@ const { logger, getMCPManager } = require('~/config'); * Creates a general tool for an entire action set. * * @param {Object} params - The parameters for loading action sets. - * @param {ServerRequest} params.req - The name of the tool. + * @param {ServerRequest} params.req - The Express request object, containing user/request info. * @param {string} params.toolKey - The toolKey for the tool. * @param {import('@librechat/agents').Providers | EModelEndpoint} params.provider - The provider for the tool. * @param {string} params.model - The model for the tool. * @returns { Promise unknown}> } An object with `_call` method to execute the tool input. */ -async function createMCPTool({ req, toolKey, provider }) { +async function createMCPTool({ req, toolKey, provider: _provider }) { const toolDefinition = req.app.locals.availableTools[toolKey]?.function; if (!toolDefinition) { logger.error(`Tool ${toolKey} not found in available tools`); @@ -27,9 +27,10 @@ async function createMCPTool({ req, toolKey, provider }) { } /** @type {LCTool} */ const { description, parameters } = toolDefinition; - const isGoogle = provider === Providers.VERTEXAI || provider === Providers.GOOGLE; + const isGoogle = _provider === Providers.VERTEXAI || _provider === Providers.GOOGLE; let schema = convertJsonSchemaToZod(parameters, { allowEmptyObject: !isGoogle, + transformOneOfAnyOf: true, }); if (!schema) { @@ -37,11 +38,31 @@ async function createMCPTool({ req, toolKey, provider }) { } const [toolName, serverName] = toolKey.split(Constants.mcp_delimiter); - /** @type {(toolInput: Object | string) => Promise} */ - const _call = async (toolInput) => { + + if (!req.user?.id) { + logger.error( + `[MCP][${serverName}][${toolName}] User ID not found on request. Cannot create tool.`, + ); + throw new Error(`User ID not found on request. Cannot create tool for ${toolKey}.`); + } + + /** @type {(toolArguments: Object | string, config?: GraphRunnableConfig) => Promise} */ + const _call = async (toolArguments, config) => { try { - const mcpManager = await getMCPManager(); - const result = await mcpManager.callTool(serverName, toolName, provider, toolInput); + const derivedSignal = config?.signal ? AbortSignal.any([config.signal]) : undefined; + const mcpManager = getMCPManager(config?.configurable?.user_id); + const provider = (config?.metadata?.provider || _provider)?.toLowerCase(); + const result = await mcpManager.callTool({ + serverName, + toolName, + provider, + toolArguments, + options: { + userId: config?.configurable?.user_id, + signal: derivedSignal, + }, + }); + if (isAssistantsEndpoint(provider) && Array.isArray(result)) { return result[0]; } @@ -50,8 +71,13 @@ async function createMCPTool({ req, toolKey, provider }) { } return result; } catch (error) { - logger.error(`${toolName} MCP server tool call failed`, error); - return `${toolName} MCP server tool call failed.`; + logger.error( + `[MCP][User: ${config?.configurable?.user_id}][${serverName}] Error calling "${toolName}" MCP tool:`, + error, + ); + throw new Error( + `"${toolKey}" tool call failed${error?.message ? `: ${error?.message}` : '.'}`, + ); } }; diff --git a/api/server/services/ModelService.js b/api/server/services/ModelService.js index 1394a5d697..a1ccd7643b 100644 --- a/api/server/services/ModelService.js +++ b/api/server/services/ModelService.js @@ -1,9 +1,12 @@ const axios = require('axios'); +const { Providers } = require('@librechat/agents'); const { HttpsProxyAgent } = require('https-proxy-agent'); const { EModelEndpoint, defaultModels, CacheKeys } = require('librechat-data-provider'); const { inputSchema, logAxiosError, extractBaseURL, processModelData } = require('~/utils'); const { OllamaClient } = require('~/app/clients/OllamaClient'); +const { isUserProvided } = require('~/server/utils'); const getLogStores = require('~/cache/getLogStores'); +const { logger } = require('~/config'); /** * Splits a string by commas and trims each resulting value. @@ -41,7 +44,7 @@ const fetchModels = async ({ user, apiKey, baseURL, - name = 'OpenAI', + name = EModelEndpoint.openAI, azure = false, userIdQuery = false, createTokenConfig = true, @@ -57,18 +60,25 @@ const fetchModels = async ({ return models; } - if (name && name.toLowerCase().startsWith('ollama')) { + if (name && name.toLowerCase().startsWith(Providers.OLLAMA)) { return await OllamaClient.fetchModels(baseURL); } try { const options = { - headers: { - Authorization: `Bearer ${apiKey}`, - }, + headers: {}, timeout: 5000, }; + if (name === EModelEndpoint.anthropic) { + options.headers = { + 'x-api-key': apiKey, + 'anthropic-version': process.env.ANTHROPIC_VERSION || '2023-06-01', + }; + } else { + options.headers.Authorization = `Bearer ${apiKey}`; + } + if (process.env.PROXY) { options.httpsAgent = new HttpsProxyAgent(process.env.PROXY); } @@ -128,9 +138,6 @@ const fetchOpenAIModels = async (opts, _models = []) => { // .split('/deployments')[0] // .concat(`/models?api-version=${azure.azureOpenAIApiVersion}`); // apiKey = azureOpenAIApiKey; - } else if (process.env.OPENROUTER_API_KEY) { - reverseProxyUrl = 'https://openrouter.ai/api/v1'; - apiKey = process.env.OPENROUTER_API_KEY; } if (reverseProxyUrl) { @@ -150,7 +157,7 @@ const fetchOpenAIModels = async (opts, _models = []) => { baseURL, azure: opts.azure, user: opts.user, - name: baseURL, + name: EModelEndpoint.openAI, }); } @@ -159,7 +166,7 @@ const fetchOpenAIModels = async (opts, _models = []) => { } if (baseURL === openaiBaseURL) { - const regex = /(text-davinci-003|gpt-|o\d+-)/; + const regex = /(text-davinci-003|gpt-|o\d+)/; const excludeRegex = /audio|realtime/; models = models.filter((model) => regex.test(model) && !excludeRegex.test(model)); const instructModels = models.filter((model) => model.includes('instruct')); @@ -217,7 +224,7 @@ const getOpenAIModels = async (opts) => { return models; } - if (userProvidedOpenAI && !process.env.OPENROUTER_API_KEY) { + if (userProvidedOpenAI) { return models; } @@ -233,13 +240,71 @@ const getChatGPTBrowserModels = () => { return models; }; -const getAnthropicModels = () => { +/** + * Fetches models from the Anthropic API. + * @async + * @function + * @param {object} opts - The options for fetching the models. + * @param {string} opts.user - The user ID to send to the API. + * @param {string[]} [_models=[]] - The models to use as a fallback. + */ +const fetchAnthropicModels = async (opts, _models = []) => { + let models = _models.slice() ?? []; + let apiKey = process.env.ANTHROPIC_API_KEY; + const anthropicBaseURL = 'https://api.anthropic.com/v1'; + let baseURL = anthropicBaseURL; + let reverseProxyUrl = process.env.ANTHROPIC_REVERSE_PROXY; + + if (reverseProxyUrl) { + baseURL = extractBaseURL(reverseProxyUrl); + } + + if (!apiKey) { + return models; + } + + const modelsCache = getLogStores(CacheKeys.MODEL_QUERIES); + + const cachedModels = await modelsCache.get(baseURL); + if (cachedModels) { + return cachedModels; + } + + if (baseURL) { + models = await fetchModels({ + apiKey, + baseURL, + user: opts.user, + name: EModelEndpoint.anthropic, + tokenKey: EModelEndpoint.anthropic, + }); + } + + if (models.length === 0) { + return _models; + } + + await modelsCache.set(baseURL, models); + return models; +}; + +const getAnthropicModels = async (opts = {}) => { let models = defaultModels[EModelEndpoint.anthropic]; if (process.env.ANTHROPIC_MODELS) { models = splitAndTrim(process.env.ANTHROPIC_MODELS); + return models; } - return models; + if (isUserProvided(process.env.ANTHROPIC_API_KEY)) { + return models; + } + + try { + return await fetchAnthropicModels(opts, models); + } catch (error) { + logger.error('Error fetching Anthropic models:', error); + return models; + } }; const getGoogleModels = () => { diff --git a/api/server/services/ModelService.spec.js b/api/server/services/ModelService.spec.js index a383db1e3c..fb4481f840 100644 --- a/api/server/services/ModelService.spec.js +++ b/api/server/services/ModelService.spec.js @@ -161,22 +161,6 @@ describe('getOpenAIModels', () => { expect(models).toEqual(expect.arrayContaining(['openai-model', 'openai-model-2'])); }); - it('attempts to use OPENROUTER_API_KEY if set', async () => { - process.env.OPENROUTER_API_KEY = 'test-router-key'; - const expectedModels = ['model-router-1', 'model-router-2']; - - axios.get.mockResolvedValue({ - data: { - data: expectedModels.map((id) => ({ id })), - }, - }); - - const models = await getOpenAIModels({ user: 'user456' }); - - expect(models).toEqual(expect.arrayContaining(expectedModels)); - expect(axios.get).toHaveBeenCalled(); - }); - it('utilizes proxy configuration when PROXY is set', async () => { axios.get.mockResolvedValue({ data: { @@ -368,15 +352,15 @@ describe('splitAndTrim', () => { }); describe('getAnthropicModels', () => { - it('returns default models when ANTHROPIC_MODELS is not set', () => { + it('returns default models when ANTHROPIC_MODELS is not set', async () => { delete process.env.ANTHROPIC_MODELS; - const models = getAnthropicModels(); + const models = await getAnthropicModels(); expect(models).toEqual(defaultModels[EModelEndpoint.anthropic]); }); - it('returns models from ANTHROPIC_MODELS when set', () => { + it('returns models from ANTHROPIC_MODELS when set', async () => { process.env.ANTHROPIC_MODELS = 'claude-1, claude-2 '; - const models = getAnthropicModels(); + const models = await getAnthropicModels(); expect(models).toEqual(['claude-1', 'claude-2']); }); }); diff --git a/api/server/services/Runs/methods.js b/api/server/services/Runs/methods.js index c6dfcbedde..3c18e9969b 100644 --- a/api/server/services/Runs/methods.js +++ b/api/server/services/Runs/methods.js @@ -55,8 +55,7 @@ async function retrieveRun({ thread_id, run_id, timeout, openai }) { return response.data; } catch (error) { const message = '[retrieveRun] Failed to retrieve run data:'; - logAxiosError({ message, error }); - throw error; + throw new Error(logAxiosError({ message, error })); } } diff --git a/api/server/services/Threads/manage.js b/api/server/services/Threads/manage.js index f99dca7534..5eace214c3 100644 --- a/api/server/services/Threads/manage.js +++ b/api/server/services/Threads/manage.js @@ -132,6 +132,8 @@ async function saveUserMessage(req, params) { * @param {string} params.endpoint - The conversation endpoint * @param {string} params.parentMessageId - The latest user message that triggered this response. * @param {string} [params.instructions] - Optional: from preset for `instructions` field. + * @param {string} [params.spec] - Optional: Model spec identifier. + * @param {string} [params.iconURL] * Overrides the instructions of the assistant. * @param {string} [params.promptPrefix] - Optional: from preset for `additional_instructions` field. * @return {Promise} A promise that resolves to the created run object. @@ -154,6 +156,8 @@ async function saveAssistantMessage(req, params) { text: params.text, unfinished: false, // tokenCount, + iconURL: params.iconURL, + spec: params.spec, }); await saveConvo( @@ -165,6 +169,8 @@ async function saveAssistantMessage(req, params) { instructions: params.instructions, assistant_id: params.assistant_id, model: params.model, + iconURL: params.iconURL, + spec: params.spec, }, { context: 'api/server/services/Threads/manage.js #saveAssistantMessage' }, ); diff --git a/api/server/services/TokenService.js b/api/server/services/TokenService.js index ec0f990a47..3dd2e79ffa 100644 --- a/api/server/services/TokenService.js +++ b/api/server/services/TokenService.js @@ -93,11 +93,12 @@ const refreshAccessToken = async ({ return response.data; } catch (error) { const message = 'Error refreshing OAuth tokens'; - logAxiosError({ - message, - error, - }); - throw new Error(message); + throw new Error( + logAxiosError({ + message, + error, + }), + ); } }; @@ -156,11 +157,12 @@ const getAccessToken = async ({ return response.data; } catch (error) { const message = 'Error exchanging OAuth code'; - logAxiosError({ - message, - error, - }); - throw new Error(message); + throw new Error( + logAxiosError({ + message, + error, + }), + ); } }; diff --git a/api/server/services/ToolService.js b/api/server/services/ToolService.js index f3e4efb6e3..b71e97f742 100644 --- a/api/server/services/ToolService.js +++ b/api/server/services/ToolService.js @@ -8,6 +8,7 @@ const { ErrorTypes, ContentTypes, imageGenTools, + EToolResources, EModelEndpoint, actionDelimiter, ImageVisionTool, @@ -15,9 +16,20 @@ const { AgentCapabilities, validateAndParseOpenAPISpec, } = require('librechat-data-provider'); +const { + createActionTool, + decryptMetadata, + loadActionSets, + domainParser, +} = require('./ActionService'); +const { + createOpenAIImageTools, + createYouTubeTools, + manifestToolMap, + toolkits, +} = require('~/app/clients/tools'); const { processFileURL, uploadImageBuffer } = require('~/server/services/Files/process'); -const { createYouTubeTools, manifestToolMap, toolkits } = require('~/app/clients/tools'); -const { loadActionSets, createActionTool, domainParser } = require('./ActionService'); +const { isActionDomainAllowed } = require('~/server/services/domains'); const { getEndpointsConfig } = require('~/server/services/Config'); const { recordUsage } = require('~/server/services/Threads'); const { loadTools } = require('~/app/clients/tools/util'); @@ -25,6 +37,30 @@ const { redactMessage } = require('~/config/parsers'); const { sleep } = require('~/server/utils'); const { logger } = require('~/config'); +/** + * @param {string} toolName + * @returns {string | undefined} toolKey + */ +function getToolkitKey(toolName) { + /** @type {string|undefined} */ + let toolkitKey; + for (const toolkit of toolkits) { + if (toolName.startsWith(EToolResources.image_edit)) { + const splitMatches = toolkit.pluginKey.split('_'); + const suffix = splitMatches[splitMatches.length - 1]; + if (toolName.endsWith(suffix)) { + toolkitKey = toolkit.pluginKey; + break; + } + } + if (toolName.startsWith(toolkit.pluginKey)) { + toolkitKey = toolkit.pluginKey; + break; + } + } + return toolkitKey; +} + /** * Loads and formats tools from the specified tool directory. * @@ -97,14 +133,16 @@ function loadAndFormatTools({ directory, adminFilter = [], adminIncluded = [] }) tools.push(formattedTool); } - /** Basic Tools; schema: { input: string } */ - const basicToolInstances = [new Calculator(), ...createYouTubeTools({ override: true })]; + /** Basic Tools & Toolkits; schema: { input: string } */ + const basicToolInstances = [ + new Calculator(), + ...createOpenAIImageTools({ override: true }), + ...createYouTubeTools({ override: true }), + ]; for (const toolInstance of basicToolInstances) { const formattedTool = formatToOpenAIAssistantTool(toolInstance); let toolName = formattedTool[Tools.function].name; - toolName = toolkits.some((toolkit) => toolName.startsWith(toolkit.pluginKey)) - ? toolName.split('_')[0] - : toolName; + toolName = getToolkitKey(toolName) ?? toolName; if (filter.has(toolName) && included.size === 0) { continue; } @@ -315,54 +353,96 @@ async function processRequiredActions(client, requiredActions) { if (!tool) { // throw new Error(`Tool ${currentAction.tool} not found.`); + // Load all action sets once if not already loaded if (!actionSets.length) { actionSets = (await loadActionSets({ assistant_id: client.req.body.assistant_id, })) ?? []; + + // Process all action sets once + // Map domains to their processed action sets + const processedDomains = new Map(); + const domainMap = new Map(); + + for (const action of actionSets) { + const domain = await domainParser(action.metadata.domain, true); + domainMap.set(domain, action); + + // Check if domain is allowed + const isDomainAllowed = await isActionDomainAllowed(action.metadata.domain); + if (!isDomainAllowed) { + continue; + } + + // Validate and parse OpenAPI spec + const validationResult = validateAndParseOpenAPISpec(action.metadata.raw_spec); + if (!validationResult.spec) { + throw new Error( + `Invalid spec: user: ${client.req.user.id} | thread_id: ${requiredActions[0].thread_id} | run_id: ${requiredActions[0].run_id}`, + ); + } + + // Process the OpenAPI spec + const { requestBuilders } = openapiToFunction(validationResult.spec); + + // Store encrypted values for OAuth flow + const encrypted = { + oauth_client_id: action.metadata.oauth_client_id, + oauth_client_secret: action.metadata.oauth_client_secret, + }; + + // Decrypt metadata + const decryptedAction = { ...action }; + decryptedAction.metadata = await decryptMetadata(action.metadata); + + processedDomains.set(domain, { + action: decryptedAction, + requestBuilders, + encrypted, + }); + + // Store builders for reuse + ActionBuildersMap[action.metadata.domain] = requestBuilders; + } + + // Update actionSets reference to use the domain map + actionSets = { domainMap, processedDomains }; } - let actionSet = null; + // Find the matching domain for this tool let currentDomain = ''; - for (let action of actionSets) { - const domain = await domainParser(client.req, action.metadata.domain, true); + for (const domain of actionSets.domainMap.keys()) { if (currentAction.tool.includes(domain)) { currentDomain = domain; - actionSet = action; break; } } - if (!actionSet) { + if (!currentDomain || !actionSets.processedDomains.has(currentDomain)) { // TODO: try `function` if no action set is found // throw new Error(`Tool ${currentAction.tool} not found.`); continue; } - let builders = ActionBuildersMap[actionSet.metadata.domain]; - - if (!builders) { - const validationResult = validateAndParseOpenAPISpec(actionSet.metadata.raw_spec); - if (!validationResult.spec) { - throw new Error( - `Invalid spec: user: ${client.req.user.id} | thread_id: ${requiredActions[0].thread_id} | run_id: ${requiredActions[0].run_id}`, - ); - } - const { requestBuilders } = openapiToFunction(validationResult.spec); - ActionToolMap[actionSet.metadata.domain] = requestBuilders; - builders = requestBuilders; - } - + const { action, requestBuilders, encrypted } = actionSets.processedDomains.get(currentDomain); const functionName = currentAction.tool.replace(`${actionDelimiter}${currentDomain}`, ''); - - const requestBuilder = builders[functionName]; + const requestBuilder = requestBuilders[functionName]; if (!requestBuilder) { // throw new Error(`Tool ${currentAction.tool} not found.`); continue; } - tool = await createActionTool({ action: actionSet, requestBuilder }); + // We've already decrypted the metadata, so we can pass it directly + tool = await createActionTool({ + userId: client.req.user.id, + res: client.res, + action, + requestBuilder, + // Note: intentionally not passing zodSchema, name, and description for assistants API + encrypted, // Pass the encrypted values for OAuth flow + }); if (!tool) { logger.warn( `Invalid action: user: ${client.req.user.id} | thread_id: ${requiredActions[0].thread_id} | run_id: ${requiredActions[0].run_id} | toolName: ${currentAction.tool}`, @@ -410,7 +490,7 @@ async function processRequiredActions(client, requiredActions) { * @param {Object} params - Run params containing user and request information. * @param {ServerRequest} params.req - The request object. * @param {ServerResponse} params.res - The request object. - * @param {Agent} params.agent - The agent to load tools for. + * @param {Pick} The agent tools. */ @@ -420,21 +500,16 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey }) } const endpointsConfig = await getEndpointsConfig(req); - const capabilities = endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? []; - const areToolsEnabled = capabilities.includes(AgentCapabilities.tools); - if (!areToolsEnabled) { - logger.debug('Tools are not enabled for this agent.'); - return {}; - } - - const isFileSearchEnabled = capabilities.includes(AgentCapabilities.file_search); - const isCodeEnabled = capabilities.includes(AgentCapabilities.execute_code); - const areActionsEnabled = capabilities.includes(AgentCapabilities.actions); + const enabledCapabilities = new Set(endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? []); + const checkCapability = (capability) => enabledCapabilities.has(capability); + const areToolsEnabled = checkCapability(AgentCapabilities.tools); const _agentTools = agent.tools?.filter((tool) => { - if (tool === Tools.file_search && !isFileSearchEnabled) { - return false; - } else if (tool === Tools.execute_code && !isCodeEnabled) { + if (tool === Tools.file_search) { + return checkCapability(AgentCapabilities.file_search); + } else if (tool === Tools.execute_code) { + return checkCapability(AgentCapabilities.execute_code); + } else if (!areToolsEnabled && !tool.includes(actionDelimiter)) { return false; } return true; @@ -468,6 +543,10 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey }) continue; } + if (!areToolsEnabled) { + continue; + } + if (tool.mcp === true) { agentTools.push(tool); continue; @@ -500,14 +579,69 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey }) return map; }, {}); - if (!areActionsEnabled) { + if (!checkCapability(AgentCapabilities.actions)) { return { tools: agentTools, toolContextMap, }; } - let actionSets = []; + const actionSets = (await loadActionSets({ agent_id: agent.id })) ?? []; + if (actionSets.length === 0) { + if (_agentTools.length > 0 && agentTools.length === 0) { + logger.warn(`No tools found for the specified tool calls: ${_agentTools.join(', ')}`); + } + return { + tools: agentTools, + toolContextMap, + }; + } + + // Process each action set once (validate spec, decrypt metadata) + const processedActionSets = new Map(); + const domainMap = new Map(); + + for (const action of actionSets) { + const domain = await domainParser(action.metadata.domain, true); + domainMap.set(domain, action); + + // Check if domain is allowed (do this once per action set) + const isDomainAllowed = await isActionDomainAllowed(action.metadata.domain); + if (!isDomainAllowed) { + continue; + } + + // Validate and parse OpenAPI spec once per action set + const validationResult = validateAndParseOpenAPISpec(action.metadata.raw_spec); + if (!validationResult.spec) { + continue; + } + + const encrypted = { + oauth_client_id: action.metadata.oauth_client_id, + oauth_client_secret: action.metadata.oauth_client_secret, + }; + + // Decrypt metadata once per action set + const decryptedAction = { ...action }; + decryptedAction.metadata = await decryptMetadata(action.metadata); + + // Process the OpenAPI spec once per action set + const { requestBuilders, functionSignatures, zodSchemas } = openapiToFunction( + validationResult.spec, + true, + ); + + processedActionSets.set(domain, { + action: decryptedAction, + requestBuilders, + functionSignatures, + zodSchemas, + encrypted, + }); + } + + // Now map tools to the processed action sets const ActionToolMap = {}; for (const toolName of _agentTools) { @@ -515,55 +649,47 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey }) continue; } - if (!actionSets.length) { - actionSets = (await loadActionSets({ agent_id: agent.id })) ?? []; - } - - let actionSet = null; + // Find the matching domain for this tool let currentDomain = ''; - for (let action of actionSets) { - const domain = await domainParser(req, action.metadata.domain, true); + for (const domain of domainMap.keys()) { if (toolName.includes(domain)) { currentDomain = domain; - actionSet = action; break; } } - if (!actionSet) { + if (!currentDomain || !processedActionSets.has(currentDomain)) { continue; } - const validationResult = validateAndParseOpenAPISpec(actionSet.metadata.raw_spec); - if (validationResult.spec) { - const { requestBuilders, functionSignatures, zodSchemas } = openapiToFunction( - validationResult.spec, - true, - ); - const functionName = toolName.replace(`${actionDelimiter}${currentDomain}`, ''); - const functionSig = functionSignatures.find((sig) => sig.name === functionName); - const requestBuilder = requestBuilders[functionName]; - const zodSchema = zodSchemas[functionName]; + const { action, encrypted, zodSchemas, requestBuilders, functionSignatures } = + processedActionSets.get(currentDomain); + const functionName = toolName.replace(`${actionDelimiter}${currentDomain}`, ''); + const functionSig = functionSignatures.find((sig) => sig.name === functionName); + const requestBuilder = requestBuilders[functionName]; + const zodSchema = zodSchemas[functionName]; - if (requestBuilder) { - const tool = await createActionTool({ - req, - res, - action: actionSet, - requestBuilder, - zodSchema, - name: toolName, - description: functionSig.description, - }); - if (!tool) { - logger.warn( - `Invalid action: user: ${req.user.id} | agent_id: ${agent.id} | toolName: ${toolName}`, - ); - throw new Error(`{"type":"${ErrorTypes.INVALID_ACTION}"}`); - } - agentTools.push(tool); - ActionToolMap[toolName] = tool; + if (requestBuilder) { + const tool = await createActionTool({ + userId: req.user.id, + res, + action, + requestBuilder, + zodSchema, + encrypted, + name: toolName, + description: functionSig.description, + }); + + if (!tool) { + logger.warn( + `Invalid action: user: ${req.user.id} | agent_id: ${agent.id} | toolName: ${toolName}`, + ); + throw new Error(`{"type":"${ErrorTypes.INVALID_ACTION}"}`); } + + agentTools.push(tool); + ActionToolMap[toolName] = tool; } } @@ -579,6 +705,7 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey }) } module.exports = { + getToolkitKey, loadAgentTools, loadAndFormatTools, processRequiredActions, diff --git a/api/server/services/Tools/credentials.js b/api/server/services/Tools/credentials.js new file mode 100644 index 0000000000..b50a2460d4 --- /dev/null +++ b/api/server/services/Tools/credentials.js @@ -0,0 +1,56 @@ +const { getUserPluginAuthValue } = require('~/server/services/PluginService'); + +/** + * + * @param {Object} params + * @param {string} params.userId + * @param {string[]} params.authFields + * @param {Set} [params.optional] + * @param {boolean} [params.throwError] + * @returns + */ +const loadAuthValues = async ({ userId, authFields, optional, throwError = true }) => { + let authValues = {}; + + /** + * Finds the first non-empty value for the given authentication field, supporting alternate fields. + * @param {string[]} fields Array of strings representing the authentication fields. Supports alternate fields delimited by "||". + * @returns {Promise<{ authField: string, authValue: string} | null>} An object containing the authentication field and value, or null if not found. + */ + const findAuthValue = async (fields) => { + for (const field of fields) { + let value = process.env[field]; + if (value) { + return { authField: field, authValue: value }; + } + try { + value = await getUserPluginAuthValue(userId, field, throwError); + } catch (err) { + if (optional && optional.has(field)) { + return { authField: field, authValue: undefined }; + } + if (field === fields[fields.length - 1] && !value) { + throw err; + } + } + if (value) { + return { authField: field, authValue: value }; + } + } + return null; + }; + + for (let authField of authFields) { + const fields = authField.split('||'); + const result = await findAuthValue(fields); + if (result) { + authValues[result.authField] = result.authValue; + } + } + + return authValues; +}; + +module.exports = { + loadAuthValues, +}; diff --git a/api/server/services/start/checks.js b/api/server/services/start/checks.js index 100424d35a..fe9cd79edf 100644 --- a/api/server/services/start/checks.js +++ b/api/server/services/start/checks.js @@ -13,6 +13,24 @@ const secretDefaults = { JWT_REFRESH_SECRET: 'eaa5191f2914e30b9387fd84e254e4ba6fc51b4654968a9b0803b456a54b8418', }; +const deprecatedVariables = [ + { + key: 'CHECK_BALANCE', + description: + 'Please use the `balance` field in the `librechat.yaml` config file instead.\nMore info: https://librechat.ai/docs/configuration/librechat_yaml/object_structure/balance#overview', + }, + { + key: 'START_BALANCE', + description: + 'Please use the `balance` field in the `librechat.yaml` config file instead.\nMore info: https://librechat.ai/docs/configuration/librechat_yaml/object_structure/balance#overview', + }, + { + key: 'GOOGLE_API_KEY', + description: + 'Please use the `GOOGLE_SEARCH_API_KEY` environment variable for the Google Search Tool instead.', + }, +]; + /** * Checks environment variables for default secrets and deprecated variables. * Logs warnings for any default secret values being used and for usage of deprecated `GOOGLE_API_KEY`. @@ -37,19 +55,11 @@ function checkVariables() { \u200B`); } - if (process.env.GOOGLE_API_KEY) { - logger.warn( - 'The `GOOGLE_API_KEY` environment variable is deprecated.\nPlease use the `GOOGLE_SEARCH_API_KEY` environment variable instead.', - ); - } - - if (process.env.OPENROUTER_API_KEY) { - logger.warn( - `The \`OPENROUTER_API_KEY\` environment variable is deprecated and its functionality will be removed soon. - Use of this environment variable is highly discouraged as it can lead to unexpected errors when using custom endpoints. - Please use the config (\`librechat.yaml\`) file for setting up OpenRouter, and use \`OPENROUTER_KEY\` or another environment variable instead.`, - ); - } + deprecatedVariables.forEach(({ key, description }) => { + if (process.env[key]) { + logger.warn(`The \`${key}\` environment variable is deprecated. ${description}`); + } + }); checkPasswordReset(); } diff --git a/api/server/services/start/interface.js b/api/server/services/start/interface.js index 98bcb6858e..d9f171ca4e 100644 --- a/api/server/services/start/interface.js +++ b/api/server/services/start/interface.js @@ -18,12 +18,15 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol const { interface: interfaceConfig } = config ?? {}; const { interface: defaults } = configDefaults; const hasModelSpecs = config?.modelSpecs?.list?.length > 0; + const includesAddedEndpoints = config?.modelSpecs?.addedEndpoints?.length > 0; /** @type {TCustomConfig['interface']} */ const loadedInterface = removeNullishValues({ endpointsMenu: interfaceConfig?.endpointsMenu ?? (hasModelSpecs ? false : defaults.endpointsMenu), - modelSelect: interfaceConfig?.modelSelect ?? (hasModelSpecs ? false : defaults.modelSelect), + modelSelect: + interfaceConfig?.modelSelect ?? + (hasModelSpecs ? includesAddedEndpoints : defaults.modelSelect), parameters: interfaceConfig?.parameters ?? (hasModelSpecs ? false : defaults.parameters), presets: interfaceConfig?.presets ?? (hasModelSpecs ? false : defaults.presets), sidePanel: interfaceConfig?.sidePanel ?? defaults.sidePanel, @@ -34,6 +37,8 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol multiConvo: interfaceConfig?.multiConvo ?? defaults.multiConvo, agents: interfaceConfig?.agents ?? defaults.agents, temporaryChat: interfaceConfig?.temporaryChat ?? defaults.temporaryChat, + runCode: interfaceConfig?.runCode ?? defaults.runCode, + customWelcome: interfaceConfig?.customWelcome ?? defaults.customWelcome, }); await updateAccessPermissions(roleName, { @@ -41,12 +46,16 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: loadedInterface.bookmarks }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: loadedInterface.multiConvo }, [PermissionTypes.AGENTS]: { [Permissions.USE]: loadedInterface.agents }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: loadedInterface.temporaryChat }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: loadedInterface.runCode }, }); await updateAccessPermissions(SystemRoles.ADMIN, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: loadedInterface.prompts }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: loadedInterface.bookmarks }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: loadedInterface.multiConvo }, [PermissionTypes.AGENTS]: { [Permissions.USE]: loadedInterface.agents }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: loadedInterface.temporaryChat }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: loadedInterface.runCode }, }); let i = 0; diff --git a/api/server/services/start/interface.spec.js b/api/server/services/start/interface.spec.js index 0041246433..7e248d3dfe 100644 --- a/api/server/services/start/interface.spec.js +++ b/api/server/services/start/interface.spec.js @@ -14,6 +14,8 @@ describe('loadDefaultInterface', () => { bookmarks: true, multiConvo: true, agents: true, + temporaryChat: true, + runCode: true, }, }; const configDefaults = { interface: {} }; @@ -25,6 +27,8 @@ describe('loadDefaultInterface', () => { [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true }, [PermissionTypes.AGENTS]: { [Permissions.USE]: true }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: true }, }); }); @@ -35,6 +39,8 @@ describe('loadDefaultInterface', () => { bookmarks: false, multiConvo: false, agents: false, + temporaryChat: false, + runCode: false, }, }; const configDefaults = { interface: {} }; @@ -46,6 +52,8 @@ describe('loadDefaultInterface', () => { [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: false }, [PermissionTypes.AGENTS]: { [Permissions.USE]: false }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: false }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: false }, }); }); @@ -60,6 +68,8 @@ describe('loadDefaultInterface', () => { [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: undefined }, }); }); @@ -70,6 +80,8 @@ describe('loadDefaultInterface', () => { bookmarks: undefined, multiConvo: undefined, agents: undefined, + temporaryChat: undefined, + runCode: undefined, }, }; const configDefaults = { interface: {} }; @@ -81,6 +93,8 @@ describe('loadDefaultInterface', () => { [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: undefined }, }); }); @@ -91,6 +105,8 @@ describe('loadDefaultInterface', () => { bookmarks: false, multiConvo: undefined, agents: true, + temporaryChat: undefined, + runCode: false, }, }; const configDefaults = { interface: {} }; @@ -102,6 +118,8 @@ describe('loadDefaultInterface', () => { [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined }, [PermissionTypes.AGENTS]: { [Permissions.USE]: true }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: false }, }); }); @@ -113,6 +131,8 @@ describe('loadDefaultInterface', () => { bookmarks: true, multiConvo: true, agents: true, + temporaryChat: true, + runCode: true, }, }; @@ -123,6 +143,8 @@ describe('loadDefaultInterface', () => { [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true }, [PermissionTypes.AGENTS]: { [Permissions.USE]: true }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: true }, }); }); @@ -137,6 +159,8 @@ describe('loadDefaultInterface', () => { [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: undefined }, }); }); @@ -151,6 +175,8 @@ describe('loadDefaultInterface', () => { [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: false }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: undefined }, }); }); @@ -165,6 +191,8 @@ describe('loadDefaultInterface', () => { [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: undefined }, }); }); @@ -175,6 +203,8 @@ describe('loadDefaultInterface', () => { bookmarks: false, multiConvo: true, agents: false, + temporaryChat: true, + runCode: false, }, }; const configDefaults = { interface: {} }; @@ -186,6 +216,8 @@ describe('loadDefaultInterface', () => { [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true }, [PermissionTypes.AGENTS]: { [Permissions.USE]: false }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: false }, }); }); @@ -197,6 +229,8 @@ describe('loadDefaultInterface', () => { bookmarks: true, multiConvo: false, agents: undefined, + temporaryChat: undefined, + runCode: undefined, }, }; @@ -207,6 +241,8 @@ describe('loadDefaultInterface', () => { [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: false }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, + [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, + [PermissionTypes.RUN_CODE]: { [Permissions.USE]: undefined }, }); }); }); diff --git a/api/server/services/start/modelSpecs.js b/api/server/services/start/modelSpecs.js index f249a9c90b..4adc89cc3a 100644 --- a/api/server/services/start/modelSpecs.js +++ b/api/server/services/start/modelSpecs.js @@ -6,9 +6,10 @@ const { logger } = require('~/config'); * Sets up Model Specs from the config (`librechat.yaml`) file. * @param {TCustomConfig['endpoints']} [endpoints] - The loaded custom configuration for endpoints. * @param {TCustomConfig['modelSpecs'] | undefined} [modelSpecs] - The loaded custom configuration for model specs. + * @param {TCustomConfig['interface'] | undefined} [interfaceConfig] - The loaded interface configuration. * @returns {TCustomConfig['modelSpecs'] | undefined} The processed model specs, if any. */ -function processModelSpecs(endpoints, _modelSpecs) { +function processModelSpecs(endpoints, _modelSpecs, interfaceConfig) { if (!_modelSpecs) { return undefined; } @@ -20,6 +21,19 @@ function processModelSpecs(endpoints, _modelSpecs) { const customEndpoints = endpoints?.[EModelEndpoint.custom] ?? []; + if (interfaceConfig.modelSelect !== true && (_modelSpecs.addedEndpoints?.length ?? 0) > 0) { + logger.warn( + `To utilize \`addedEndpoints\`, which allows provider/model selections alongside model specs, set \`modelSelect: true\` in the interface configuration. + + Example: + \`\`\`yaml + interface: + modelSelect: true + \`\`\` + `, + ); + } + for (const spec of list) { if (EModelEndpoint[spec.preset.endpoint] && spec.preset.endpoint !== EModelEndpoint.custom) { modelSpecs.push(spec); diff --git a/api/server/services/twoFactorService.js b/api/server/services/twoFactorService.js new file mode 100644 index 0000000000..d000c8fcfc --- /dev/null +++ b/api/server/services/twoFactorService.js @@ -0,0 +1,224 @@ +const { webcrypto } = require('node:crypto'); +const { decryptV3, decryptV2 } = require('../utils/crypto'); +const { hashBackupCode } = require('~/server/utils/crypto'); + +// Base32 alphabet for TOTP secret encoding. +const BASE32_ALPHABET = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567'; + +/** + * Encodes a Buffer into a Base32 string. + * @param {Buffer} buffer + * @returns {string} + */ +const encodeBase32 = (buffer) => { + let bits = 0; + let value = 0; + let output = ''; + for (const byte of buffer) { + value = (value << 8) | byte; + bits += 8; + while (bits >= 5) { + output += BASE32_ALPHABET[(value >>> (bits - 5)) & 31]; + bits -= 5; + } + } + if (bits > 0) { + output += BASE32_ALPHABET[(value << (5 - bits)) & 31]; + } + return output; +}; + +/** + * Decodes a Base32 string into a Buffer. + * @param {string} base32Str + * @returns {Buffer} + */ +const decodeBase32 = (base32Str) => { + const cleaned = base32Str.replace(/=+$/, '').toUpperCase(); + let bits = 0; + let value = 0; + const output = []; + for (const char of cleaned) { + const idx = BASE32_ALPHABET.indexOf(char); + if (idx === -1) { + continue; + } + value = (value << 5) | idx; + bits += 5; + if (bits >= 8) { + output.push((value >>> (bits - 8)) & 0xff); + bits -= 8; + } + } + return Buffer.from(output); +}; + +/** + * Generates a new TOTP secret (Base32 encoded). + * @returns {string} + */ +const generateTOTPSecret = () => { + const randomArray = new Uint8Array(10); + webcrypto.getRandomValues(randomArray); + return encodeBase32(Buffer.from(randomArray)); +}; + +/** + * Generates a TOTP code based on the secret and time. + * Uses a 30-second time step and produces a 6-digit code. + * @param {string} secret + * @param {number} [forTime=Date.now()] + * @returns {Promise} + */ +const generateTOTP = async (secret, forTime = Date.now()) => { + const timeStep = 30; // seconds + const counter = Math.floor(forTime / 1000 / timeStep); + const counterBuffer = new ArrayBuffer(8); + const counterView = new DataView(counterBuffer); + counterView.setUint32(4, counter, false); + + const keyBuffer = decodeBase32(secret); + const keyArrayBuffer = keyBuffer.buffer.slice( + keyBuffer.byteOffset, + keyBuffer.byteOffset + keyBuffer.byteLength, + ); + + const cryptoKey = await webcrypto.subtle.importKey( + 'raw', + keyArrayBuffer, + { name: 'HMAC', hash: 'SHA-1' }, + false, + ['sign'], + ); + const signatureBuffer = await webcrypto.subtle.sign('HMAC', cryptoKey, counterBuffer); + const hmac = new Uint8Array(signatureBuffer); + + // Dynamic truncation per RFC 4226. + const offset = hmac[hmac.length - 1] & 0xf; + const slice = hmac.slice(offset, offset + 4); + const view = new DataView(slice.buffer, slice.byteOffset, slice.byteLength); + const binaryCode = view.getUint32(0, false) & 0x7fffffff; + const code = (binaryCode % 1000000).toString().padStart(6, '0'); + return code; +}; + +/** + * Verifies a TOTP token by checking a ±1 time step window. + * @param {string} secret + * @param {string} token + * @returns {Promise} + */ +const verifyTOTP = async (secret, token) => { + const timeStepMS = 30 * 1000; + const currentTime = Date.now(); + for (let offset = -1; offset <= 1; offset++) { + const expected = await generateTOTP(secret, currentTime + offset * timeStepMS); + if (expected === token) { + return true; + } + } + return false; +}; + +/** + * Generates backup codes (default count: 10). + * Each code is an 8-character hexadecimal string and stored with its SHA-256 hash. + * @param {number} [count=10] + * @returns {Promise<{ plainCodes: string[], codeObjects: Array<{ codeHash: string, used: boolean, usedAt: Date | null }> }>} + */ +const generateBackupCodes = async (count = 10) => { + const plainCodes = []; + const codeObjects = []; + const encoder = new TextEncoder(); + + for (let i = 0; i < count; i++) { + const randomArray = new Uint8Array(4); + webcrypto.getRandomValues(randomArray); + const code = Array.from(randomArray) + .map((b) => b.toString(16).padStart(2, '0')) + .join(''); + plainCodes.push(code); + + const codeBuffer = encoder.encode(code); + const hashBuffer = await webcrypto.subtle.digest('SHA-256', codeBuffer); + const hashArray = Array.from(new Uint8Array(hashBuffer)); + const codeHash = hashArray.map((b) => b.toString(16).padStart(2, '0')).join(''); + codeObjects.push({ codeHash, used: false, usedAt: null }); + } + return { plainCodes, codeObjects }; +}; + +/** + * Verifies a backup code and, if valid, marks it as used. + * @param {Object} params + * @param {Object} params.user + * @param {string} params.backupCode + * @returns {Promise} + */ +const verifyBackupCode = async ({ user, backupCode }) => { + if (!backupCode || !user || !Array.isArray(user.backupCodes)) { + return false; + } + + const hashedInput = await hashBackupCode(backupCode.trim()); + const matchingCode = user.backupCodes.find( + (codeObj) => codeObj.codeHash === hashedInput && !codeObj.used, + ); + + if (matchingCode) { + const updatedBackupCodes = user.backupCodes.map((codeObj) => + codeObj.codeHash === hashedInput && !codeObj.used + ? { ...codeObj, used: true, usedAt: new Date() } + : codeObj, + ); + // Update the user record with the marked backup code. + const { updateUser } = require('~/models'); + await updateUser(user._id, { backupCodes: updatedBackupCodes }); + return true; + } + return false; +}; + +/** + * Retrieves and decrypts a stored TOTP secret. + * - Uses decryptV3 if the secret has a "v3:" prefix. + * - Falls back to decryptV2 for colon-delimited values. + * - Assumes a 16-character secret is already plain. + * @param {string|null} storedSecret + * @returns {Promise} + */ +const getTOTPSecret = async (storedSecret) => { + if (!storedSecret) { + return null; + } + if (storedSecret.startsWith('v3:')) { + return decryptV3(storedSecret); + } + if (storedSecret.includes(':')) { + return await decryptV2(storedSecret); + } + if (storedSecret.length === 16) { + return storedSecret; + } + return storedSecret; +}; + +/** + * Generates a temporary JWT token for 2FA verification that expires in 5 minutes. + * @param {string} userId + * @returns {string} + */ +const generate2FATempToken = (userId) => { + const { sign } = require('jsonwebtoken'); + return sign({ userId, twoFAPending: true }, process.env.JWT_SECRET, { expiresIn: '5m' }); +}; + +module.exports = { + generateTOTPSecret, + generateTOTP, + verifyTOTP, + generateBackupCodes, + verifyBackupCode, + getTOTPSecret, + generate2FATempToken, +}; diff --git a/api/server/socialLogins.js b/api/server/socialLogins.js index f39d1da596..0eb44514d3 100644 --- a/api/server/socialLogins.js +++ b/api/server/socialLogins.js @@ -1,4 +1,4 @@ -const Redis = require('ioredis'); +const { Keyv } = require('keyv'); const passport = require('passport'); const session = require('express-session'); const MemoryStore = require('memorystore')(session); @@ -12,6 +12,7 @@ const { appleLogin, } = require('~/strategies'); const { isEnabled } = require('~/server/utils'); +const keyvRedis = require('~/cache/keyvRedis'); const { logger } = require('~/config'); /** @@ -19,6 +20,8 @@ const { logger } = require('~/config'); * @param {Express.Application} app */ const configureSocialLogins = (app) => { + logger.info('Configuring social logins...'); + if (process.env.GOOGLE_CLIENT_ID && process.env.GOOGLE_CLIENT_SECRET) { passport.use(googleLogin()); } @@ -41,18 +44,17 @@ const configureSocialLogins = (app) => { process.env.OPENID_SCOPE && process.env.OPENID_SESSION_SECRET ) { + logger.info('Configuring OpenID Connect...'); const sessionOptions = { secret: process.env.OPENID_SESSION_SECRET, resave: false, saveUninitialized: false, }; if (isEnabled(process.env.USE_REDIS)) { - const client = new Redis(process.env.REDIS_URI); - client - .on('error', (err) => logger.error('ioredis error:', err)) - .on('ready', () => logger.info('ioredis successfully initialized.')) - .on('reconnecting', () => logger.info('ioredis reconnecting...')); - sessionOptions.store = new RedisStore({ client, prefix: 'librechat' }); + logger.debug('Using Redis for session storage in OpenID...'); + const keyv = new Keyv({ store: keyvRedis }); + const client = keyv.opts.store.client; + sessionOptions.store = new RedisStore({ client, prefix: 'openid_session' }); } else { sessionOptions.store = new MemoryStore({ checkPeriod: 86400000, // prune expired entries every 24h @@ -61,7 +63,9 @@ const configureSocialLogins = (app) => { app.use(session(sessionOptions)); app.use(passport.session()); setupOpenId(); + + logger.info('OpenID Connect configured.'); } }; -module.exports = configureSocialLogins; \ No newline at end of file +module.exports = configureSocialLogins; diff --git a/api/server/utils/crypto.js b/api/server/utils/crypto.js index ea71df51ad..333cd7573a 100644 --- a/api/server/utils/crypto.js +++ b/api/server/utils/crypto.js @@ -1,27 +1,25 @@ require('dotenv').config(); +const crypto = require('node:crypto'); +const { webcrypto } = crypto; -const { webcrypto } = require('node:crypto'); +// Use hex decoding for both key and IV for legacy methods. const key = Buffer.from(process.env.CREDS_KEY, 'hex'); const iv = Buffer.from(process.env.CREDS_IV, 'hex'); const algorithm = 'AES-CBC'; +// --- Legacy v1/v2 Setup: AES-CBC with fixed key and IV --- + async function encrypt(value) { const cryptoKey = await webcrypto.subtle.importKey('raw', key, { name: algorithm }, false, [ 'encrypt', ]); - const encoder = new TextEncoder(); const data = encoder.encode(value); - const encryptedBuffer = await webcrypto.subtle.encrypt( - { - name: algorithm, - iv: iv, - }, + { name: algorithm, iv: iv }, cryptoKey, data, ); - return Buffer.from(encryptedBuffer).toString('hex'); } @@ -29,73 +27,85 @@ async function decrypt(encryptedValue) { const cryptoKey = await webcrypto.subtle.importKey('raw', key, { name: algorithm }, false, [ 'decrypt', ]); - const encryptedBuffer = Buffer.from(encryptedValue, 'hex'); - const decryptedBuffer = await webcrypto.subtle.decrypt( - { - name: algorithm, - iv: iv, - }, + { name: algorithm, iv: iv }, cryptoKey, encryptedBuffer, ); - const decoder = new TextDecoder(); return decoder.decode(decryptedBuffer); } -// Programmatically generate iv +// --- v2: AES-CBC with a random IV per encryption --- + async function encryptV2(value) { const gen_iv = webcrypto.getRandomValues(new Uint8Array(16)); - const cryptoKey = await webcrypto.subtle.importKey('raw', key, { name: algorithm }, false, [ 'encrypt', ]); - const encoder = new TextEncoder(); const data = encoder.encode(value); - const encryptedBuffer = await webcrypto.subtle.encrypt( - { - name: algorithm, - iv: gen_iv, - }, + { name: algorithm, iv: gen_iv }, cryptoKey, data, ); - return Buffer.from(gen_iv).toString('hex') + ':' + Buffer.from(encryptedBuffer).toString('hex'); } async function decryptV2(encryptedValue) { const parts = encryptedValue.split(':'); - // Already decrypted from an earlier invocation if (parts.length === 1) { return parts[0]; } const gen_iv = Buffer.from(parts.shift(), 'hex'); const encrypted = parts.join(':'); - const cryptoKey = await webcrypto.subtle.importKey('raw', key, { name: algorithm }, false, [ 'decrypt', ]); - const encryptedBuffer = Buffer.from(encrypted, 'hex'); - const decryptedBuffer = await webcrypto.subtle.decrypt( - { - name: algorithm, - iv: gen_iv, - }, + { name: algorithm, iv: gen_iv }, cryptoKey, encryptedBuffer, ); - const decoder = new TextDecoder(); return decoder.decode(decryptedBuffer); } +// --- v3: AES-256-CTR using Node's crypto functions --- +const algorithm_v3 = 'aes-256-ctr'; + +/** + * Encrypts a value using AES-256-CTR. + * Note: AES-256 requires a 32-byte key. Ensure that process.env.CREDS_KEY is a 64-character hex string. + * + * @param {string} value - The plaintext to encrypt. + * @returns {string} The encrypted string with a "v3:" prefix. + */ +function encryptV3(value) { + if (key.length !== 32) { + throw new Error(`Invalid key length: expected 32 bytes, got ${key.length} bytes`); + } + const iv_v3 = crypto.randomBytes(16); + const cipher = crypto.createCipheriv(algorithm_v3, key, iv_v3); + const encrypted = Buffer.concat([cipher.update(value, 'utf8'), cipher.final()]); + return `v3:${iv_v3.toString('hex')}:${encrypted.toString('hex')}`; +} + +function decryptV3(encryptedValue) { + const parts = encryptedValue.split(':'); + if (parts[0] !== 'v3') { + throw new Error('Not a v3 encrypted value'); + } + const iv_v3 = Buffer.from(parts[1], 'hex'); + const encryptedText = Buffer.from(parts.slice(2).join(':'), 'hex'); + const decipher = crypto.createDecipheriv(algorithm_v3, key, iv_v3); + const decrypted = Buffer.concat([decipher.update(encryptedText), decipher.final()]); + return decrypted.toString('utf8'); +} + async function hashToken(str) { const data = new TextEncoder().encode(str); const hashBuffer = await webcrypto.subtle.digest('SHA-256', data); @@ -106,10 +116,32 @@ async function getRandomValues(length) { if (!Number.isInteger(length) || length <= 0) { throw new Error('Length must be a positive integer'); } - const randomValues = new Uint8Array(length); webcrypto.getRandomValues(randomValues); return Buffer.from(randomValues).toString('hex'); } -module.exports = { encrypt, decrypt, encryptV2, decryptV2, hashToken, getRandomValues }; +/** + * Computes SHA-256 hash for the given input. + * @param {string} input + * @returns {Promise} + */ +async function hashBackupCode(input) { + const encoder = new TextEncoder(); + const data = encoder.encode(input); + const hashBuffer = await webcrypto.subtle.digest('SHA-256', data); + const hashArray = Array.from(new Uint8Array(hashBuffer)); + return hashArray.map((b) => b.toString(16).padStart(2, '0')).join(''); +} + +module.exports = { + encrypt, + decrypt, + encryptV2, + decryptV2, + encryptV3, + decryptV3, + hashToken, + hashBackupCode, + getRandomValues, +}; diff --git a/api/server/utils/handleText.js b/api/server/utils/handleText.js index 8c681d8f4e..f593d6c866 100644 --- a/api/server/utils/handleText.js +++ b/api/server/utils/handleText.js @@ -203,6 +203,8 @@ function generateConfig(key, baseURL, endpoint) { AgentCapabilities.artifacts, AgentCapabilities.actions, AgentCapabilities.tools, + AgentCapabilities.ocr, + AgentCapabilities.chain, ]; } diff --git a/api/server/utils/staticCache.js b/api/server/utils/staticCache.js index a8001c7e0a..23713ddf6f 100644 --- a/api/server/utils/staticCache.js +++ b/api/server/utils/staticCache.js @@ -1,4 +1,4 @@ -const express = require('express'); +const expressStaticGzip = require('express-static-gzip'); const oneDayInSeconds = 24 * 60 * 60; @@ -6,13 +6,13 @@ const sMaxAge = process.env.STATIC_CACHE_S_MAX_AGE || oneDayInSeconds; const maxAge = process.env.STATIC_CACHE_MAX_AGE || oneDayInSeconds * 2; const staticCache = (staticPath) => - express.static(staticPath, { - setHeaders: (res) => { - if (process.env.NODE_ENV?.toLowerCase() !== 'production') { - return; + expressStaticGzip(staticPath, { + enableBrotli: false, // disable Brotli, only using gzip + orderPreference: ['gz'], + setHeaders: (res, _path) => { + if (process.env.NODE_ENV?.toLowerCase() === 'production') { + res.setHeader('Cache-Control', `public, max-age=${maxAge}, s-maxage=${sMaxAge}`); } - - res.setHeader('Cache-Control', `public, max-age=${maxAge}, s-maxage=${sMaxAge}`); }, }); diff --git a/api/server/utils/streamResponse.js b/api/server/utils/streamResponse.js index 0f042339a9..bb8d63b229 100644 --- a/api/server/utils/streamResponse.js +++ b/api/server/utils/streamResponse.js @@ -70,7 +70,13 @@ const sendError = async (req, res, options, callback) => { } if (shouldSaveMessage) { - await saveMessage(req, { ...errorMessage, user }); + await saveMessage( + req, + { ...errorMessage, user }, + { + context: 'api/server/utils/streamResponse.js - sendError', + }, + ); } if (!errorMessage.error) { diff --git a/api/strategies/googleStrategy.js b/api/strategies/googleStrategy.js index ab8a268953..fd65823327 100644 --- a/api/strategies/googleStrategy.js +++ b/api/strategies/googleStrategy.js @@ -6,7 +6,7 @@ const getProfileDetails = ({ profile }) => ({ id: profile.id, avatarUrl: profile.photos[0].value, username: profile.name.givenName, - name: `${profile.name.givenName} ${profile.name.familyName}`, + name: `${profile.name.givenName}${profile.name.familyName ? ` ${profile.name.familyName}` : ''}`, emailVerified: profile.emails[0].verified, }); diff --git a/api/strategies/jwtStrategy.js b/api/strategies/jwtStrategy.js index e65b284950..ac19e92ac3 100644 --- a/api/strategies/jwtStrategy.js +++ b/api/strategies/jwtStrategy.js @@ -12,7 +12,7 @@ const jwtLogin = async () => }, async (payload, done) => { try { - const user = await getUserById(payload?.id, '-password -__v'); + const user = await getUserById(payload?.id, '-password -__v -totpSecret'); if (user) { user.id = user._id.toString(); if (!user.role) { diff --git a/api/strategies/ldapStrategy.js b/api/strategies/ldapStrategy.js index 4a2c1b827b..5ec279b982 100644 --- a/api/strategies/ldapStrategy.js +++ b/api/strategies/ldapStrategy.js @@ -18,6 +18,7 @@ const { LDAP_USERNAME, LDAP_EMAIL, LDAP_TLS_REJECT_UNAUTHORIZED, + LDAP_STARTTLS, } = process.env; // Check required environment variables @@ -50,6 +51,7 @@ if (LDAP_EMAIL) { searchAttributes.push(LDAP_EMAIL); } const rejectUnauthorized = isEnabled(LDAP_TLS_REJECT_UNAUTHORIZED); +const startTLS = isEnabled(LDAP_STARTTLS); const ldapOptions = { server: { @@ -72,6 +74,7 @@ const ldapOptions = { })(), }, }), + ...(startTLS && { starttls: true }), }, usernameField: 'email', passwordField: 'password', diff --git a/api/test/__mocks__/KeyvMongo.js b/api/test/__mocks__/KeyvMongo.js deleted file mode 100644 index f88bc144be..0000000000 --- a/api/test/__mocks__/KeyvMongo.js +++ /dev/null @@ -1,30 +0,0 @@ -const mockGet = jest.fn(); -const mockSet = jest.fn(); - -jest.mock('@keyv/mongo', () => { - const EventEmitter = require('events'); - class KeyvMongo extends EventEmitter { - constructor(url = 'mongodb://127.0.0.1:27017', options) { - super(); - this.ttlSupport = false; - url = url ?? {}; - if (typeof url === 'string') { - url = { url }; - } - if (url.uri) { - url = { url: url.uri, ...url }; - } - this.opts = { - url, - collection: 'keyv', - ...url, - ...options, - }; - } - - get = mockGet; - set = mockSet; - } - - return KeyvMongo; -}); diff --git a/api/test/__mocks__/logger.js b/api/test/__mocks__/logger.js index caeb004e39..549c57d5a4 100644 --- a/api/test/__mocks__/logger.js +++ b/api/test/__mocks__/logger.js @@ -39,7 +39,10 @@ jest.mock('winston-daily-rotate-file', () => { }); jest.mock('~/config', () => { + const actualModule = jest.requireActual('~/config'); return { + sendEvent: actualModule.sendEvent, + createAxiosInstance: actualModule.createAxiosInstance, logger: { info: jest.fn(), warn: jest.fn(), diff --git a/api/typedefs.js b/api/typedefs.js index bd97bd93fa..d65d8c9191 100644 --- a/api/typedefs.js +++ b/api/typedefs.js @@ -7,6 +7,11 @@ * @typedef {import('openai').OpenAI} OpenAI * @memberof typedefs */ +/** + * @exports OpenAIImagesResponse + * @typedef {Promise} OpenAIImagesResponse + * @memberof typedefs + */ /** * @exports ServerRequest @@ -14,12 +19,84 @@ * @memberof typedefs */ +/** + * @template T + * @typedef {ReadableStream | NodeJS.ReadableStream} NodeStream + * @memberof typedefs + */ + +/** + * @template T + * @typedef {(req: ServerRequest, filepath: string) => Promise>} NodeStreamDownloader + * @memberof typedefs + */ + /** * @exports ServerResponse * @typedef {import('express').Response} ServerResponse * @memberof typedefs */ +/** + * @exports NextFunction + * @typedef {import('express').NextFunction} NextFunction + * @memberof typedefs + */ + +/** + * @exports Graph + * @typedef {import('@librechat/agents').Graph} Graph + * @memberof typedefs + */ + +/** + * @exports StandardGraph + * @typedef {import('@librechat/agents').StandardGraph} StandardGraph + * @memberof typedefs + */ + +/** + * @exports EventHandler + * @typedef {import('@librechat/agents').EventHandler} EventHandler + * @memberof typedefs + */ + +/** + * @exports ModelEndData + * @typedef {import('@librechat/agents').ModelEndData} ModelEndData + * @memberof typedefs + */ + +/** + * @exports ToolEndData + * @typedef {import('@librechat/agents').ToolEndData} ToolEndData + * @memberof typedefs + */ + +/** + * @exports ToolEndCallback + * @typedef {import('@librechat/agents').ToolEndCallback} ToolEndCallback + * @memberof typedefs + */ + +/** + * @exports ChatModelStreamHandler + * @typedef {import('@librechat/agents').ChatModelStreamHandler} ChatModelStreamHandler + * @memberof typedefs + */ + +/** + * @exports ContentAggregator + * @typedef {import('@librechat/agents').ContentAggregatorResult['aggregateContent']} ContentAggregator + * @memberof typedefs + */ + +/** + * @exports GraphEvents + * @typedef {import('@librechat/agents').GraphEvents} GraphEvents + * @memberof typedefs + */ + /** * @exports AgentRun * @typedef {import('@librechat/agents').Run} AgentRun @@ -74,12 +151,6 @@ * @memberof typedefs */ -/** - * @exports ToolEndData - * @typedef {import('@librechat/agents').ToolEndData} ToolEndData - * @memberof typedefs - */ - /** * @exports BaseMessage * @typedef {import('@langchain/core/messages').BaseMessage} BaseMessage @@ -397,6 +468,12 @@ * @memberof typedefs */ +/** + * @exports MessageContentImageUrl + * @typedef {import('librechat-data-provider').Agents.MessageContentImageUrl} MessageContentImageUrl + * @memberof typedefs + */ + /** Prompts */ /** * @exports TPrompt @@ -754,39 +831,26 @@ * @memberof typedefs */ -/** - * @exports ObjectId - * @typedef {import('mongoose').Types.ObjectId} ObjectId - * @memberof typedefs - */ - /** * @exports MongoFile - * @typedef {import('~/models/schema/fileSchema.js').MongoFile} MongoFile + * @typedef {import('@librechat/data-schemas').IMongoFile} MongoFile * @memberof typedefs */ - /** - * @exports ToolCallData - * @typedef {import('~/models/schema/toolCallSchema.js').ToolCallData} ToolCallData + * @exports IBalance + * @typedef {import('@librechat/data-schemas').IBalance} IBalance * @memberof typedefs */ /** * @exports MongoUser - * @typedef {import('~/models/schema/userSchema.js').MongoUser} MongoUser + * @typedef {import('@librechat/data-schemas').IUser} MongoUser * @memberof typedefs */ /** - * @exports MongoProject - * @typedef {import('~/models/schema/projectSchema.js').MongoProject} MongoProject - * @memberof typedefs - */ - -/** - * @exports MongoPromptGroup - * @typedef {import('~/models/schema/promptSchema.js').MongoPromptGroup} MongoPromptGroup + * @exports ObjectId + * @typedef {import('mongoose').Types.ObjectId} ObjectId * @memberof typedefs */ @@ -817,8 +881,9 @@ /** * @typedef {Partial & { * message?: string, - * signal?: AbortSignal - * memory?: ConversationSummaryBufferMemory + * signal?: AbortSignal, + * memory?: ConversationSummaryBufferMemory, + * tool_resources?: AgentToolResources, * }} LoadToolOptions * @memberof typedefs */ @@ -829,6 +894,12 @@ * @memberof typedefs */ +/** + * @exports TEndpointOption + * @typedef {import('librechat-data-provider').TEndpointOption} TEndpointOption + * @memberof typedefs + */ + /** * @exports TAttachment * @typedef {import('librechat-data-provider').TAttachment} TAttachment @@ -1811,3 +1882,51 @@ * @typedef {Promise<{ message: TMessage, conversation: TConversation }> | undefined} ClientDatabaseSavePromise * @memberof typedefs */ + +/** + * @exports OCRImage + * @typedef {Object} OCRImage + * @property {string} id - The identifier of the image. + * @property {number} top_left_x - X-coordinate of the top left corner of the image. + * @property {number} top_left_y - Y-coordinate of the top left corner of the image. + * @property {number} bottom_right_x - X-coordinate of the bottom right corner of the image. + * @property {number} bottom_right_y - Y-coordinate of the bottom right corner of the image. + * @property {string} image_base64 - Base64-encoded image data. + * @memberof typedefs + */ + +/** + * @exports PageDimensions + * @typedef {Object} PageDimensions + * @property {number} dpi - The dots per inch resolution of the page. + * @property {number} height - The height of the page in pixels. + * @property {number} width - The width of the page in pixels. + * @memberof typedefs + */ + +/** + * @exports OCRPage + * @typedef {Object} OCRPage + * @property {number} index - The index of the page in the document. + * @property {string} markdown - The extracted text content of the page in markdown format. + * @property {OCRImage[]} images - Array of images found on the page. + * @property {PageDimensions} dimensions - The dimensions of the page. + * @memberof typedefs + */ + +/** + * @exports OCRUsageInfo + * @typedef {Object} OCRUsageInfo + * @property {number} pages_processed - Number of pages processed in the document. + * @property {number} doc_size_bytes - Size of the document in bytes. + * @memberof typedefs + */ + +/** + * @exports OCRResult + * @typedef {Object} OCRResult + * @property {OCRPage[]} pages - Array of pages extracted from the document. + * @property {string} model - The model used for OCR processing. + * @property {OCRUsageInfo} usage_info - Usage information for the OCR operation. + * @memberof typedefs + */ diff --git a/api/utils/axios.js b/api/utils/axios.js index 8b12a5ca99..2beff55e1f 100644 --- a/api/utils/axios.js +++ b/api/utils/axios.js @@ -5,41 +5,42 @@ const { logger } = require('~/config'); * * @param {Object} options - The options object. * @param {string} options.message - The custom message to be logged. - * @param {Error} options.error - The Axios error object. + * @param {import('axios').AxiosError} options.error - The Axios error object. + * @returns {string} The log message. */ const logAxiosError = ({ message, error }) => { - const timedOutMessage = 'Cannot read properties of undefined (reading \'status\')'; - if (error.response) { - logger.error( - `${message} The request was made and the server responded with a status code that falls out of the range of 2xx: ${ - error.message ? error.message : '' - }. Error response data:\n`, - { - headers: error.response?.headers, - status: error.response?.status, - data: error.response?.data, - }, - ); - } else if (error.request) { - logger.error( - `${message} The request was made but no response was received: ${ - error.message ? error.message : '' - }. Error Request:\n`, - { - request: error.request, - }, - ); - } else if (error?.message?.includes(timedOutMessage)) { - logger.error( - `${message}\nThe request either timed out or was unsuccessful. Error message:\n`, - error, - ); - } else { - logger.error( - `${message}\nSomething happened in setting up the request. Error message:\n`, - error, - ); + let logMessage = message; + try { + const stack = error.stack || 'No stack trace available'; + + if (error.response?.status) { + const { status, headers, data } = error.response; + logMessage = `${message} The server responded with status ${status}: ${error.message}`; + logger.error(logMessage, { + status, + headers, + data, + stack, + }); + } else if (error.request) { + const { method, url } = error.config || {}; + logMessage = `${message} No response received for ${method ? method.toUpperCase() : ''} ${url || ''}: ${error.message}`; + logger.error(logMessage, { + requestInfo: { method, url }, + stack, + }); + } else if (error?.message?.includes('Cannot read properties of undefined (reading \'status\')')) { + logMessage = `${message} It appears the request timed out or was unsuccessful: ${error.message}`; + logger.error(logMessage, { stack }); + } else { + logMessage = `${message} An error occurred while setting up the request: ${error.message}`; + logger.error(logMessage, { stack }); + } + } catch (err) { + logMessage = `Error in logAxiosError: ${err.message}`; + logger.error(logMessage, { stack: err.stack || 'No stack trace available' }); } + return logMessage; }; module.exports = { logAxiosError }; diff --git a/api/utils/debug.js b/api/utils/debug.js deleted file mode 100644 index 68599eea38..0000000000 --- a/api/utils/debug.js +++ /dev/null @@ -1,56 +0,0 @@ -const levels = { - NONE: 0, - LOW: 1, - MEDIUM: 2, - HIGH: 3, -}; - -let level = levels.HIGH; - -module.exports = { - levels, - setLevel: (l) => (level = l), - log: { - parameters: (parameters) => { - if (levels.HIGH > level) { - return; - } - console.group(); - parameters.forEach((p) => console.log(`${p.name}:`, p.value)); - console.groupEnd(); - }, - functionName: (name) => { - if (levels.MEDIUM > level) { - return; - } - console.log(`\nEXECUTING: ${name}\n`); - }, - flow: (flow) => { - if (levels.LOW > level) { - return; - } - console.log(`\n\n\nBEGIN FLOW: ${flow}\n\n\n`); - }, - variable: ({ name, value }) => { - if (levels.HIGH > level) { - return; - } - console.group(); - console.group(); - console.log(`VARIABLE ${name}:`, value); - console.groupEnd(); - console.groupEnd(); - }, - request: () => (req, res, next) => { - if (levels.HIGH > level) { - return next(); - } - console.log('Hit URL', req.url, 'with following:'); - console.group(); - console.log('Query:', req.query); - console.log('Body:', req.body); - console.groupEnd(); - return next(); - }, - }, -}; diff --git a/api/utils/tokens.js b/api/utils/tokens.js index 0541f4f301..7ff59acfdd 100644 --- a/api/utils/tokens.js +++ b/api/utils/tokens.js @@ -2,7 +2,9 @@ const z = require('zod'); const { EModelEndpoint } = require('librechat-data-provider'); const openAIModels = { + 'o4-mini': 200000, 'o3-mini': 195000, // -5000 from max + o3: 200000, o1: 195000, // -5000 from max 'o1-mini': 127500, // -500 from max 'o1-preview': 127500, // -500 from max @@ -13,6 +15,10 @@ const openAIModels = { 'gpt-4-32k-0613': 32758, // -10 from max 'gpt-4-1106': 127500, // -500 from max 'gpt-4-0125': 127500, // -500 from max + 'gpt-4.5': 127500, // -500 from max + 'gpt-4.1': 1047576, + 'gpt-4.1-mini': 1047576, + 'gpt-4.1-nano': 1047576, 'gpt-4o': 127500, // -500 from max 'gpt-4o-mini': 127500, // -500 from max 'gpt-4o-2024-05-13': 127500, // -500 from max @@ -33,8 +39,14 @@ const mistralModels = { 'mistral-7b': 31990, // -10 from max 'mistral-small': 31990, // -10 from max 'mixtral-8x7b': 31990, // -10 from max + 'mistral-large': 131000, 'mistral-large-2402': 127500, 'mistral-large-2407': 127500, + 'pixtral-large': 131000, + 'mistral-saba': 32000, + codestral: 256000, + 'ministral-8b': 131000, + 'ministral-3b': 131000, }; const cohereModels = { @@ -48,9 +60,16 @@ const cohereModels = { const googleModels = { /* Max I/O is combined so we subtract the amount from max response tokens for actual total */ + gemma: 8196, + 'gemma-2': 32768, + 'gemma-3': 32768, + 'gemma-3-27b': 131072, gemini: 30720, // -2048 from max 'gemini-pro-vision': 12288, 'gemini-exp': 2000000, + 'gemini-2.5': 1000000, // 1M input tokens, 64k output tokens + 'gemini-2.5-pro': 1000000, + 'gemini-2.5-flash': 1000000, 'gemini-2.0': 2000000, 'gemini-2.0-flash': 1000000, 'gemini-2.0-flash-lite': 1000000, @@ -74,6 +93,7 @@ const anthropicModels = { 'claude-instant': 100000, 'claude-2': 100000, 'claude-2.1': 200000, + 'claude-3': 200000, 'claude-3-haiku': 200000, 'claude-3-sonnet': 200000, 'claude-3-opus': 200000, @@ -81,6 +101,8 @@ const anthropicModels = { 'claude-3-5-haiku': 200000, 'claude-3-5-sonnet': 200000, 'claude-3.5-sonnet': 200000, + 'claude-3-7-sonnet': 200000, + 'claude-3.7-sonnet': 200000, 'claude-3-5-sonnet-latest': 200000, 'claude-3.5-sonnet-latest': 200000, }; @@ -88,6 +110,7 @@ const anthropicModels = { const deepseekModels = { 'deepseek-reasoner': 63000, // -1000 from max (API) deepseek: 63000, // -1000 from max (API) + 'deepseek.r1': 127500, }; const metaModels = { @@ -183,7 +206,23 @@ const bedrockModels = { ...amazonModels, }; -const aggregateModels = { ...openAIModels, ...googleModels, ...bedrockModels }; +const xAIModels = { + grok: 131072, + 'grok-beta': 131072, + 'grok-vision-beta': 8192, + 'grok-2': 131072, + 'grok-2-latest': 131072, + 'grok-2-1212': 131072, + 'grok-2-vision': 32768, + 'grok-2-vision-latest': 32768, + 'grok-2-vision-1212': 32768, + 'grok-3': 131072, + 'grok-3-fast': 131072, + 'grok-3-mini': 131072, + 'grok-3-mini-fast': 131072, +}; + +const aggregateModels = { ...openAIModels, ...googleModels, ...bedrockModels, ...xAIModels }; const maxTokensMap = { [EModelEndpoint.azureOpenAI]: openAIModels, @@ -202,12 +241,15 @@ const modelMaxOutputs = { system_default: 1024, }; +/** Outputs from https://docs.anthropic.com/en/docs/about-claude/models/all-models#model-names */ const anthropicMaxOutputs = { 'claude-3-haiku': 4096, 'claude-3-sonnet': 4096, 'claude-3-opus': 4096, 'claude-3.5-sonnet': 8192, 'claude-3-5-sonnet': 8192, + 'claude-3.7-sonnet': 128000, + 'claude-3-7-sonnet': 128000, }; const maxOutputTokensMap = { diff --git a/api/utils/tokens.spec.js b/api/utils/tokens.spec.js index eb1fd85495..57a9f72e89 100644 --- a/api/utils/tokens.spec.js +++ b/api/utils/tokens.spec.js @@ -103,6 +103,53 @@ describe('getModelMaxTokens', () => { ); }); + test('should return correct tokens for gpt-4.5 matches', () => { + expect(getModelMaxTokens('gpt-4.5')).toBe(maxTokensMap[EModelEndpoint.openAI]['gpt-4.5']); + expect(getModelMaxTokens('gpt-4.5-preview')).toBe( + maxTokensMap[EModelEndpoint.openAI]['gpt-4.5'], + ); + expect(getModelMaxTokens('openai/gpt-4.5-preview')).toBe( + maxTokensMap[EModelEndpoint.openAI]['gpt-4.5'], + ); + }); + + test('should return correct tokens for gpt-4.1 matches', () => { + expect(getModelMaxTokens('gpt-4.1')).toBe(maxTokensMap[EModelEndpoint.openAI]['gpt-4.1']); + expect(getModelMaxTokens('gpt-4.1-preview')).toBe( + maxTokensMap[EModelEndpoint.openAI]['gpt-4.1'], + ); + expect(getModelMaxTokens('openai/gpt-4.1')).toBe( + maxTokensMap[EModelEndpoint.openAI]['gpt-4.1'], + ); + expect(getModelMaxTokens('gpt-4.1-2024-08-06')).toBe( + maxTokensMap[EModelEndpoint.openAI]['gpt-4.1'], + ); + }); + + test('should return correct tokens for gpt-4.1-mini matches', () => { + expect(getModelMaxTokens('gpt-4.1-mini')).toBe( + maxTokensMap[EModelEndpoint.openAI]['gpt-4.1-mini'], + ); + expect(getModelMaxTokens('gpt-4.1-mini-preview')).toBe( + maxTokensMap[EModelEndpoint.openAI]['gpt-4.1-mini'], + ); + expect(getModelMaxTokens('openai/gpt-4.1-mini')).toBe( + maxTokensMap[EModelEndpoint.openAI]['gpt-4.1-mini'], + ); + }); + + test('should return correct tokens for gpt-4.1-nano matches', () => { + expect(getModelMaxTokens('gpt-4.1-nano')).toBe( + maxTokensMap[EModelEndpoint.openAI]['gpt-4.1-nano'], + ); + expect(getModelMaxTokens('gpt-4.1-nano-preview')).toBe( + maxTokensMap[EModelEndpoint.openAI]['gpt-4.1-nano'], + ); + expect(getModelMaxTokens('openai/gpt-4.1-nano')).toBe( + maxTokensMap[EModelEndpoint.openAI]['gpt-4.1-nano'], + ); + }); + test('should return correct tokens for Anthropic models', () => { const models = [ 'claude-2.1', @@ -116,6 +163,7 @@ describe('getModelMaxTokens', () => { 'claude-3-sonnet', 'claude-3-opus', 'claude-3-5-sonnet', + 'claude-3-7-sonnet', ]; const maxTokens = { @@ -292,6 +340,15 @@ describe('getModelMaxTokens', () => { expect(getModelMaxTokens('o1-preview-something')).toBe(o1PreviewTokens); expect(getModelMaxTokens('openai/o1-preview-something')).toBe(o1PreviewTokens); }); + + test('should return correct max context tokens for o4-mini and o3', () => { + const o4MiniTokens = maxTokensMap[EModelEndpoint.openAI]['o4-mini']; + const o3Tokens = maxTokensMap[EModelEndpoint.openAI]['o3']; + expect(getModelMaxTokens('o4-mini')).toBe(o4MiniTokens); + expect(getModelMaxTokens('openai/o4-mini')).toBe(o4MiniTokens); + expect(getModelMaxTokens('o3')).toBe(o3Tokens); + expect(getModelMaxTokens('openai/o3')).toBe(o3Tokens); + }); }); describe('matchModelName', () => { @@ -344,6 +401,25 @@ describe('matchModelName', () => { expect(matchModelName('gpt-4-0125-vision-preview')).toBe('gpt-4-0125'); }); + it('should return the closest matching key for gpt-4.1 matches', () => { + expect(matchModelName('openai/gpt-4.1')).toBe('gpt-4.1'); + expect(matchModelName('gpt-4.1-preview')).toBe('gpt-4.1'); + expect(matchModelName('gpt-4.1-2024-08-06')).toBe('gpt-4.1'); + expect(matchModelName('gpt-4.1-2024-08-06-0718')).toBe('gpt-4.1'); + }); + + it('should return the closest matching key for gpt-4.1-mini matches', () => { + expect(matchModelName('openai/gpt-4.1-mini')).toBe('gpt-4.1-mini'); + expect(matchModelName('gpt-4.1-mini-preview')).toBe('gpt-4.1-mini'); + expect(matchModelName('gpt-4.1-mini-2024-08-06')).toBe('gpt-4.1-mini'); + }); + + it('should return the closest matching key for gpt-4.1-nano matches', () => { + expect(matchModelName('openai/gpt-4.1-nano')).toBe('gpt-4.1-nano'); + expect(matchModelName('gpt-4.1-nano-preview')).toBe('gpt-4.1-nano'); + expect(matchModelName('gpt-4.1-nano-2024-08-06')).toBe('gpt-4.1-nano'); + }); + // Tests for Google models it('should return the exact model name if it exists in maxTokensMap - Google models', () => { expect(matchModelName('text-bison-32k', EModelEndpoint.google)).toBe('text-bison-32k'); @@ -412,6 +488,9 @@ describe('Meta Models Tests', () => { expect(getModelMaxTokens('deepseek-reasoner')).toBe( maxTokensMap[EModelEndpoint.openAI]['deepseek-reasoner'], ); + expect(getModelMaxTokens('deepseek.r1')).toBe( + maxTokensMap[EModelEndpoint.openAI]['deepseek.r1'], + ); }); }); @@ -483,3 +562,90 @@ describe('Meta Models Tests', () => { }); }); }); + +describe('Grok Model Tests - Tokens', () => { + describe('getModelMaxTokens', () => { + test('should return correct tokens for Grok vision models', () => { + expect(getModelMaxTokens('grok-2-vision-1212')).toBe(32768); + expect(getModelMaxTokens('grok-2-vision')).toBe(32768); + expect(getModelMaxTokens('grok-2-vision-latest')).toBe(32768); + }); + + test('should return correct tokens for Grok beta models', () => { + expect(getModelMaxTokens('grok-vision-beta')).toBe(8192); + expect(getModelMaxTokens('grok-beta')).toBe(131072); + }); + + test('should return correct tokens for Grok text models', () => { + expect(getModelMaxTokens('grok-2-1212')).toBe(131072); + expect(getModelMaxTokens('grok-2')).toBe(131072); + expect(getModelMaxTokens('grok-2-latest')).toBe(131072); + }); + + test('should return correct tokens for Grok 3 series models', () => { + expect(getModelMaxTokens('grok-3')).toBe(131072); + expect(getModelMaxTokens('grok-3-fast')).toBe(131072); + expect(getModelMaxTokens('grok-3-mini')).toBe(131072); + expect(getModelMaxTokens('grok-3-mini-fast')).toBe(131072); + }); + + test('should handle partial matches for Grok models with prefixes', () => { + // Vision models should match before general models + expect(getModelMaxTokens('xai/grok-2-vision-1212')).toBe(32768); + expect(getModelMaxTokens('xai/grok-2-vision')).toBe(32768); + expect(getModelMaxTokens('xai/grok-2-vision-latest')).toBe(32768); + // Beta models + expect(getModelMaxTokens('xai/grok-vision-beta')).toBe(8192); + expect(getModelMaxTokens('xai/grok-beta')).toBe(131072); + // Text models + expect(getModelMaxTokens('xai/grok-2-1212')).toBe(131072); + expect(getModelMaxTokens('xai/grok-2')).toBe(131072); + expect(getModelMaxTokens('xai/grok-2-latest')).toBe(131072); + // Grok 3 models + expect(getModelMaxTokens('xai/grok-3')).toBe(131072); + expect(getModelMaxTokens('xai/grok-3-fast')).toBe(131072); + expect(getModelMaxTokens('xai/grok-3-mini')).toBe(131072); + expect(getModelMaxTokens('xai/grok-3-mini-fast')).toBe(131072); + }); + }); + + describe('matchModelName', () => { + test('should match exact Grok model names', () => { + // Vision models + expect(matchModelName('grok-2-vision-1212')).toBe('grok-2-vision-1212'); + expect(matchModelName('grok-2-vision')).toBe('grok-2-vision'); + expect(matchModelName('grok-2-vision-latest')).toBe('grok-2-vision-latest'); + // Beta models + expect(matchModelName('grok-vision-beta')).toBe('grok-vision-beta'); + expect(matchModelName('grok-beta')).toBe('grok-beta'); + // Text models + expect(matchModelName('grok-2-1212')).toBe('grok-2-1212'); + expect(matchModelName('grok-2')).toBe('grok-2'); + expect(matchModelName('grok-2-latest')).toBe('grok-2-latest'); + // Grok 3 models + expect(matchModelName('grok-3')).toBe('grok-3'); + expect(matchModelName('grok-3-fast')).toBe('grok-3-fast'); + expect(matchModelName('grok-3-mini')).toBe('grok-3-mini'); + expect(matchModelName('grok-3-mini-fast')).toBe('grok-3-mini-fast'); + }); + + test('should match Grok model variations with prefixes', () => { + // Vision models should match before general models + expect(matchModelName('xai/grok-2-vision-1212')).toBe('grok-2-vision-1212'); + expect(matchModelName('xai/grok-2-vision')).toBe('grok-2-vision'); + expect(matchModelName('xai/grok-2-vision-latest')).toBe('grok-2-vision-latest'); + // Beta models + expect(matchModelName('xai/grok-vision-beta')).toBe('grok-vision-beta'); + expect(matchModelName('xai/grok-beta')).toBe('grok-beta'); + // Text models + expect(matchModelName('xai/grok-2-1212')).toBe('grok-2-1212'); + expect(matchModelName('xai/grok-2')).toBe('grok-2'); + expect(matchModelName('xai/grok-2-latest')).toBe('grok-2-latest'); + // Grok 3 models + expect(matchModelName('xai/grok-3')).toBe('grok-3'); + expect(matchModelName('xai/grok-3-fast')).toBe('grok-3-fast'); + expect(matchModelName('xai/grok-3-mini')).toBe('grok-3-mini'); + expect(matchModelName('xai/grok-3-mini-fast')).toBe('grok-3-mini-fast'); + }); + }); +}); diff --git a/bun.lockb b/bun.lockb index e85113bbce..61118178fd 100755 Binary files a/bun.lockb and b/bun.lockb differ diff --git a/client/index.html b/client/index.html index 9bd0363fab..9e300e7365 100644 --- a/client/index.html +++ b/client/index.html @@ -6,6 +6,7 @@ + LibreChat @@ -53,6 +54,5 @@
- diff --git a/client/package.json b/client/package.json index 993cf30071..5fd9729a74 100644 --- a/client/package.json +++ b/client/package.json @@ -1,6 +1,6 @@ { "name": "@librechat/frontend", - "version": "v0.7.7-rc1", + "version": "v0.7.8", "description": "", "type": "module", "scripts": { @@ -28,10 +28,11 @@ }, "homepage": "https://librechat.ai", "dependencies": { - "@ariakit/react": "^0.4.11", + "@ariakit/react": "^0.4.15", + "@ariakit/react-core": "^0.4.17", "@codesandbox/sandpack-react": "^2.19.10", - "@dicebear/collection": "^7.0.4", - "@dicebear/core": "^7.0.4", + "@dicebear/collection": "^9.2.2", + "@dicebear/core": "^9.2.2", "@headlessui/react": "^2.1.2", "@radix-ui/react-accordion": "^1.1.2", "@radix-ui/react-alert-dialog": "^1.0.2", @@ -43,6 +44,7 @@ "@radix-ui/react-icons": "^1.3.0", "@radix-ui/react-label": "^2.0.0", "@radix-ui/react-popover": "^1.0.7", + "@radix-ui/react-progress": "^1.1.2", "@radix-ui/react-radio-group": "^1.1.3", "@radix-ui/react-select": "^2.0.0", "@radix-ui/react-separator": "^1.0.3", @@ -50,6 +52,7 @@ "@radix-ui/react-switch": "^1.0.3", "@radix-ui/react-tabs": "^1.0.3", "@radix-ui/react-toast": "^1.1.5", + "@react-spring/web": "^9.7.5", "@tanstack/react-query": "^4.28.0", "@tanstack/react-table": "^8.11.7", "class-variance-authority": "^0.6.0", @@ -63,12 +66,13 @@ "framer-motion": "^11.5.4", "html-to-image": "^1.11.11", "i18next": "^24.2.2", + "i18next-browser-languagedetector": "^8.0.3", + "input-otp": "^1.4.2", "js-cookie": "^3.0.5", "librechat-data-provider": "*", "lodash": "^4.17.21", "lucide-react": "^0.394.0", "match-sorter": "^6.3.4", - "msedge-tts": "^1.3.4", "qrcode.react": "^4.2.0", "rc-input-number": "^7.4.2", "react": "^18.2.0", @@ -82,7 +86,7 @@ "react-i18next": "^15.4.0", "react-lazy-load-image-component": "^1.6.0", "react-markdown": "^9.0.1", - "react-resizable-panels": "^2.1.1", + "react-resizable-panels": "^2.1.8", "react-router-dom": "^6.11.2", "react-speech-recognition": "^3.10.0", "react-textarea-autosize": "^8.4.0", @@ -114,10 +118,11 @@ "@testing-library/user-event": "^14.4.3", "@types/jest": "^29.5.14", "@types/js-cookie": "^3.0.6", + "@types/lodash": "^4.17.15", "@types/node": "^20.3.0", "@types/react": "^18.2.11", "@types/react-dom": "^18.2.4", - "@vitejs/plugin-react": "^4.2.1", + "@vitejs/plugin-react": "^4.3.4", "autoprefixer": "^10.4.20", "babel-plugin-replace-ts-export-assignment": "^0.0.2", "babel-plugin-root-import": "^6.6.0", @@ -136,8 +141,9 @@ "tailwindcss": "^3.4.1", "ts-jest": "^29.2.5", "typescript": "^5.3.3", - "vite": "^6.1.0", - "vite-plugin-node-polyfills": "^0.17.0", - "vite-plugin-pwa": "^0.21.1" + "vite": "^6.3.4", + "vite-plugin-compression2": "^1.3.3", + "vite-plugin-node-polyfills": "^0.23.0", + "vite-plugin-pwa": "^0.21.2" } } diff --git a/client/public/assets/apple-touch-icon-180x180.png b/client/public/assets/apple-touch-icon-180x180.png index 91dde5d139..57c4637c93 100644 Binary files a/client/public/assets/apple-touch-icon-180x180.png and b/client/public/assets/apple-touch-icon-180x180.png differ diff --git a/client/public/assets/icon-192x192.png b/client/public/assets/icon-192x192.png new file mode 100644 index 0000000000..b8dfe0eae5 Binary files /dev/null and b/client/public/assets/icon-192x192.png differ diff --git a/client/public/assets/image_gen_oai.png b/client/public/assets/image_gen_oai.png new file mode 100644 index 0000000000..e1762e7091 Binary files /dev/null and b/client/public/assets/image_gen_oai.png differ diff --git a/client/public/assets/maskable-icon.png b/client/public/assets/maskable-icon.png index 18305a2446..90e48f870b 100644 Binary files a/client/public/assets/maskable-icon.png and b/client/public/assets/maskable-icon.png differ diff --git a/client/public/assets/xai.svg b/client/public/assets/xai.svg deleted file mode 100644 index 2aca45ed4f..0000000000 --- a/client/public/assets/xai.svg +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/client/public/robots.txt b/client/public/robots.txt new file mode 100644 index 0000000000..376d535e0a --- /dev/null +++ b/client/public/robots.txt @@ -0,0 +1,3 @@ +User-agent: * +Disallow: /api/ +Allow: / \ No newline at end of file diff --git a/client/src/@types/i18next.d.ts b/client/src/@types/i18next.d.ts new file mode 100644 index 0000000000..2d50f5a3cd --- /dev/null +++ b/client/src/@types/i18next.d.ts @@ -0,0 +1,9 @@ +import { defaultNS, resources } from '~/locales/i18n'; + +declare module 'i18next' { + interface CustomTypeOptions { + defaultNS: typeof defaultNS; + resources: typeof resources.en; + strictKeyChecks: true + } +} \ No newline at end of file diff --git a/client/src/Providers/SearchContext.tsx b/client/src/Providers/SearchContext.tsx deleted file mode 100644 index 678818aa18..0000000000 --- a/client/src/Providers/SearchContext.tsx +++ /dev/null @@ -1,6 +0,0 @@ -import { createContext, useContext } from 'react'; -import useSearch from '~/hooks/Conversations/useSearch'; -type SearchContextType = ReturnType; - -export const SearchContext = createContext({} as SearchContextType); -export const useSearchContext = () => useContext(SearchContext); diff --git a/client/src/Providers/index.ts b/client/src/Providers/index.ts index 7363c97d41..43da0d346b 100644 --- a/client/src/Providers/index.ts +++ b/client/src/Providers/index.ts @@ -4,7 +4,6 @@ export { default as AgentsProvider } from './AgentsContext'; export * from './ChatContext'; export * from './ShareContext'; export * from './ToastContext'; -export * from './SearchContext'; export * from './FileMapContext'; export * from './AddedChatContext'; export * from './EditorContext'; diff --git a/client/src/common/agents-types.ts b/client/src/common/agents-types.ts index a9c24106bc..982cbfdb17 100644 --- a/client/src/common/agents-types.ts +++ b/client/src/common/agents-types.ts @@ -5,6 +5,7 @@ import type { OptionWithIcon, ExtendedFile } from './types'; export type TAgentOption = OptionWithIcon & Agent & { knowledge_files?: Array<[string, ExtendedFile]>; + context_files?: Array<[string, ExtendedFile]>; code_files?: Array<[string, ExtendedFile]>; }; @@ -27,4 +28,5 @@ export type AgentForm = { provider?: AgentProvider | OptionWithIcon; agent_ids?: string[]; [AgentCapabilities.artifacts]?: ArtifactModes | string; + recursion_limit?: number; } & TAgentCapabilities; diff --git a/client/src/common/index.ts b/client/src/common/index.ts index 3452818fce..e1a3ab0a05 100644 --- a/client/src/common/index.ts +++ b/client/src/common/index.ts @@ -3,5 +3,6 @@ export * from './artifacts'; export * from './types'; export * from './menus'; export * from './tools'; +export * from './selector'; export * from './assistants-types'; export * from './agents-types'; diff --git a/client/src/common/selector.ts b/client/src/common/selector.ts new file mode 100644 index 0000000000..619d8e8f80 --- /dev/null +++ b/client/src/common/selector.ts @@ -0,0 +1,23 @@ +import React from 'react'; +import { TModelSpec, TStartupConfig } from 'librechat-data-provider'; + +export interface Endpoint { + value: string; + label: string; + hasModels: boolean; + models?: Array<{ name: string; isGlobal?: boolean }>; + icon: React.ReactNode; + agentNames?: Record; + assistantNames?: Record; + modelIcons?: Record; +} + +export interface SelectedValues { + endpoint: string | null; + model: string | null; + modelSpec: string | null; +} + +export interface ModelSelectorProps { + startupConfig: TStartupConfig | undefined; +} diff --git a/client/src/common/types.ts b/client/src/common/types.ts index 3d61eccb1c..cd8b45f6b7 100644 --- a/client/src/common/types.ts +++ b/client/src/common/types.ts @@ -1,10 +1,10 @@ import { RefObject } from 'react'; -import { FileSources } from 'librechat-data-provider'; -import type * as InputNumberPrimitive from 'rc-input-number'; -import type { ColumnDef } from '@tanstack/react-table'; -import type { SetterOrUpdater } from 'recoil'; -import type * as t from 'librechat-data-provider'; +import { FileSources, EModelEndpoint } from 'librechat-data-provider'; import type { UseMutationResult } from '@tanstack/react-query'; +import type * as InputNumberPrimitive from 'rc-input-number'; +import type { SetterOrUpdater, RecoilState } from 'recoil'; +import type { ColumnDef } from '@tanstack/react-table'; +import type * as t from 'librechat-data-provider'; import type { LucideIcon } from 'lucide-react'; import type { TranslationKeys } from '~/hooks'; @@ -29,7 +29,6 @@ export enum STTEndpoints { export enum TTSEndpoints { browser = 'browser', - edge = 'edge', external = 'external', } @@ -48,6 +47,14 @@ export type AudioChunk = { }; }; +export type BadgeItem = { + id: string; + icon: React.ComponentType; + label: string; + atom: RecoilState; + isAvailable: boolean; +}; + export type AssistantListItem = { id: string; name: string; @@ -106,7 +113,7 @@ export type IconsRecord = { export type AgentIconMapProps = IconMapProps & { agentName?: string }; export type NavLink = { - title: string; + title: TranslationKeys; label?: string; icon: LucideIcon | React.FC; Component?: React.ComponentType; @@ -131,6 +138,7 @@ export interface DataColumnMeta { } export enum Panel { + advanced = 'advanced', builder = 'builder', actions = 'actions', model = 'model', @@ -181,6 +189,7 @@ export type AgentPanelProps = { activePanel?: string; action?: t.Action; actions?: t.Action[]; + createMutation: UseMutationResult; setActivePanel: React.Dispatch>; setAction: React.Dispatch>; endpointsConfig?: t.TEndpointsConfig; @@ -297,11 +306,14 @@ export type TAskProps = { export type TOptions = { editedMessageId?: string | null; editedText?: string | null; - resubmitFiles?: boolean; isRegenerate?: boolean; isContinued?: boolean; isEdited?: boolean; overrideMessages?: t.TMessage[]; + /** This value is only true when the user submits a message with "Save & Submit" for a user-created message */ + isResubmission?: boolean; + /** Currently only utilized when `isResubmission === true`, uses that message's currently attached files */ + overrideFiles?: t.TMessage['files']; }; export type TAskFunction = (props: TAskProps, options?: TOptions) => void; @@ -370,12 +382,12 @@ export type TDangerButtonProps = { showText?: boolean; mutation?: UseMutationResult; onClick: () => void; - infoTextCode: string; - actionTextCode: string; + infoTextCode: TranslationKeys; + actionTextCode: TranslationKeys; dataTestIdInitial: string; dataTestIdConfirm: string; - infoDescriptionCode?: string; - confirmActionTextCode?: string; + infoDescriptionCode?: TranslationKeys; + confirmActionTextCode?: TranslationKeys; }; export type TDialogProps = { @@ -399,7 +411,7 @@ export type TAuthContext = { isAuthenticated: boolean; error: string | undefined; login: (data: t.TLoginUser) => void; - logout: () => void; + logout: (redirect?: string) => void; setError: React.Dispatch>; roles?: Record; }; @@ -483,9 +495,23 @@ export interface ExtendedFile { attached?: boolean; embedded?: boolean; tool_resource?: string; + metadata?: t.TFile['metadata']; } -export type ContextType = { navVisible: boolean; setNavVisible: (visible: boolean) => void }; +export interface ModelItemProps { + modelName: string; + endpoint: EModelEndpoint; + isSelected: boolean; + onSelect: () => void; + onNavigateBack: () => void; + icon?: JSX.Element; + className?: string; +} + +export type ContextType = { + navVisible: boolean; + setNavVisible: React.Dispatch>; +}; export interface SwitcherProps { endpoint?: t.EModelEndpoint | null; @@ -528,7 +554,8 @@ export type TResData = TBaseResData & { responseMessage: t.TMessage; }; -export type TFinalResData = TBaseResData & { +export type TFinalResData = Omit & { + conversation: Partial & Pick; requestMessage?: t.TMessage; responseMessage?: t.TMessage; }; diff --git a/client/src/components/Artifacts/Artifact.tsx b/client/src/components/Artifacts/Artifact.tsx index 5081d9cc59..2b06a2ccc0 100644 --- a/client/src/components/Artifacts/Artifact.tsx +++ b/client/src/components/Artifacts/Artifact.tsx @@ -2,6 +2,7 @@ import React, { useEffect, useCallback, useRef, useState } from 'react'; import throttle from 'lodash/throttle'; import { visit } from 'unist-util-visit'; import { useSetRecoilState } from 'recoil'; +import { useLocation } from 'react-router-dom'; import type { Pluggable } from 'unified'; import type { Artifact } from '~/common'; import { useMessageContext, useArtifactContext } from '~/Providers'; @@ -11,7 +12,16 @@ import ArtifactButton from './ArtifactButton'; export const artifactPlugin: Pluggable = () => { return (tree) => { - visit(tree, ['textDirective', 'leafDirective', 'containerDirective'], (node) => { + visit(tree, ['textDirective', 'leafDirective', 'containerDirective'], (node, index, parent) => { + if (node.type === 'textDirective') { + const replacementText = `:${node.name}`; + if (parent && Array.isArray(parent.children) && typeof index === 'number') { + parent.children[index] = { + type: 'text', + value: replacementText, + }; + } + } if (node.name !== 'artifact') { return; } @@ -25,14 +35,18 @@ export const artifactPlugin: Pluggable = () => { }; }; +const defaultTitle = 'untitled'; +const defaultType = 'unknown'; +const defaultIdentifier = 'lc-no-identifier'; + export function Artifact({ - // eslint-disable-next-line @typescript-eslint/no-unused-vars node, ...props }: Artifact & { children: React.ReactNode | { props: { children: React.ReactNode } }; node: unknown; }) { + const location = useLocation(); const { messageId } = useMessageContext(); const { getNextIndex, resetCounter } = useArtifactContext(); const artifactIndex = useRef(getNextIndex(false)).current; @@ -50,15 +64,18 @@ export function Artifact({ const content = extractContent(props.children); logger.log('artifacts', 'updateArtifact: content.length', content.length); - const title = props.title ?? 'Untitled Artifact'; - const type = props.type ?? 'unknown'; - const identifier = props.identifier ?? 'no-identifier'; + const title = props.title ?? defaultTitle; + const type = props.type ?? defaultType; + const identifier = props.identifier ?? defaultIdentifier; const artifactKey = `${identifier}_${type}_${title}_${messageId}` .replace(/\s+/g, '_') .toLowerCase(); throttledUpdateRef.current(() => { const now = Date.now(); + if (artifactKey === `${defaultIdentifier}_${defaultType}_${defaultTitle}_${messageId}`) { + return; + } const currentArtifact: Artifact = { id: artifactKey, @@ -71,6 +88,10 @@ export function Artifact({ lastUpdateTime: now, }; + if (!location.pathname.includes('/c/')) { + return setArtifact(currentArtifact); + } + setArtifacts((prevArtifacts) => { if ( prevArtifacts?.[artifactKey] != null && @@ -95,6 +116,7 @@ export function Artifact({ props.identifier, messageId, artifactIndex, + location.pathname, ]); useEffect(() => { diff --git a/client/src/components/Artifacts/ArtifactButton.tsx b/client/src/components/Artifacts/ArtifactButton.tsx index d8fa557700..162e7d717c 100644 --- a/client/src/components/Artifacts/ArtifactButton.tsx +++ b/client/src/components/Artifacts/ArtifactButton.tsx @@ -1,14 +1,52 @@ -import { useSetRecoilState } from 'recoil'; +import { useEffect, useRef } from 'react'; +import debounce from 'lodash/debounce'; +import { useLocation } from 'react-router-dom'; +import { useRecoilState, useSetRecoilState, useResetRecoilState } from 'recoil'; import type { Artifact } from '~/common'; import FilePreview from '~/components/Chat/Input/Files/FilePreview'; +import { getFileType, logger } from '~/utils'; import { useLocalize } from '~/hooks'; -import { getFileType } from '~/utils'; import store from '~/store'; const ArtifactButton = ({ artifact }: { artifact: Artifact | null }) => { const localize = useLocalize(); - const setVisible = useSetRecoilState(store.artifactsVisible); - const setArtifactId = useSetRecoilState(store.currentArtifactId); + const location = useLocation(); + const setVisible = useSetRecoilState(store.artifactsVisibility); + const [artifacts, setArtifacts] = useRecoilState(store.artifactsState); + const setCurrentArtifactId = useSetRecoilState(store.currentArtifactId); + const resetCurrentArtifactId = useResetRecoilState(store.currentArtifactId); + const [visibleArtifacts, setVisibleArtifacts] = useRecoilState(store.visibleArtifacts); + + const debouncedSetVisibleRef = useRef( + debounce((artifactToSet: Artifact) => { + logger.log( + 'artifacts_visibility', + 'Setting artifact to visible state from Artifact button', + artifactToSet, + ); + setVisibleArtifacts((prev) => ({ + ...prev, + [artifactToSet.id]: artifactToSet, + })); + }, 750), + ); + + useEffect(() => { + if (artifact == null || artifact?.id == null || artifact.id === '') { + return; + } + + if (!location.pathname.includes('/c/')) { + return; + } + + const debouncedSetVisible = debouncedSetVisibleRef.current; + debouncedSetVisible(artifact); + return () => { + debouncedSetVisible.cancel(); + }; + }, [artifact, location.pathname]); + if (artifact === null || artifact === undefined) { return null; } @@ -19,12 +57,21 @@ const ArtifactButton = ({ artifact }: { artifact: Artifact | null }) => {

{currentArtifact.title}

@@ -118,22 +107,8 @@ export default function Artifacts() { {localize('com_ui_code')} - @@ -149,29 +124,13 @@ export default function Artifacts() {
{`${currentIndex + 1} / ${ orderedArtifactIds.length }`}
diff --git a/client/src/components/Artifacts/Code.tsx b/client/src/components/Artifacts/Code.tsx index de92c4c0da..21db2055d7 100644 --- a/client/src/components/Artifacts/Code.tsx +++ b/client/src/components/Artifacts/Code.tsx @@ -35,7 +35,7 @@ export const CodeMarkdown = memo( const [userScrolled, setUserScrolled] = useState(false); const currentContent = content; const rehypePlugins = [ - [rehypeKatex, { output: 'mathml' }], + [rehypeKatex], [ rehypeHighlight, { diff --git a/client/src/components/Audio/TTS.tsx b/client/src/components/Audio/TTS.tsx index 14c6346b0f..3ceacb7f8d 100644 --- a/client/src/components/Audio/TTS.tsx +++ b/client/src/components/Audio/TTS.tsx @@ -2,9 +2,8 @@ import { useEffect, useMemo } from 'react'; import { useRecoilValue } from 'recoil'; import type { TMessageAudio } from '~/common'; -import { useLocalize, useTTSBrowser, useTTSEdge, useTTSExternal } from '~/hooks'; -import { VolumeIcon, VolumeMuteIcon, Spinner } from '~/components/svg'; -import { useToastContext } from '~/Providers/ToastContext'; +import { useLocalize, useTTSBrowser, useTTSExternal } from '~/hooks'; +import { VolumeIcon, VolumeMuteIcon, Spinner } from '~/components'; import { logger } from '~/utils'; import store from '~/store'; @@ -85,97 +84,6 @@ export function BrowserTTS({ isLast, index, messageId, content, className }: TMe ); } -export function EdgeTTS({ isLast, index, messageId, content, className }: TMessageAudio) { - const localize = useLocalize(); - const playbackRate = useRecoilValue(store.playbackRate); - const isBrowserSupported = useMemo( - () => typeof MediaSource !== 'undefined' && MediaSource.isTypeSupported('audio/mpeg'), - [], - ); - - const { showToast } = useToastContext(); - const { toggleSpeech, isSpeaking, isLoading, audioRef } = useTTSEdge({ - isLast, - index, - messageId, - content, - }); - - const renderIcon = (size: string) => { - if (isLoading === true) { - return ; - } - - if (isSpeaking === true) { - return ; - } - - return ; - }; - - useEffect(() => { - const messageAudio = document.getElementById(`audio-${messageId}`) as HTMLAudioElement | null; - if (!messageAudio) { - return; - } - if (playbackRate != null && playbackRate > 0 && messageAudio.playbackRate !== playbackRate) { - messageAudio.playbackRate = playbackRate; - } - }, [audioRef, isSpeaking, playbackRate, messageId]); - - logger.log( - 'MessageAudio: audioRef.current?.src, audioRef.current', - audioRef.current?.src, - audioRef.current, - ); - - return ( - <> - - {isBrowserSupported ? ( -
); @@ -85,6 +58,7 @@ export function ExternalVoiceDropdown() { onChange={handleVoiceChange} sizeClasses="min-w-[200px] !max-w-[400px] [--anchor-max-width:400px]" testId="ExternalVoiceDropdown" + className="z-50" />
); diff --git a/client/src/components/Auth/AuthLayout.tsx b/client/src/components/Auth/AuthLayout.tsx index a7e890517a..d90f0d3dfe 100644 --- a/client/src/components/Auth/AuthLayout.tsx +++ b/client/src/components/Auth/AuthLayout.tsx @@ -85,7 +85,8 @@ function AuthLayout({ )} {children} - {(pathname.includes('login') || pathname.includes('register')) && ( + {!pathname.includes('2fa') && + (pathname.includes('login') || pathname.includes('register')) && ( )} diff --git a/client/src/components/Auth/Login.tsx b/client/src/components/Auth/Login.tsx index a332553701..48cbfe1a91 100644 --- a/client/src/components/Auth/Login.tsx +++ b/client/src/components/Auth/Login.tsx @@ -1,16 +1,78 @@ -import { useOutletContext } from 'react-router-dom'; +import { useOutletContext, useSearchParams } from 'react-router-dom'; +import { useEffect, useState } from 'react'; import { useAuthContext } from '~/hooks/AuthContext'; import type { TLoginLayoutContext } from '~/common'; import { ErrorMessage } from '~/components/Auth/ErrorMessage'; import { getLoginError } from '~/utils'; import { useLocalize } from '~/hooks'; import LoginForm from './LoginForm'; +import SocialButton from '~/components/Auth/SocialButton'; +import { OpenIDIcon } from '~/components'; function Login() { const localize = useLocalize(); const { error, setError, login } = useAuthContext(); const { startupConfig } = useOutletContext(); + const [searchParams, setSearchParams] = useSearchParams(); + // Determine if auto-redirect should be disabled based on the URL parameter + const disableAutoRedirect = searchParams.get('redirect') === 'false'; + + // Persist the disable flag locally so that once detected, auto-redirect stays disabled. + const [isAutoRedirectDisabled, setIsAutoRedirectDisabled] = useState(disableAutoRedirect); + + // Once the disable flag is detected, update local state and remove the parameter from the URL. + useEffect(() => { + if (disableAutoRedirect) { + setIsAutoRedirectDisabled(true); + const newParams = new URLSearchParams(searchParams); + newParams.delete('redirect'); + setSearchParams(newParams, { replace: true }); + } + }, [disableAutoRedirect, searchParams, setSearchParams]); + + // Determine whether we should auto-redirect to OpenID. + const shouldAutoRedirect = + startupConfig?.openidLoginEnabled && + startupConfig?.openidAutoRedirect && + startupConfig?.serverDomain && + !isAutoRedirectDisabled; + + useEffect(() => { + if (shouldAutoRedirect) { + console.log('Auto-redirecting to OpenID provider...'); + window.location.href = `${startupConfig.serverDomain}/oauth/openid`; + } + }, [shouldAutoRedirect, startupConfig]); + + // Render fallback UI if auto-redirect is active. + if (shouldAutoRedirect) { + return ( +
+

+ {localize('com_ui_redirecting_to_provider', { 0: startupConfig.openidLabel })} +

+
+ + startupConfig.openidImageUrl ? ( + OpenID Logo + ) : ( + + ) + } + label={startupConfig.openidLabel} + id="openid" + /> +
+
+ ); + } + return ( <> {error != null && {localize(getLoginError(error))}} diff --git a/client/src/components/Auth/LoginForm.tsx b/client/src/components/Auth/LoginForm.tsx index 9f5bb46039..2cd62d08b9 100644 --- a/client/src/components/Auth/LoginForm.tsx +++ b/client/src/components/Auth/LoginForm.tsx @@ -166,9 +166,7 @@ const LoginForm: React.FC = ({ onSubmit, startupConfig, error, type="submit" className=" w-full rounded-2xl bg-green-600 px-4 py-3 text-sm font-medium text-white - transition-colors hover:bg-green-700 focus:outline-none focus:ring-2 - focus:ring-green-500 focus:ring-offset-2 disabled:opacity-50 - disabled:hover:bg-green-600 dark:bg-green-600 dark:hover:bg-green-700 + transition-colors hover:bg-green-700 dark:bg-green-600 dark:hover:bg-green-700 " > {localize('com_auth_continue')} diff --git a/client/src/components/Auth/TwoFactorScreen.tsx b/client/src/components/Auth/TwoFactorScreen.tsx new file mode 100644 index 0000000000..04f89d7cea --- /dev/null +++ b/client/src/components/Auth/TwoFactorScreen.tsx @@ -0,0 +1,176 @@ +import React, { useState, useCallback } from 'react'; +import { useSearchParams } from 'react-router-dom'; +import { useForm, Controller } from 'react-hook-form'; +import { REGEXP_ONLY_DIGITS, REGEXP_ONLY_DIGITS_AND_CHARS } from 'input-otp'; +import { InputOTP, InputOTPGroup, InputOTPSeparator, InputOTPSlot, Label } from '~/components'; +import { useVerifyTwoFactorTempMutation } from '~/data-provider'; +import { useToastContext } from '~/Providers'; +import { useLocalize } from '~/hooks'; + +interface VerifyPayload { + tempToken: string; + token?: string; + backupCode?: string; +} + +type TwoFactorFormInputs = { + token?: string; + backupCode?: string; +}; + +const TwoFactorScreen: React.FC = React.memo(() => { + const [searchParams] = useSearchParams(); + const tempTokenRaw = searchParams.get('tempToken'); + const tempToken = tempTokenRaw !== null && tempTokenRaw !== '' ? tempTokenRaw : ''; + + const { + control, + handleSubmit, + formState: { errors }, + } = useForm(); + const localize = useLocalize(); + const { showToast } = useToastContext(); + const [useBackup, setUseBackup] = useState(false); + const [isLoading, setIsLoading] = useState(false); + const { mutate: verifyTempMutate } = useVerifyTwoFactorTempMutation({ + onSuccess: (result) => { + if (result.token != null && result.token !== '') { + window.location.href = '/'; + } + }, + onMutate: () => { + setIsLoading(true); + }, + onError: (error: unknown) => { + setIsLoading(false); + const err = error as { response?: { data?: { message?: unknown } } }; + const errorMsg = + typeof err.response?.data?.message === 'string' + ? err.response.data.message + : 'Error verifying 2FA'; + showToast({ message: errorMsg, status: 'error' }); + }, + }); + + const onSubmit = useCallback( + (data: TwoFactorFormInputs) => { + const payload: VerifyPayload = { tempToken }; + if (useBackup && data.backupCode != null && data.backupCode !== '') { + payload.backupCode = data.backupCode; + } else if (data.token != null && data.token !== '') { + payload.token = data.token; + } + verifyTempMutate(payload); + }, + [tempToken, useBackup, verifyTempMutate], + ); + + const toggleBackupOn = useCallback(() => { + setUseBackup(true); + }, []); + + const toggleBackupOff = useCallback(() => { + setUseBackup(false); + }, []); + + return ( +
+
+ + {!useBackup && ( +
+ ( + + + + + + + + + + + + + + )} + /> + {errors.token && {errors.token.message}} +
+ )} + {useBackup && ( +
+ ( + + + + + + + + + + + + + )} + /> + {errors.backupCode && ( + {errors.backupCode.message} + )} +
+ )} +
+ +
+
+ {!useBackup ? ( + + ) : ( + + )} +
+
+
+ ); +}); + +export default TwoFactorScreen; diff --git a/client/src/components/Auth/index.ts b/client/src/components/Auth/index.ts index cd1ac1adce..afde148015 100644 --- a/client/src/components/Auth/index.ts +++ b/client/src/components/Auth/index.ts @@ -4,3 +4,4 @@ export { default as ResetPassword } from './ResetPassword'; export { default as VerifyEmail } from './VerifyEmail'; export { default as ApiErrorWatcher } from './ApiErrorWatcher'; export { default as RequestPasswordReset } from './RequestPasswordReset'; +export { default as TwoFactorScreen } from './TwoFactorScreen'; diff --git a/client/src/components/Bookmarks/BookmarkItem.tsx b/client/src/components/Bookmarks/BookmarkItem.tsx index 92a6df0b54..60698a3165 100644 --- a/client/src/components/Bookmarks/BookmarkItem.tsx +++ b/client/src/components/Bookmarks/BookmarkItem.tsx @@ -34,19 +34,22 @@ const BookmarkItem: FC = ({ tag, selected, handleSubmit, icon, .. if (icon != null) { return icon; } + if (isLoading) { return ; } + if (selected) { return ; } + return ; }; return ( { - // eslint-disable-next-line @typescript-eslint/no-unused-vars const { title: _t, ...convo } = conversation ?? ({} as TConversation); setAddedConvo({ ...convo, @@ -42,7 +41,7 @@ function AddMultiConvo() { role="button" onClick={clickHandler} data-testid="parameters-button" - className="inline-flex size-10 items-center justify-center rounded-lg border border-border-light bg-transparent text-text-primary transition-all ease-in-out hover:bg-surface-tertiary disabled:pointer-events-none disabled:opacity-50 radix-state-open:bg-surface-tertiary" + className="inline-flex size-10 flex-shrink-0 items-center justify-center rounded-xl border border-border-light bg-transparent text-text-primary transition-all ease-in-out hover:bg-surface-tertiary disabled:pointer-events-none disabled:opacity-50 radix-state-open:bg-surface-tertiary" > diff --git a/client/src/components/Chat/ChatView.tsx b/client/src/components/Chat/ChatView.tsx index dbf39ee845..a554c5f7d1 100644 --- a/client/src/components/Chat/ChatView.tsx +++ b/client/src/components/Chat/ChatView.tsx @@ -2,25 +2,38 @@ import { memo, useCallback } from 'react'; import { useRecoilValue } from 'recoil'; import { useForm } from 'react-hook-form'; import { useParams } from 'react-router-dom'; -import { useGetMessagesByConvoId } from 'librechat-data-provider/react-query'; +import { Constants } from 'librechat-data-provider'; import type { TMessage } from 'librechat-data-provider'; import type { ChatFormValues } from '~/common'; import { ChatContext, AddedChatContext, useFileMapContext, ChatFormProvider } from '~/Providers'; import { useChatHelpers, useAddedResponse, useSSE } from '~/hooks'; +import ConversationStarters from './Input/ConversationStarters'; +import { useGetMessagesByConvoId } from '~/data-provider'; import MessagesView from './Messages/MessagesView'; import { Spinner } from '~/components/svg'; import Presentation from './Presentation'; +import { buildTree, cn } from '~/utils'; import ChatForm from './Input/ChatForm'; -import { buildTree } from '~/utils'; import Landing from './Landing'; import Header from './Header'; import Footer from './Footer'; import store from '~/store'; +function LoadingSpinner() { + return ( +
+
+ +
+
+ ); +} + function ChatView({ index = 0 }: { index?: number }) { const { conversationId } = useParams(); const rootSubmission = useRecoilValue(store.submissionByIndex(index)); const addedSubmission = useRecoilValue(store.submissionByIndex(index + 1)); + const centerFormOnLanding = useRecoilValue(store.centerFormOnLanding); const fileMap = useFileMapContext(); @@ -46,16 +59,19 @@ function ChatView({ index = 0 }: { index?: number }) { }); let content: JSX.Element | null | undefined; - if (isLoading && conversationId !== 'new') { - content = ( -
- -
- ); - } else if (messagesTree && messagesTree.length !== 0) { - content = } />; + const isLandingPage = + (!messagesTree || messagesTree.length === 0) && + (conversationId === Constants.NEW_CONVO || !conversationId); + const isNavigating = (!messagesTree || messagesTree.length === 0) && conversationId != null; + + if (isLoading && conversationId !== Constants.NEW_CONVO) { + content = ; + } else if ((isLoading || isNavigating) && !isLandingPage) { + content = ; + } else if (!isLandingPage) { + content = ; } else { - content = } />; + content = ; } return ( @@ -63,10 +79,30 @@ function ChatView({ index = 0 }: { index?: number }) { - {content} -
- -