diff --git a/.env.example b/.env.example index fcf017c327..876535b345 100644 --- a/.env.example +++ b/.env.example @@ -443,7 +443,6 @@ OPENID_IMAGE_URL= # Set to true to automatically redirect to the OpenID provider when a user visits the login page # This will bypass the login form completely for users, only use this if OpenID is your only authentication method OPENID_AUTO_REDIRECT=false - # Set to true to use PKCE (Proof Key for Code Exchange) for OpenID authentication OPENID_USE_PKCE=false #Set to true to reuse openid tokens for authentication management instead of using the mongodb session and the custom refresh token. @@ -459,6 +458,33 @@ OPENID_ON_BEHALF_FLOW_USERINFRO_SCOPE = "user.read" # example for Scope Needed f # Set to true to use the OpenID Connect end session endpoint for logout OPENID_USE_END_SESSION_ENDPOINT= + +# SAML +# Note: If OpenID is enabled, SAML authentication will be automatically disabled. +SAML_ENTRY_POINT= +SAML_ISSUER= +SAML_CERT= +SAML_CALLBACK_URL=/oauth/saml/callback +SAML_SESSION_SECRET= + +# Attribute mappings (optional) +SAML_EMAIL_CLAIM= +SAML_USERNAME_CLAIM= +SAML_GIVEN_NAME_CLAIM= +SAML_FAMILY_NAME_CLAIM= +SAML_PICTURE_CLAIM= +SAML_NAME_CLAIM= + +# Logint buttion settings (optional) +SAML_BUTTON_LABEL= +SAML_IMAGE_URL= + +# Whether the SAML Response should be signed. +# - If "true", the entire `SAML Response` will be signed. +# - If "false" or unset, only the `SAML Assertion` will be signed (default behavior). +# SAML_USE_AUTHN_RESPONSE_SIGNED= + + # LDAP LDAP_URL= LDAP_BIND_DN= @@ -489,6 +515,18 @@ EMAIL_PASSWORD= EMAIL_FROM_NAME= EMAIL_FROM=noreply@librechat.ai +#========================# +# Mailgun API # +#========================# + +# MAILGUN_API_KEY=your-mailgun-api-key +# MAILGUN_DOMAIN=mg.yourdomain.com +# EMAIL_FROM=noreply@yourdomain.com +# EMAIL_FROM_NAME="LibreChat" + +# # Optional: For EU region +# MAILGUN_HOST=https://api.eu.mailgun.net + #========================# # Firebase CDN # #========================# diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 09444a1b44..207aa17e66 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -30,8 +30,8 @@ Project maintainers have the right and responsibility to remove, edit, or reject 2. Install typescript globally: `npm i -g typescript`. 3. Run `npm ci` to install dependencies. 4. Build the data provider: `npm run build:data-provider`. -5. Build MCP: `npm run build:mcp`. -6. Build data schemas: `npm run build:data-schemas`. +5. Build data schemas: `npm run build:data-schemas`. +6. Build API methods: `npm run build:api`. 7. Setup and run unit tests: - Copy `.env.test`: `cp api/test/.env.test.example api/test/.env.test`. - Run backend unit tests: `npm run test:api`. diff --git a/.github/workflows/backend-review.yml b/.github/workflows/backend-review.yml index b7bccecae8..7637b8cdc0 100644 --- a/.github/workflows/backend-review.yml +++ b/.github/workflows/backend-review.yml @@ -7,6 +7,7 @@ on: - release/* paths: - 'api/**' + - 'packages/api/**' jobs: tests_Backend: name: Run Backend unit tests @@ -36,12 +37,12 @@ jobs: - name: Install Data Provider Package run: npm run build:data-provider - - name: Install MCP Package - run: npm run build:mcp - - name: Install Data Schemas Package run: npm run build:data-schemas + - name: Install API Package + run: npm run build:api + - name: Create empty auth.json file run: | mkdir -p api/data @@ -66,5 +67,8 @@ jobs: - name: Run librechat-data-provider unit tests run: cd packages/data-provider && npm run test:ci - - name: Run librechat-mcp unit tests - run: cd packages/mcp && npm run test:ci \ No newline at end of file + - name: Run @librechat/data-schemas unit tests + run: cd packages/data-schemas && npm run test:ci + + - name: Run @librechat/api unit tests + run: cd packages/api && npm run test:ci \ No newline at end of file diff --git a/.github/workflows/deploy-dev.yml b/.github/workflows/deploy-dev.yml index fc1c02db69..a255932e3e 100644 --- a/.github/workflows/deploy-dev.yml +++ b/.github/workflows/deploy-dev.yml @@ -2,7 +2,7 @@ name: Update Test Server on: workflow_run: - workflows: ["Docker Dev Images Build"] + workflows: ["Docker Dev Branch Images Build"] types: - completed workflow_dispatch: @@ -12,7 +12,8 @@ jobs: runs-on: ubuntu-latest if: | github.repository == 'danny-avila/LibreChat' && - (github.event_name == 'workflow_dispatch' || github.event.workflow_run.conclusion == 'success') + (github.event_name == 'workflow_dispatch' || + (github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.head_branch == 'dev')) steps: - name: Checkout repository uses: actions/checkout@v4 @@ -29,13 +30,17 @@ jobs: DO_USER: ${{ secrets.DO_USER }} run: | ssh -o StrictHostKeyChecking=no ${DO_USER}@${DO_HOST} << EOF - sudo -i -u danny bash << EEOF + sudo -i -u danny bash << 'EEOF' cd ~/LibreChat && \ git fetch origin main && \ - npm run update:deployed && \ + sudo npm run stop:deployed && \ + sudo docker images --format "{{.Repository}}:{{.ID}}" | grep -E "lc-dev|librechat" | cut -d: -f2 | xargs -r sudo docker rmi -f || true && \ + sudo npm run update:deployed && \ + git checkout dev && \ + git pull origin dev && \ git checkout do-deploy && \ - git rebase main && \ - npm run start:deployed && \ + git rebase dev && \ + sudo npm run start:deployed && \ echo "Update completed. Application should be running now." EEOF EOF diff --git a/.github/workflows/dev-branch-images.yml b/.github/workflows/dev-branch-images.yml new file mode 100644 index 0000000000..b7ad470314 --- /dev/null +++ b/.github/workflows/dev-branch-images.yml @@ -0,0 +1,72 @@ +name: Docker Dev Branch Images Build + +on: + workflow_dispatch: + push: + branches: + - dev + paths: + - 'api/**' + - 'client/**' + - 'packages/**' + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + include: + - target: api-build + file: Dockerfile.multi + image_name: lc-dev-api + - target: node + file: Dockerfile + image_name: lc-dev + + steps: + # Check out the repository + - name: Checkout + uses: actions/checkout@v4 + + # Set up QEMU + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + # Set up Docker Buildx + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + # Log in to GitHub Container Registry + - name: Log in to GitHub Container Registry + uses: docker/login-action@v2 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + # Login to Docker Hub + - name: Login to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + # Prepare the environment + - name: Prepare environment + run: | + cp .env.example .env + + # Build and push Docker images for each target + - name: Build and push Docker images + uses: docker/build-push-action@v5 + with: + context: . + file: ${{ matrix.file }} + push: true + tags: | + ghcr.io/${{ github.repository_owner }}/${{ matrix.image_name }}:${{ github.sha }} + ghcr.io/${{ github.repository_owner }}/${{ matrix.image_name }}:latest + ${{ secrets.DOCKERHUB_USERNAME }}/${{ matrix.image_name }}:${{ github.sha }} + ${{ secrets.DOCKERHUB_USERNAME }}/${{ matrix.image_name }}:latest + platforms: linux/amd64,linux/arm64 + target: ${{ matrix.target }} \ No newline at end of file diff --git a/.github/workflows/i18n-unused-keys.yml b/.github/workflows/i18n-unused-keys.yml index 6bcf824946..07cc77a1ae 100644 --- a/.github/workflows/i18n-unused-keys.yml +++ b/.github/workflows/i18n-unused-keys.yml @@ -5,12 +5,13 @@ on: paths: - "client/src/**" - "api/**" + - "packages/data-provider/src/**" jobs: detect-unused-i18n-keys: runs-on: ubuntu-latest permissions: - pull-requests: write # Required for posting PR comments + pull-requests: write steps: - name: Checkout repository uses: actions/checkout@v3 diff --git a/.gitignore b/.gitignore index c24bc76b19..f49594afdf 100644 --- a/.gitignore +++ b/.gitignore @@ -122,3 +122,5 @@ helm/**/.values.yaml !/client/src/@types/i18next.d.ts +# SAML Idp cert +*.cert diff --git a/Dockerfile.multi b/Dockerfile.multi index 991f805bec..17a9847323 100644 --- a/Dockerfile.multi +++ b/Dockerfile.multi @@ -14,7 +14,7 @@ RUN npm config set fetch-retry-maxtimeout 600000 && \ npm config set fetch-retry-mintimeout 15000 COPY package*.json ./ COPY packages/data-provider/package*.json ./packages/data-provider/ -COPY packages/mcp/package*.json ./packages/mcp/ +COPY packages/api/package*.json ./packages/api/ COPY packages/data-schemas/package*.json ./packages/data-schemas/ COPY client/package*.json ./client/ COPY api/package*.json ./api/ @@ -24,26 +24,27 @@ FROM base-min AS base WORKDIR /app RUN npm ci -# Build data-provider +# Build `data-provider` package FROM base AS data-provider-build WORKDIR /app/packages/data-provider COPY packages/data-provider ./ RUN npm run build -# Build mcp package -FROM base AS mcp-build -WORKDIR /app/packages/mcp -COPY packages/mcp ./ -COPY --from=data-provider-build /app/packages/data-provider/dist /app/packages/data-provider/dist -RUN npm run build - -# Build data-schemas +# Build `data-schemas` package FROM base AS data-schemas-build WORKDIR /app/packages/data-schemas COPY packages/data-schemas ./ COPY --from=data-provider-build /app/packages/data-provider/dist /app/packages/data-provider/dist RUN npm run build +# Build `api` package +FROM base AS api-package-build +WORKDIR /app/packages/api +COPY packages/api ./ +COPY --from=data-provider-build /app/packages/data-provider/dist /app/packages/data-provider/dist +COPY --from=data-schemas-build /app/packages/data-schemas/dist /app/packages/data-schemas/dist +RUN npm run build + # Client build FROM base AS client-build WORKDIR /app/client @@ -63,8 +64,8 @@ RUN npm ci --omit=dev COPY api ./api COPY config ./config COPY --from=data-provider-build /app/packages/data-provider/dist ./packages/data-provider/dist -COPY --from=mcp-build /app/packages/mcp/dist ./packages/mcp/dist COPY --from=data-schemas-build /app/packages/data-schemas/dist ./packages/data-schemas/dist +COPY --from=api-package-build /app/packages/api/dist ./packages/api/dist COPY --from=client-build /app/client/dist ./client/dist WORKDIR /app/api EXPOSE 3080 diff --git a/README.md b/README.md index cc9533b2d2..d6bd19ab43 100644 --- a/README.md +++ b/README.md @@ -150,8 +150,8 @@ Click on the thumbnail to open the video☝️ **Other:** - **Website:** [librechat.ai](https://librechat.ai) - - **Documentation:** [docs.librechat.ai](https://docs.librechat.ai) - - **Blog:** [blog.librechat.ai](https://blog.librechat.ai) + - **Documentation:** [librechat.ai/docs](https://librechat.ai/docs) + - **Blog:** [librechat.ai/blog](https://librechat.ai/blog) --- diff --git a/api/app/clients/AnthropicClient.js b/api/app/clients/AnthropicClient.js index 0da331ced5..037f1e7c46 100644 --- a/api/app/clients/AnthropicClient.js +++ b/api/app/clients/AnthropicClient.js @@ -10,6 +10,7 @@ const { validateVisionModel, } = require('librechat-data-provider'); const { SplitStreamHandler: _Handler } = require('@librechat/agents'); +const { Tokenizer, createFetch, createStreamEventHandlers } = require('@librechat/api'); const { truncateText, formatMessage, @@ -26,8 +27,6 @@ const { const { getModelMaxTokens, getModelMaxOutputTokens, matchModelName } = require('~/utils'); const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); const { encodeAndFormat } = require('~/server/services/Files/images/encode'); -const { createFetch, createStreamEventHandlers } = require('./generators'); -const Tokenizer = require('~/server/services/Tokenizer'); const { sleep } = require('~/server/utils'); const BaseClient = require('./BaseClient'); const { logger } = require('~/config'); diff --git a/api/app/clients/ChatGPTClient.js b/api/app/clients/ChatGPTClient.js index 07b2fa97bb..555028dc3f 100644 --- a/api/app/clients/ChatGPTClient.js +++ b/api/app/clients/ChatGPTClient.js @@ -2,6 +2,7 @@ const { Keyv } = require('keyv'); const crypto = require('crypto'); const { CohereClient } = require('cohere-ai'); const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source'); +const { constructAzureURL, genAzureChatCompletion } = require('@librechat/api'); const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken'); const { ImageDetail, @@ -10,9 +11,9 @@ const { CohereConstants, mapModelToAzureConfig, } = require('librechat-data-provider'); -const { extractBaseURL, constructAzureURL, genAzureChatCompletion } = require('~/utils'); const { createContextHandlers } = require('./prompts'); const { createCoherePayload } = require('./llm'); +const { extractBaseURL } = require('~/utils'); const BaseClient = require('./BaseClient'); const { logger } = require('~/config'); @@ -244,9 +245,9 @@ class ChatGPTClient extends BaseClient { baseURL = this.langchainProxy ? constructAzureURL({ - baseURL: this.langchainProxy, - azureOptions: this.azure, - }) + baseURL: this.langchainProxy, + azureOptions: this.azure, + }) : this.azureEndpoint.split(/(? { try { let done = false; diff --git a/api/app/clients/GoogleClient.js b/api/app/clients/GoogleClient.js index c9102e9ae2..817239d14f 100644 --- a/api/app/clients/GoogleClient.js +++ b/api/app/clients/GoogleClient.js @@ -1,4 +1,5 @@ const { google } = require('googleapis'); +const { Tokenizer } = require('@librechat/api'); const { concat } = require('@langchain/core/utils/stream'); const { ChatVertexAI } = require('@langchain/google-vertexai'); const { ChatGoogleGenerativeAI } = require('@langchain/google-genai'); @@ -19,7 +20,6 @@ const { } = require('librechat-data-provider'); const { getSafetySettings } = require('~/server/services/Endpoints/google/llm'); const { encodeAndFormat } = require('~/server/services/Files/images'); -const Tokenizer = require('~/server/services/Tokenizer'); const { spendTokens } = require('~/models/spendTokens'); const { getModelMaxTokens } = require('~/utils'); const { sleep } = require('~/server/utils'); @@ -34,7 +34,8 @@ const BaseClient = require('./BaseClient'); const loc = process.env.GOOGLE_LOC || 'us-central1'; const publisher = 'google'; -const endpointPrefix = `${loc}-aiplatform.googleapis.com`; +const endpointPrefix = + loc === 'global' ? 'aiplatform.googleapis.com' : `${loc}-aiplatform.googleapis.com`; const settings = endpointSettings[EModelEndpoint.google]; const EXCLUDED_GENAI_MODELS = /gemini-(?:1\.0|1-0|pro)/; @@ -236,11 +237,11 @@ class GoogleClient extends BaseClient { msg.content = ( !Array.isArray(msg.content) ? [ - { - type: ContentTypes.TEXT, - [ContentTypes.TEXT]: msg.content, - }, - ] + { + type: ContentTypes.TEXT, + [ContentTypes.TEXT]: msg.content, + }, + ] : msg.content ).concat(message.image_urls); diff --git a/api/app/clients/OllamaClient.js b/api/app/clients/OllamaClient.js index 77d007580c..032781f1f1 100644 --- a/api/app/clients/OllamaClient.js +++ b/api/app/clients/OllamaClient.js @@ -1,10 +1,11 @@ const { z } = require('zod'); const axios = require('axios'); const { Ollama } = require('ollama'); +const { sleep } = require('@librechat/agents'); +const { logAxiosError } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { Constants } = require('librechat-data-provider'); -const { deriveBaseURL, logAxiosError } = require('~/utils'); -const { sleep } = require('~/server/utils'); -const { logger } = require('~/config'); +const { deriveBaseURL } = require('~/utils'); const ollamaPayloadSchema = z.object({ mirostat: z.number().optional(), @@ -67,7 +68,7 @@ class OllamaClient { return models; } catch (error) { const logMessage = - 'Failed to fetch models from Ollama API. If you are not using Ollama directly, and instead, through some aggregator or reverse proxy that handles fetching via OpenAI spec, ensure the name of the endpoint doesn\'t start with `ollama` (case-insensitive).'; + "Failed to fetch models from Ollama API. If you are not using Ollama directly, and instead, through some aggregator or reverse proxy that handles fetching via OpenAI spec, ensure the name of the endpoint doesn't start with `ollama` (case-insensitive)."; logAxiosError({ message: logMessage, error }); return []; } diff --git a/api/app/clients/OpenAIClient.js b/api/app/clients/OpenAIClient.js index 280db89284..f3a7e67c12 100644 --- a/api/app/clients/OpenAIClient.js +++ b/api/app/clients/OpenAIClient.js @@ -1,6 +1,14 @@ const { OllamaClient } = require('./OllamaClient'); const { HttpsProxyAgent } = require('https-proxy-agent'); const { SplitStreamHandler, CustomOpenAIClient: OpenAI } = require('@librechat/agents'); +const { + isEnabled, + Tokenizer, + createFetch, + constructAzureURL, + genAzureChatCompletion, + createStreamEventHandlers, +} = require('@librechat/api'); const { Constants, ImageDetail, @@ -16,13 +24,6 @@ const { validateVisionModel, mapModelToAzureConfig, } = require('librechat-data-provider'); -const { - extractBaseURL, - constructAzureURL, - getModelMaxTokens, - genAzureChatCompletion, - getModelMaxOutputTokens, -} = require('~/utils'); const { truncateText, formatMessage, @@ -30,10 +31,9 @@ const { titleInstruction, createContextHandlers, } = require('./prompts'); +const { extractBaseURL, getModelMaxTokens, getModelMaxOutputTokens } = require('~/utils'); const { encodeAndFormat } = require('~/server/services/Files/images/encode'); -const { createFetch, createStreamEventHandlers } = require('./generators'); -const { addSpaceIfNeeded, isEnabled, sleep } = require('~/server/utils'); -const Tokenizer = require('~/server/services/Tokenizer'); +const { addSpaceIfNeeded, sleep } = require('~/server/utils'); const { spendTokens } = require('~/models/spendTokens'); const { handleOpenAIErrors } = require('./tools/util'); const { createLLM, RunManager } = require('./llm'); diff --git a/api/app/clients/generators.js b/api/app/clients/generators.js deleted file mode 100644 index 9814cac7a5..0000000000 --- a/api/app/clients/generators.js +++ /dev/null @@ -1,71 +0,0 @@ -const fetch = require('node-fetch'); -const { GraphEvents } = require('@librechat/agents'); -const { logger, sendEvent } = require('~/config'); -const { sleep } = require('~/server/utils'); - -/** - * Makes a function to make HTTP request and logs the process. - * @param {Object} params - * @param {boolean} [params.directEndpoint] - Whether to use a direct endpoint. - * @param {string} [params.reverseProxyUrl] - The reverse proxy URL to use for the request. - * @returns {Promise} - A promise that resolves to the response of the fetch request. - */ -function createFetch({ directEndpoint = false, reverseProxyUrl = '' }) { - /** - * Makes an HTTP request and logs the process. - * @param {RequestInfo} url - The URL to make the request to. Can be a string or a Request object. - * @param {RequestInit} [init] - Optional init options for the request. - * @returns {Promise} - A promise that resolves to the response of the fetch request. - */ - return async (_url, init) => { - let url = _url; - if (directEndpoint) { - url = reverseProxyUrl; - } - logger.debug(`Making request to ${url}`); - if (typeof Bun !== 'undefined') { - return await fetch(url, init); - } - return await fetch(url, init); - }; -} - -// Add this at the module level outside the class -/** - * Creates event handlers for stream events that don't capture client references - * @param {Object} res - The response object to send events to - * @returns {Object} Object containing handler functions - */ -function createStreamEventHandlers(res) { - return { - [GraphEvents.ON_RUN_STEP]: (event) => { - if (res) { - sendEvent(res, event); - } - }, - [GraphEvents.ON_MESSAGE_DELTA]: (event) => { - if (res) { - sendEvent(res, event); - } - }, - [GraphEvents.ON_REASONING_DELTA]: (event) => { - if (res) { - sendEvent(res, event); - } - }, - }; -} - -function createHandleLLMNewToken(streamRate) { - return async () => { - if (streamRate) { - await sleep(streamRate); - } - }; -} - -module.exports = { - createFetch, - createHandleLLMNewToken, - createStreamEventHandlers, -}; diff --git a/api/app/clients/llm/createLLM.js b/api/app/clients/llm/createLLM.js index c8d6666bce..846c4d8e9c 100644 --- a/api/app/clients/llm/createLLM.js +++ b/api/app/clients/llm/createLLM.js @@ -1,6 +1,5 @@ const { ChatOpenAI } = require('@langchain/openai'); -const { sanitizeModelName, constructAzureURL } = require('~/utils'); -const { isEnabled } = require('~/server/utils'); +const { isEnabled, sanitizeModelName, constructAzureURL } = require('@librechat/api'); /** * Creates a new instance of a language model (LLM) for chat interactions. diff --git a/api/app/clients/specs/BaseClient.test.js b/api/app/clients/specs/BaseClient.test.js index d620d5f647..6d44915804 100644 --- a/api/app/clients/specs/BaseClient.test.js +++ b/api/app/clients/specs/BaseClient.test.js @@ -1,7 +1,7 @@ const { Constants } = require('librechat-data-provider'); const { initializeFakeClient } = require('./FakeClient'); -jest.mock('~/lib/db/connectDb'); +jest.mock('~/db/connect'); jest.mock('~/models', () => ({ User: jest.fn(), Key: jest.fn(), @@ -33,7 +33,9 @@ jest.mock('~/models', () => ({ const { getConvo, saveConvo } = require('~/models'); jest.mock('@librechat/agents', () => { + const { Providers } = jest.requireActual('@librechat/agents'); return { + Providers, ChatOpenAI: jest.fn().mockImplementation(() => { return {}; }), @@ -52,7 +54,7 @@ const messageHistory = [ { role: 'user', isCreatedByUser: true, - text: 'What\'s up', + text: "What's up", messageId: '3', parentMessageId: '2', }, @@ -456,7 +458,7 @@ describe('BaseClient', () => { const chatMessages2 = await TestClient.loadHistory(conversationId, '3'); expect(TestClient.currentMessages).toHaveLength(3); - expect(chatMessages2[chatMessages2.length - 1].text).toEqual('What\'s up'); + expect(chatMessages2[chatMessages2.length - 1].text).toEqual("What's up"); }); /* Most of the new sendMessage logic revolving around edited/continued AI messages diff --git a/api/app/clients/specs/OpenAIClient.test.js b/api/app/clients/specs/OpenAIClient.test.js index 579f636eef..cc4aa84d5d 100644 --- a/api/app/clients/specs/OpenAIClient.test.js +++ b/api/app/clients/specs/OpenAIClient.test.js @@ -5,7 +5,7 @@ const getLogStores = require('~/cache/getLogStores'); const OpenAIClient = require('../OpenAIClient'); jest.mock('meilisearch'); -jest.mock('~/lib/db/connectDb'); +jest.mock('~/db/connect'); jest.mock('~/models', () => ({ User: jest.fn(), Key: jest.fn(), @@ -462,17 +462,17 @@ describe('OpenAIClient', () => { role: 'system', name: 'example_user', content: - 'Let\'s circle back when we have more bandwidth to touch base on opportunities for increased leverage.', + "Let's circle back when we have more bandwidth to touch base on opportunities for increased leverage.", }, { role: 'system', name: 'example_assistant', - content: 'Let\'s talk later when we\'re less busy about how to do better.', + content: "Let's talk later when we're less busy about how to do better.", }, { role: 'user', content: - 'This late pivot means we don\'t have time to boil the ocean for the client deliverable.', + "This late pivot means we don't have time to boil the ocean for the client deliverable.", }, ]; diff --git a/api/app/clients/specs/PluginsClient.test.js b/api/app/clients/specs/PluginsClient.test.js index fd7bee5043..4928acefd1 100644 --- a/api/app/clients/specs/PluginsClient.test.js +++ b/api/app/clients/specs/PluginsClient.test.js @@ -3,7 +3,7 @@ const { Constants } = require('librechat-data-provider'); const { HumanMessage, AIMessage } = require('@langchain/core/messages'); const PluginsClient = require('../PluginsClient'); -jest.mock('~/lib/db/connectDb'); +jest.mock('~/db/connect'); jest.mock('~/models/Conversation', () => { return function () { return { diff --git a/api/app/clients/tools/dynamic/OpenAPIPlugin.js b/api/app/clients/tools/dynamic/OpenAPIPlugin.js deleted file mode 100644 index acc3a64d32..0000000000 --- a/api/app/clients/tools/dynamic/OpenAPIPlugin.js +++ /dev/null @@ -1,184 +0,0 @@ -require('dotenv').config(); -const fs = require('fs'); -const { z } = require('zod'); -const path = require('path'); -const yaml = require('js-yaml'); -const { createOpenAPIChain } = require('langchain/chains'); -const { DynamicStructuredTool } = require('@langchain/core/tools'); -const { ChatPromptTemplate, HumanMessagePromptTemplate } = require('@langchain/core/prompts'); -const { logger } = require('~/config'); - -function addLinePrefix(text, prefix = '// ') { - return text - .split('\n') - .map((line) => prefix + line) - .join('\n'); -} - -function createPrompt(name, functions) { - const prefix = `// The ${name} tool has the following functions. Determine the desired or most optimal function for the user's query:`; - const functionDescriptions = functions - .map((func) => `// - ${func.name}: ${func.description}`) - .join('\n'); - return `${prefix}\n${functionDescriptions} -// You are an expert manager and scrum master. You must provide a detailed intent to better execute the function. -// Always format as such: {{"func": "function_name", "intent": "intent and expected result"}}`; -} - -const AuthBearer = z - .object({ - type: z.string().includes('service_http'), - authorization_type: z.string().includes('bearer'), - verification_tokens: z.object({ - openai: z.string(), - }), - }) - .catch(() => false); - -const AuthDefinition = z - .object({ - type: z.string(), - authorization_type: z.string(), - verification_tokens: z.object({ - openai: z.string(), - }), - }) - .catch(() => false); - -async function readSpecFile(filePath) { - try { - const fileContents = await fs.promises.readFile(filePath, 'utf8'); - if (path.extname(filePath) === '.json') { - return JSON.parse(fileContents); - } - return yaml.load(fileContents); - } catch (e) { - logger.error('[readSpecFile] error', e); - return false; - } -} - -async function getSpec(url) { - const RegularUrl = z - .string() - .url() - .catch(() => false); - - if (RegularUrl.parse(url) && path.extname(url) === '.json') { - const response = await fetch(url); - return await response.json(); - } - - const ValidSpecPath = z - .string() - .url() - .catch(async () => { - const spec = path.join(__dirname, '..', '.well-known', 'openapi', url); - if (!fs.existsSync(spec)) { - return false; - } - - return await readSpecFile(spec); - }); - - return ValidSpecPath.parse(url); -} - -async function createOpenAPIPlugin({ data, llm, user, message, memory, signal }) { - let spec; - try { - spec = await getSpec(data.api.url); - } catch (error) { - logger.error('[createOpenAPIPlugin] getSpec error', error); - return null; - } - - if (!spec) { - logger.warn('[createOpenAPIPlugin] No spec found'); - return null; - } - - const headers = {}; - const { auth, name_for_model, description_for_model, description_for_human } = data; - if (auth && AuthDefinition.parse(auth)) { - logger.debug('[createOpenAPIPlugin] auth detected', auth); - const { openai } = auth.verification_tokens; - if (AuthBearer.parse(auth)) { - headers.authorization = `Bearer ${openai}`; - logger.debug('[createOpenAPIPlugin] added auth bearer', headers); - } - } - - const chainOptions = { llm }; - - if (data.headers && data.headers['librechat_user_id']) { - logger.debug('[createOpenAPIPlugin] id detected', headers); - headers[data.headers['librechat_user_id']] = user; - } - - if (Object.keys(headers).length > 0) { - logger.debug('[createOpenAPIPlugin] headers detected', headers); - chainOptions.headers = headers; - } - - if (data.params) { - logger.debug('[createOpenAPIPlugin] params detected', data.params); - chainOptions.params = data.params; - } - - let history = ''; - if (memory) { - logger.debug('[createOpenAPIPlugin] openAPI chain: memory detected', memory); - const { history: chat_history } = await memory.loadMemoryVariables({}); - history = chat_history?.length > 0 ? `\n\n## Chat History:\n${chat_history}\n` : ''; - } - - chainOptions.prompt = ChatPromptTemplate.fromMessages([ - HumanMessagePromptTemplate.fromTemplate( - `# Use the provided API's to respond to this query:\n\n{query}\n\n## Instructions:\n${addLinePrefix( - description_for_model, - )}${history}`, - ), - ]); - - const chain = await createOpenAPIChain(spec, chainOptions); - - const { functions } = chain.chains[0].lc_kwargs.llmKwargs; - - return new DynamicStructuredTool({ - name: name_for_model, - description_for_model: `${addLinePrefix(description_for_human)}${createPrompt( - name_for_model, - functions, - )}`, - description: `${description_for_human}`, - schema: z.object({ - func: z - .string() - .describe( - `The function to invoke. The functions available are: ${functions - .map((func) => func.name) - .join(', ')}`, - ), - intent: z - .string() - .describe('Describe your intent with the function and your expected result'), - }), - func: async ({ func = '', intent = '' }) => { - const filteredFunctions = functions.filter((f) => f.name === func); - chain.chains[0].lc_kwargs.llmKwargs.functions = filteredFunctions; - const query = `${message}${func?.length > 0 ? `\n// Intent: ${intent}` : ''}`; - const result = await chain.call({ - query, - signal, - }); - return result.response; - }, - }); -} - -module.exports = { - getSpec, - readSpecFile, - createOpenAPIPlugin, -}; diff --git a/api/app/clients/tools/dynamic/OpenAPIPlugin.spec.js b/api/app/clients/tools/dynamic/OpenAPIPlugin.spec.js deleted file mode 100644 index 83bc5e9397..0000000000 --- a/api/app/clients/tools/dynamic/OpenAPIPlugin.spec.js +++ /dev/null @@ -1,72 +0,0 @@ -const fs = require('fs'); -const { createOpenAPIPlugin, getSpec, readSpecFile } = require('./OpenAPIPlugin'); - -global.fetch = jest.fn().mockImplementationOnce(() => { - return new Promise((resolve) => { - resolve({ - ok: true, - json: () => Promise.resolve({ key: 'value' }), - }); - }); -}); -jest.mock('fs', () => ({ - promises: { - readFile: jest.fn(), - }, - existsSync: jest.fn(), -})); - -describe('readSpecFile', () => { - it('reads JSON file correctly', async () => { - fs.promises.readFile.mockResolvedValue(JSON.stringify({ test: 'value' })); - const result = await readSpecFile('test.json'); - expect(result).toEqual({ test: 'value' }); - }); - - it('reads YAML file correctly', async () => { - fs.promises.readFile.mockResolvedValue('test: value'); - const result = await readSpecFile('test.yaml'); - expect(result).toEqual({ test: 'value' }); - }); - - it('handles error correctly', async () => { - fs.promises.readFile.mockRejectedValue(new Error('test error')); - const result = await readSpecFile('test.json'); - expect(result).toBe(false); - }); -}); - -describe('getSpec', () => { - it('fetches spec from url correctly', async () => { - const parsedJson = await getSpec('https://www.instacart.com/.well-known/ai-plugin.json'); - const isObject = typeof parsedJson === 'object'; - expect(isObject).toEqual(true); - }); - - it('reads spec from file correctly', async () => { - fs.existsSync.mockReturnValue(true); - fs.promises.readFile.mockResolvedValue(JSON.stringify({ test: 'value' })); - const result = await getSpec('test.json'); - expect(result).toEqual({ test: 'value' }); - }); - - it('returns false when file does not exist', async () => { - fs.existsSync.mockReturnValue(false); - const result = await getSpec('test.json'); - expect(result).toBe(false); - }); -}); - -describe('createOpenAPIPlugin', () => { - it('returns null when getSpec throws an error', async () => { - const result = await createOpenAPIPlugin({ data: { api: { url: 'invalid' } } }); - expect(result).toBe(null); - }); - - it('returns null when no spec is found', async () => { - const result = await createOpenAPIPlugin({}); - expect(result).toBe(null); - }); - - // Add more tests here for different scenarios -}); diff --git a/api/app/clients/tools/structured/DALLE3.js b/api/app/clients/tools/structured/DALLE3.js index fc0f1851f6..7c2a56fe71 100644 --- a/api/app/clients/tools/structured/DALLE3.js +++ b/api/app/clients/tools/structured/DALLE3.js @@ -8,10 +8,10 @@ const { HttpsProxyAgent } = require('https-proxy-agent'); const { FileContext, ContentTypes } = require('librechat-data-provider'); const { getImageBasename } = require('~/server/services/Files/images'); const extractBaseURL = require('~/utils/extractBaseURL'); -const { logger } = require('~/config'); +const logger = require('~/config/winston'); const displayMessage = - 'DALL-E displayed an image. All generated images are already plainly visible, so don\'t repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.'; + "DALL-E displayed an image. All generated images are already plainly visible, so don't repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user."; class DALLE3 extends Tool { constructor(fields = {}) { super(); diff --git a/api/app/clients/tools/structured/OpenAIImageTools.js b/api/app/clients/tools/structured/OpenAIImageTools.js index afea9dfd55..08e15a7fad 100644 --- a/api/app/clients/tools/structured/OpenAIImageTools.js +++ b/api/app/clients/tools/structured/OpenAIImageTools.js @@ -4,12 +4,13 @@ const { v4 } = require('uuid'); const OpenAI = require('openai'); const FormData = require('form-data'); const { tool } = require('@langchain/core/tools'); +const { logAxiosError } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { HttpsProxyAgent } = require('https-proxy-agent'); const { ContentTypes, EImageOutputType } = require('librechat-data-provider'); const { getStrategyFunctions } = require('~/server/services/Files/strategies'); -const { logAxiosError, extractBaseURL } = require('~/utils'); +const { extractBaseURL } = require('~/utils'); const { getFiles } = require('~/models/File'); -const { logger } = require('~/config'); /** Default descriptions for image generation tool */ const DEFAULT_IMAGE_GEN_DESCRIPTION = ` @@ -64,7 +65,7 @@ const DEFAULT_IMAGE_EDIT_PROMPT_DESCRIPTION = `Describe the changes, enhancement Always base this prompt on the most recently uploaded reference images.`; const displayMessage = - 'The tool displayed an image. All generated images are already plainly visible, so don\'t repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.'; + "The tool displayed an image. All generated images are already plainly visible, so don't repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user."; /** * Replaces unwanted characters from the input string diff --git a/api/app/clients/tools/structured/specs/DALLE3.spec.js b/api/app/clients/tools/structured/specs/DALLE3.spec.js index 1b28de2faf..2def575fb3 100644 --- a/api/app/clients/tools/structured/specs/DALLE3.spec.js +++ b/api/app/clients/tools/structured/specs/DALLE3.spec.js @@ -1,10 +1,29 @@ const OpenAI = require('openai'); const DALLE3 = require('../DALLE3'); - -const { logger } = require('~/config'); +const logger = require('~/config/winston'); jest.mock('openai'); +jest.mock('@librechat/data-schemas', () => { + return { + logger: { + info: jest.fn(), + warn: jest.fn(), + debug: jest.fn(), + error: jest.fn(), + }, + }; +}); + +jest.mock('tiktoken', () => { + return { + encoding_for_model: jest.fn().mockReturnValue({ + encode: jest.fn(), + decode: jest.fn(), + }), + }; +}); + const processFileURL = jest.fn(); jest.mock('~/server/services/Files/images', () => ({ @@ -37,6 +56,11 @@ jest.mock('fs', () => { return { existsSync: jest.fn(), mkdirSync: jest.fn(), + promises: { + writeFile: jest.fn(), + readFile: jest.fn(), + unlink: jest.fn(), + }, }; }); diff --git a/api/app/clients/tools/util/fileSearch.js b/api/app/clients/tools/util/fileSearch.js index 54da483362..19d3a79edb 100644 --- a/api/app/clients/tools/util/fileSearch.js +++ b/api/app/clients/tools/util/fileSearch.js @@ -135,7 +135,7 @@ const createFileSearchTool = async ({ req, files, entity_id }) => { query: z .string() .describe( - 'A natural language query to search for relevant information in the files. Be specific and use keywords related to the information you\'re looking for. The query will be used for semantic similarity matching against the file contents.', + "A natural language query to search for relevant information in the files. Be specific and use keywords related to the information you're looking for. The query will be used for semantic similarity matching against the file contents.", ), }), }, diff --git a/api/app/clients/tools/util/handleTools.test.js b/api/app/clients/tools/util/handleTools.test.js index 6538ce9aa4..1cacda8159 100644 --- a/api/app/clients/tools/util/handleTools.test.js +++ b/api/app/clients/tools/util/handleTools.test.js @@ -1,8 +1,5 @@ -const mockUser = { - _id: 'fakeId', - save: jest.fn(), - findByIdAndDelete: jest.fn(), -}; +const mongoose = require('mongoose'); +const { MongoMemoryServer } = require('mongodb-memory-server'); const mockPluginService = { updateUserPluginAuth: jest.fn(), @@ -10,23 +7,18 @@ const mockPluginService = { getUserPluginAuthValue: jest.fn(), }; -jest.mock('~/models/User', () => { - return function () { - return mockUser; - }; -}); - jest.mock('~/server/services/PluginService', () => mockPluginService); const { BaseLLM } = require('@langchain/openai'); const { Calculator } = require('@langchain/community/tools/calculator'); -const User = require('~/models/User'); +const { User } = require('~/db/models'); const PluginService = require('~/server/services/PluginService'); const { validateTools, loadTools, loadToolWithAuth } = require('./handleTools'); const { StructuredSD, availableTools, DALLE3 } = require('../'); describe('Tool Handlers', () => { + let mongoServer; let fakeUser; const pluginKey = 'dalle'; const pluginKey2 = 'wolfram'; @@ -37,7 +29,9 @@ describe('Tool Handlers', () => { const authConfigs = mainPlugin.authConfig; beforeAll(async () => { - mockUser.save.mockResolvedValue(undefined); + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + await mongoose.connect(mongoUri); const userAuthValues = {}; mockPluginService.getUserPluginAuthValue.mockImplementation((userId, authField) => { @@ -78,9 +72,36 @@ describe('Tool Handlers', () => { }); afterAll(async () => { - await mockUser.findByIdAndDelete(fakeUser._id); + await mongoose.disconnect(); + await mongoServer.stop(); + }); + + beforeEach(async () => { + // Clear mocks but not the database since we need the user to persist + jest.clearAllMocks(); + + // Reset the mock implementations + const userAuthValues = {}; + mockPluginService.getUserPluginAuthValue.mockImplementation((userId, authField) => { + return userAuthValues[`${userId}-${authField}`]; + }); + mockPluginService.updateUserPluginAuth.mockImplementation( + (userId, authField, _pluginKey, credential) => { + const fields = authField.split('||'); + fields.forEach((field) => { + userAuthValues[`${userId}-${field}`] = credential; + }); + }, + ); + + // Re-add the auth configs for the user for (const authConfig of authConfigs) { - await PluginService.deleteUserPluginAuth(fakeUser._id, authConfig.authField); + await PluginService.updateUserPluginAuth( + fakeUser._id, + authConfig.authField, + pluginKey, + mockCredential, + ); } }); @@ -218,7 +239,6 @@ describe('Tool Handlers', () => { try { await loadTool2(); } catch (error) { - // eslint-disable-next-line jest/no-conditional-expect expect(error).toBeDefined(); } }); diff --git a/api/cache/banViolation.js b/api/cache/banViolation.js index cdbff85c54..17b23f1c12 100644 --- a/api/cache/banViolation.js +++ b/api/cache/banViolation.js @@ -1,8 +1,8 @@ +const { logger } = require('@librechat/data-schemas'); const { ViolationTypes } = require('librechat-data-provider'); const { isEnabled, math, removePorts } = require('~/server/utils'); const { deleteAllUserSessions } = require('~/models'); const getLogStores = require('./getLogStores'); -const { logger } = require('~/config'); const { BAN_VIOLATIONS, BAN_INTERVAL } = process.env ?? {}; const interval = math(BAN_INTERVAL, 20); @@ -32,7 +32,6 @@ const banViolation = async (req, res, errorMessage) => { if (!isEnabled(BAN_VIOLATIONS)) { return; } - if (!errorMessage) { return; } @@ -51,7 +50,6 @@ const banViolation = async (req, res, errorMessage) => { const banLogs = getLogStores(ViolationTypes.BAN); const duration = errorMessage.duration || banLogs.opts.ttl; - if (duration <= 0) { return; } diff --git a/api/cache/banViolation.spec.js b/api/cache/banViolation.spec.js index 8fef16920f..df98753498 100644 --- a/api/cache/banViolation.spec.js +++ b/api/cache/banViolation.spec.js @@ -1,48 +1,28 @@ +const mongoose = require('mongoose'); +const { MongoMemoryServer } = require('mongodb-memory-server'); const banViolation = require('./banViolation'); -jest.mock('keyv'); -jest.mock('../models/Session'); -// Mocking the getLogStores function -jest.mock('./getLogStores', () => { - return jest.fn().mockImplementation(() => { - const EventEmitter = require('events'); - const { CacheKeys } = require('librechat-data-provider'); - const math = require('../server/utils/math'); - const mockGet = jest.fn(); - const mockSet = jest.fn(); - class KeyvMongo extends EventEmitter { - constructor(url = 'mongodb://127.0.0.1:27017', options) { - super(); - this.ttlSupport = false; - url = url ?? {}; - if (typeof url === 'string') { - url = { url }; - } - if (url.uri) { - url = { url: url.uri, ...url }; - } - this.opts = { - url, - collection: 'keyv', - ...url, - ...options, - }; - } - - get = mockGet; - set = mockSet; - } - - return new KeyvMongo('', { - namespace: CacheKeys.BANS, - ttl: math(process.env.BAN_DURATION, 7200000), - }); - }); -}); +// Mock deleteAllUserSessions since we're testing ban logic, not session deletion +jest.mock('~/models', () => ({ + ...jest.requireActual('~/models'), + deleteAllUserSessions: jest.fn().mockResolvedValue(true), +})); describe('banViolation', () => { + let mongoServer; let req, res, errorMessage; + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + await mongoose.connect(mongoUri); + }); + + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); + }); + beforeEach(() => { req = { ip: '127.0.0.1', @@ -55,7 +35,7 @@ describe('banViolation', () => { }; errorMessage = { type: 'someViolation', - user_id: '12345', + user_id: new mongoose.Types.ObjectId().toString(), // Use valid ObjectId prev_count: 0, violation_count: 0, }; diff --git a/api/cache/getLogStores.js b/api/cache/getLogStores.js index b0a6a822ac..2478bf40d9 100644 --- a/api/cache/getLogStores.js +++ b/api/cache/getLogStores.js @@ -1,7 +1,7 @@ const { Keyv } = require('keyv'); const { CacheKeys, ViolationTypes, Time } = require('librechat-data-provider'); const { logFile, violationFile } = require('./keyvFiles'); -const { math, isEnabled } = require('~/server/utils'); +const { isEnabled, math } = require('~/server/utils'); const keyvRedis = require('./keyvRedis'); const keyvMongo = require('./keyvMongo'); diff --git a/api/config/index.js b/api/config/index.js index e238f700be..a02c75887e 100644 --- a/api/config/index.js +++ b/api/config/index.js @@ -1,7 +1,6 @@ -const axios = require('axios'); const { EventSource } = require('eventsource'); -const { Time, CacheKeys } = require('librechat-data-provider'); -const { MCPManager, FlowStateManager } = require('librechat-mcp'); +const { Time } = require('librechat-data-provider'); +const { MCPManager, FlowStateManager } = require('@librechat/api'); const logger = require('./winston'); global.EventSource = EventSource; @@ -37,60 +36,8 @@ function getFlowStateManager(flowsCache) { return flowManager; } -/** - * Sends message data in Server Sent Events format. - * @param {ServerResponse} res - The server response. - * @param {{ data: string | Record, event?: string }} event - The message event. - * @param {string} event.event - The type of event. - * @param {string} event.data - The message to be sent. - */ -const sendEvent = (res, event) => { - if (typeof event.data === 'string' && event.data.length === 0) { - return; - } - res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`); -}; - -/** - * Creates and configures an Axios instance with optional proxy settings. - * - * @typedef {import('axios').AxiosInstance} AxiosInstance - * @typedef {import('axios').AxiosProxyConfig} AxiosProxyConfig - * - * @returns {AxiosInstance} A configured Axios instance - * @throws {Error} If there's an issue creating the Axios instance or parsing the proxy URL - */ -function createAxiosInstance() { - const instance = axios.create(); - - if (process.env.proxy) { - try { - const url = new URL(process.env.proxy); - - /** @type {AxiosProxyConfig} */ - const proxyConfig = { - host: url.hostname.replace(/^\[|\]$/g, ''), - protocol: url.protocol.replace(':', ''), - }; - - if (url.port) { - proxyConfig.port = parseInt(url.port, 10); - } - - instance.defaults.proxy = proxyConfig; - } catch (error) { - console.error('Error parsing proxy URL:', error); - throw new Error(`Invalid proxy URL: ${process.env.proxy}`); - } - } - - return instance; -} - module.exports = { logger, - sendEvent, getMCPManager, - createAxiosInstance, getFlowStateManager, }; diff --git a/api/lib/db/connectDb.js b/api/db/connect.js similarity index 96% rename from api/lib/db/connectDb.js rename to api/db/connect.js index b8cbeb2adb..e88ffa51ed 100644 --- a/api/lib/db/connectDb.js +++ b/api/db/connect.js @@ -39,7 +39,10 @@ async function connectDb() { }); } cached.conn = await cached.promise; + return cached.conn; } -module.exports = connectDb; +module.exports = { + connectDb, +}; diff --git a/api/db/index.js b/api/db/index.js new file mode 100644 index 0000000000..5c29902f69 --- /dev/null +++ b/api/db/index.js @@ -0,0 +1,8 @@ +const mongoose = require('mongoose'); +const { createModels } = require('@librechat/data-schemas'); +const { connectDb } = require('./connect'); +const indexSync = require('./indexSync'); + +createModels(mongoose); + +module.exports = { connectDb, indexSync }; diff --git a/api/lib/db/indexSync.js b/api/db/indexSync.js similarity index 93% rename from api/lib/db/indexSync.js rename to api/db/indexSync.js index 75acd9d231..e8bcd55e37 100644 --- a/api/lib/db/indexSync.js +++ b/api/db/indexSync.js @@ -1,8 +1,11 @@ +const mongoose = require('mongoose'); const { MeiliSearch } = require('meilisearch'); -const { Conversation } = require('~/models/Conversation'); -const { Message } = require('~/models/Message'); +const { logger } = require('@librechat/data-schemas'); + const { isEnabled } = require('~/server/utils'); -const { logger } = require('~/config'); + +const Conversation = mongoose.models.Conversation; +const Message = mongoose.models.Message; const searchEnabled = isEnabled(process.env.SEARCH); const indexingDisabled = isEnabled(process.env.MEILI_NO_SYNC); @@ -29,7 +32,6 @@ async function indexSync() { if (!searchEnabled) { return; } - try { const client = MeiliSearchClient.getInstance(); diff --git a/api/db/models.js b/api/db/models.js new file mode 100644 index 0000000000..fca1327446 --- /dev/null +++ b/api/db/models.js @@ -0,0 +1,5 @@ +const mongoose = require('mongoose'); +const { createModels } = require('@librechat/data-schemas'); +const models = createModels(mongoose); + +module.exports = { ...models }; diff --git a/api/lib/db/index.js b/api/lib/db/index.js deleted file mode 100644 index fa7a460d05..0000000000 --- a/api/lib/db/index.js +++ /dev/null @@ -1,4 +0,0 @@ -const connectDb = require('./connectDb'); -const indexSync = require('./indexSync'); - -module.exports = { connectDb, indexSync }; diff --git a/api/models/Action.js b/api/models/Action.js index 677b4d78df..20aa20a7e4 100644 --- a/api/models/Action.js +++ b/api/models/Action.js @@ -1,7 +1,4 @@ -const mongoose = require('mongoose'); -const { actionSchema } = require('@librechat/data-schemas'); - -const Action = mongoose.model('action', actionSchema); +const { Action } = require('~/db/models'); /** * Update an action with new data without overwriting existing properties, diff --git a/api/models/Agent.js b/api/models/Agent.js index 11fd6dabb2..297604c444 100644 --- a/api/models/Agent.js +++ b/api/models/Agent.js @@ -1,6 +1,7 @@ const mongoose = require('mongoose'); -const { agentSchema } = require('@librechat/data-schemas'); -const { SystemRoles, Tools } = require('librechat-data-provider'); +const crypto = require('node:crypto'); +const { logger } = require('@librechat/data-schemas'); +const { SystemRoles, Tools, actionDelimiter } = require('librechat-data-provider'); const { GLOBAL_PROJECT_NAME, EPHEMERAL_AGENT_ID, mcp_delimiter } = require('librechat-data-provider').Constants; const { CONFIG_STORE, STARTUP_CONFIG } = require('librechat-data-provider').CacheKeys; @@ -11,8 +12,8 @@ const { removeAgentFromAllProjects, } = require('./Project'); const getLogStores = require('~/cache/getLogStores'); - -const Agent = mongoose.model('agent', agentSchema); +const { getActions } = require('./Action'); +const { Agent } = require('~/db/models'); /** * Create an agent with the provided data. @@ -149,10 +150,12 @@ const loadAgent = async ({ req, agent_id, endpoint, model_parameters }) => { /** * Check if a version already exists in the versions array, excluding timestamp and author fields * @param {Object} updateData - The update data to compare + * @param {Object} currentData - The current agent data * @param {Array} versions - The existing versions array + * @param {string} [actionsHash] - Hash of current action metadata * @returns {Object|null} - The matching version if found, null otherwise */ -const isDuplicateVersion = (updateData, currentData, versions) => { +const isDuplicateVersion = (updateData, currentData, versions, actionsHash = null) => { if (!versions || versions.length === 0) { return null; } @@ -167,19 +170,23 @@ const isDuplicateVersion = (updateData, currentData, versions) => { 'created_at', 'updated_at', '__v', - 'agent_ids', 'versions', + 'actionsHash', // Exclude actionsHash from direct comparison ]; const { $push, $pull, $addToSet, ...directUpdates } = updateData; - if (Object.keys(directUpdates).length === 0) { + if (Object.keys(directUpdates).length === 0 && !actionsHash) { return null; } const wouldBeVersion = { ...currentData, ...directUpdates }; const lastVersion = versions[versions.length - 1]; + if (actionsHash && lastVersion.actionsHash !== actionsHash) { + return null; + } + const allFields = new Set([...Object.keys(wouldBeVersion), ...Object.keys(lastVersion)]); const importantFields = Array.from(allFields).filter((field) => !excludeFields.includes(field)); @@ -249,21 +256,57 @@ const isDuplicateVersion = (updateData, currentData, versions) => { * @param {string} searchParameter.id - The ID of the agent to update. * @param {string} [searchParameter.author] - The user ID of the agent's author. * @param {Object} updateData - An object containing the properties to update. - * @param {string} [updatingUserId] - The ID of the user performing the update (used for tracking non-author updates). + * @param {Object} [options] - Optional configuration object. + * @param {string} [options.updatingUserId] - The ID of the user performing the update (used for tracking non-author updates). + * @param {boolean} [options.forceVersion] - Force creation of a new version even if no fields changed. + * @param {boolean} [options.skipVersioning] - Skip version creation entirely (useful for isolated operations like sharing). * @returns {Promise} The updated or newly created agent document as a plain object. * @throws {Error} If the update would create a duplicate version */ -const updateAgent = async (searchParameter, updateData, updatingUserId = null) => { - const options = { new: true, upsert: false }; +const updateAgent = async (searchParameter, updateData, options = {}) => { + const { updatingUserId = null, forceVersion = false, skipVersioning = false } = options; + const mongoOptions = { new: true, upsert: false }; const currentAgent = await Agent.findOne(searchParameter); if (currentAgent) { const { __v, _id, id, versions, author, ...versionData } = currentAgent.toObject(); const { $push, $pull, $addToSet, ...directUpdates } = updateData; - if (Object.keys(directUpdates).length > 0 && versions && versions.length > 0) { - const duplicateVersion = isDuplicateVersion(updateData, versionData, versions); - if (duplicateVersion) { + let actionsHash = null; + + // Generate actions hash if agent has actions + if (currentAgent.actions && currentAgent.actions.length > 0) { + // Extract action IDs from the format "domain_action_id" + const actionIds = currentAgent.actions + .map((action) => { + const parts = action.split(actionDelimiter); + return parts[1]; // Get just the action ID part + }) + .filter(Boolean); + + if (actionIds.length > 0) { + try { + const actions = await getActions( + { + action_id: { $in: actionIds }, + }, + true, + ); // Include sensitive data for hash + + actionsHash = await generateActionMetadataHash(currentAgent.actions, actions); + } catch (error) { + logger.error('Error fetching actions for hash generation:', error); + } + } + } + + const shouldCreateVersion = + !skipVersioning && + (forceVersion || Object.keys(directUpdates).length > 0 || $push || $pull || $addToSet); + + if (shouldCreateVersion) { + const duplicateVersion = isDuplicateVersion(updateData, versionData, versions, actionsHash); + if (duplicateVersion && !forceVersion) { const error = new Error( 'Duplicate version: This would create a version identical to an existing one', ); @@ -284,18 +327,25 @@ const updateAgent = async (searchParameter, updateData, updatingUserId = null) = updatedAt: new Date(), }; + // Include actions hash in version if available + if (actionsHash) { + versionEntry.actionsHash = actionsHash; + } + // Always store updatedBy field to track who made the change if (updatingUserId) { versionEntry.updatedBy = new mongoose.Types.ObjectId(updatingUserId); } - updateData.$push = { - ...($push || {}), - versions: versionEntry, - }; + if (shouldCreateVersion) { + updateData.$push = { + ...($push || {}), + versions: versionEntry, + }; + } } - return Agent.findOneAndUpdate(searchParameter, updateData, options).lean(); + return Agent.findOneAndUpdate(searchParameter, updateData, mongoOptions).lean(); }; /** @@ -333,7 +383,9 @@ const addAgentResourceFile = async ({ req, agent_id, tool_resource, file_id }) = }, }; - const updatedAgent = await updateAgent(searchParameter, updateData, req?.user?.id); + const updatedAgent = await updateAgent(searchParameter, updateData, { + updatingUserId: req?.user?.id, + }); if (updatedAgent) { return updatedAgent; } else { @@ -425,7 +477,6 @@ const getListAgents = async (searchParameter) => { delete globalQuery.author; query = { $or: [globalQuery, query] }; } - const agents = ( await Agent.find(query, { id: 1, @@ -497,7 +548,10 @@ const updateAgentProjects = async ({ user, agentId, projectIds, removeProjectIds delete updateQuery.author; } - const updatedAgent = await updateAgent(updateQuery, updateOps, user.id); + const updatedAgent = await updateAgent(updateQuery, updateOps, { + updatingUserId: user.id, + skipVersioning: true, + }); if (updatedAgent) { return updatedAgent; } @@ -548,16 +602,73 @@ const revertAgentVersion = async (searchParameter, versionIndex) => { return Agent.findOneAndUpdate(searchParameter, updateData, { new: true }).lean(); }; +/** + * Generates a hash of action metadata for version comparison + * @param {string[]} actionIds - Array of action IDs in format "domain_action_id" + * @param {Action[]} actions - Array of action documents + * @returns {Promise} - SHA256 hash of the action metadata + */ +const generateActionMetadataHash = async (actionIds, actions) => { + if (!actionIds || actionIds.length === 0) { + return ''; + } + + // Create a map of action_id to metadata for quick lookup + const actionMap = new Map(); + actions.forEach((action) => { + actionMap.set(action.action_id, action.metadata); + }); + + // Sort action IDs for consistent hashing + const sortedActionIds = [...actionIds].sort(); + + // Build a deterministic string representation of all action metadata + const metadataString = sortedActionIds + .map((actionFullId) => { + // Extract just the action_id part (after the delimiter) + const parts = actionFullId.split(actionDelimiter); + const actionId = parts[1]; + + const metadata = actionMap.get(actionId); + if (!metadata) { + return `${actionId}:null`; + } + + // Sort metadata keys for deterministic output + const sortedKeys = Object.keys(metadata).sort(); + const metadataStr = sortedKeys + .map((key) => `${key}:${JSON.stringify(metadata[key])}`) + .join(','); + return `${actionId}:{${metadataStr}}`; + }) + .join(';'); + + // Use Web Crypto API to generate hash + const encoder = new TextEncoder(); + const data = encoder.encode(metadataString); + const hashBuffer = await crypto.webcrypto.subtle.digest('SHA-256', data); + const hashArray = Array.from(new Uint8Array(hashBuffer)); + const hashHex = hashArray.map((b) => b.toString(16).padStart(2, '0')).join(''); + + return hashHex; +}; + +/** + * Load a default agent based on the endpoint + * @param {string} endpoint + * @returns {Agent | null} + */ + module.exports = { - Agent, getAgent, loadAgent, createAgent, updateAgent, deleteAgent, getListAgents, + revertAgentVersion, updateAgentProjects, addAgentResourceFile, removeAgentResourceFiles, - revertAgentVersion, + generateActionMetadataHash, }; diff --git a/api/models/Agent.spec.js b/api/models/Agent.spec.js index 57d54171c4..1e18168147 100644 --- a/api/models/Agent.spec.js +++ b/api/models/Agent.spec.js @@ -8,1018 +8,2660 @@ process.env.CREDS_IV = '0123456789abcdef'; const mongoose = require('mongoose'); const { v4: uuidv4 } = require('uuid'); +const { agentSchema } = require('@librechat/data-schemas'); const { MongoMemoryServer } = require('mongodb-memory-server'); const { - Agent, - addAgentResourceFile, - removeAgentResourceFiles, + getAgent, + loadAgent, createAgent, updateAgent, - getAgent, deleteAgent, getListAgents, updateAgentProjects, + addAgentResourceFile, + removeAgentResourceFiles, + generateActionMetadataHash, + revertAgentVersion, } = require('./Agent'); -describe('Agent Resource File Operations', () => { - let mongoServer; +/** + * @type {import('mongoose').Model} + */ +let Agent; - beforeAll(async () => { - mongoServer = await MongoMemoryServer.create(); - const mongoUri = mongoServer.getUri(); - await mongoose.connect(mongoUri); - }); +describe('models/Agent', () => { + describe('Agent Resource File Operations', () => { + let mongoServer; - afterAll(async () => { - await mongoose.disconnect(); - await mongoServer.stop(); - process.env.CREDS_KEY = originalEnv.CREDS_KEY; - process.env.CREDS_IV = originalEnv.CREDS_IV; - }); - - beforeEach(async () => { - await Agent.deleteMany({}); - }); - - const createBasicAgent = async () => { - const agentId = `agent_${uuidv4()}`; - const agent = await Agent.create({ - id: agentId, - name: 'Test Agent', - provider: 'test', - model: 'test-model', - author: new mongoose.Types.ObjectId(), - }); - return agent; - }; - - test('should add tool_resource to tools if missing', async () => { - const agent = await createBasicAgent(); - const fileId = uuidv4(); - const toolResource = 'file_search'; - - const updatedAgent = await addAgentResourceFile({ - agent_id: agent.id, - tool_resource: toolResource, - file_id: fileId, + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema); + await mongoose.connect(mongoUri); }); - expect(updatedAgent.tools).toContain(toolResource); - expect(Array.isArray(updatedAgent.tools)).toBe(true); - // Should not duplicate - const count = updatedAgent.tools.filter((t) => t === toolResource).length; - expect(count).toBe(1); - }); - - test('should not duplicate tool_resource in tools if already present', async () => { - const agent = await createBasicAgent(); - const fileId1 = uuidv4(); - const fileId2 = uuidv4(); - const toolResource = 'file_search'; - - // First add - await addAgentResourceFile({ - agent_id: agent.id, - tool_resource: toolResource, - file_id: fileId1, + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); + process.env.CREDS_KEY = originalEnv.CREDS_KEY; + process.env.CREDS_IV = originalEnv.CREDS_IV; }); - // Second add (should not duplicate) - const updatedAgent = await addAgentResourceFile({ - agent_id: agent.id, - tool_resource: toolResource, - file_id: fileId2, + beforeEach(async () => { + await Agent.deleteMany({}); }); - expect(updatedAgent.tools).toContain(toolResource); - expect(Array.isArray(updatedAgent.tools)).toBe(true); - const count = updatedAgent.tools.filter((t) => t === toolResource).length; - expect(count).toBe(1); - }); + test('should add tool_resource to tools if missing', async () => { + const agent = await createBasicAgent(); + const fileId = uuidv4(); + const toolResource = 'file_search'; - test('should handle concurrent file additions', async () => { - const agent = await createBasicAgent(); - const fileIds = Array.from({ length: 10 }, () => uuidv4()); + const updatedAgent = await addAgentResourceFile({ + agent_id: agent.id, + tool_resource: toolResource, + file_id: fileId, + }); - // Concurrent additions - const additionPromises = fileIds.map((fileId) => - addAgentResourceFile({ + expect(updatedAgent.tools).toContain(toolResource); + expect(Array.isArray(updatedAgent.tools)).toBe(true); + // Should not duplicate + const count = updatedAgent.tools.filter((t) => t === toolResource).length; + expect(count).toBe(1); + }); + + test('should not duplicate tool_resource in tools if already present', async () => { + const agent = await createBasicAgent(); + const fileId1 = uuidv4(); + const fileId2 = uuidv4(); + const toolResource = 'file_search'; + + // First add + await addAgentResourceFile({ + agent_id: agent.id, + tool_resource: toolResource, + file_id: fileId1, + }); + + // Second add (should not duplicate) + const updatedAgent = await addAgentResourceFile({ + agent_id: agent.id, + tool_resource: toolResource, + file_id: fileId2, + }); + + expect(updatedAgent.tools).toContain(toolResource); + expect(Array.isArray(updatedAgent.tools)).toBe(true); + const count = updatedAgent.tools.filter((t) => t === toolResource).length; + expect(count).toBe(1); + }); + + test('should handle concurrent file additions', async () => { + const agent = await createBasicAgent(); + const fileIds = Array.from({ length: 10 }, () => uuidv4()); + + // Concurrent additions + const additionPromises = createFileOperations(agent.id, fileIds, 'add'); + + await Promise.all(additionPromises); + + const updatedAgent = await Agent.findOne({ id: agent.id }); + expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined(); + expect(updatedAgent.tool_resources.test_tool.file_ids).toHaveLength(10); + expect(new Set(updatedAgent.tool_resources.test_tool.file_ids).size).toBe(10); + }); + + test('should handle concurrent additions and removals', async () => { + const agent = await createBasicAgent(); + const initialFileIds = Array.from({ length: 5 }, () => uuidv4()); + + await Promise.all(createFileOperations(agent.id, initialFileIds, 'add')); + + const newFileIds = Array.from({ length: 5 }, () => uuidv4()); + const operations = [ + ...newFileIds.map((fileId) => + addAgentResourceFile({ + agent_id: agent.id, + tool_resource: 'test_tool', + file_id: fileId, + }), + ), + ...initialFileIds.map((fileId) => + removeAgentResourceFiles({ + agent_id: agent.id, + files: [{ tool_resource: 'test_tool', file_id: fileId }], + }), + ), + ]; + + await Promise.all(operations); + + const updatedAgent = await Agent.findOne({ id: agent.id }); + expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined(); + expect(updatedAgent.tool_resources.test_tool.file_ids).toHaveLength(5); + }); + + test('should initialize array when adding to non-existent tool resource', async () => { + const agent = await createBasicAgent(); + const fileId = uuidv4(); + + const updatedAgent = await addAgentResourceFile({ + agent_id: agent.id, + tool_resource: 'new_tool', + file_id: fileId, + }); + + expect(updatedAgent.tool_resources.new_tool.file_ids).toBeDefined(); + expect(updatedAgent.tool_resources.new_tool.file_ids).toHaveLength(1); + expect(updatedAgent.tool_resources.new_tool.file_ids[0]).toBe(fileId); + }); + + test('should handle rapid sequential modifications to same tool resource', async () => { + const agent = await createBasicAgent(); + const fileId = uuidv4(); + + for (let i = 0; i < 10; i++) { + await addAgentResourceFile({ + agent_id: agent.id, + tool_resource: 'test_tool', + file_id: `${fileId}_${i}`, + }); + + if (i % 2 === 0) { + await removeAgentResourceFiles({ + agent_id: agent.id, + files: [{ tool_resource: 'test_tool', file_id: `${fileId}_${i}` }], + }); + } + } + + const updatedAgent = await Agent.findOne({ id: agent.id }); + expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined(); + expect(Array.isArray(updatedAgent.tool_resources.test_tool.file_ids)).toBe(true); + }); + + test('should handle multiple tool resources concurrently', async () => { + const agent = await createBasicAgent(); + const toolResources = ['tool1', 'tool2', 'tool3']; + const operations = []; + + toolResources.forEach((tool) => { + const fileIds = Array.from({ length: 5 }, () => uuidv4()); + fileIds.forEach((fileId) => { + operations.push( + addAgentResourceFile({ + agent_id: agent.id, + tool_resource: tool, + file_id: fileId, + }), + ); + }); + }); + + await Promise.all(operations); + + const updatedAgent = await Agent.findOne({ id: agent.id }); + toolResources.forEach((tool) => { + expect(updatedAgent.tool_resources[tool].file_ids).toBeDefined(); + expect(updatedAgent.tool_resources[tool].file_ids).toHaveLength(5); + }); + }); + + test.each([ + { + name: 'duplicate additions', + operation: 'add', + duplicateCount: 5, + expectedLength: 1, + expectedContains: true, + }, + { + name: 'duplicate removals', + operation: 'remove', + duplicateCount: 5, + expectedLength: 0, + expectedContains: false, + setupFile: true, + }, + ])( + 'should handle concurrent $name', + async ({ operation, duplicateCount, expectedLength, expectedContains, setupFile }) => { + const agent = await createBasicAgent(); + const fileId = uuidv4(); + + if (setupFile) { + await addAgentResourceFile({ + agent_id: agent.id, + tool_resource: 'test_tool', + file_id: fileId, + }); + } + + const promises = Array.from({ length: duplicateCount }).map(() => + operation === 'add' + ? addAgentResourceFile({ + agent_id: agent.id, + tool_resource: 'test_tool', + file_id: fileId, + }) + : removeAgentResourceFiles({ + agent_id: agent.id, + files: [{ tool_resource: 'test_tool', file_id: fileId }], + }), + ); + + await Promise.all(promises); + + const updatedAgent = await Agent.findOne({ id: agent.id }); + const fileIds = updatedAgent.tool_resources?.test_tool?.file_ids ?? []; + + expect(fileIds).toHaveLength(expectedLength); + if (expectedContains) { + expect(fileIds[0]).toBe(fileId); + } else { + expect(fileIds).not.toContain(fileId); + } + }, + ); + + test('should handle concurrent add and remove of the same file', async () => { + const agent = await createBasicAgent(); + const fileId = uuidv4(); + + await addAgentResourceFile({ agent_id: agent.id, tool_resource: 'test_tool', file_id: fileId, - }), - ); + }); - await Promise.all(additionPromises); - - const updatedAgent = await Agent.findOne({ id: agent.id }); - expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined(); - expect(updatedAgent.tool_resources.test_tool.file_ids).toHaveLength(10); - expect(new Set(updatedAgent.tool_resources.test_tool.file_ids).size).toBe(10); - }); - - test('should handle concurrent additions and removals', async () => { - const agent = await createBasicAgent(); - const initialFileIds = Array.from({ length: 5 }, () => uuidv4()); - - await Promise.all( - initialFileIds.map((fileId) => + const operations = [ addAgentResourceFile({ agent_id: agent.id, tool_resource: 'test_tool', file_id: fileId, }), - ), - ); - - const newFileIds = Array.from({ length: 5 }, () => uuidv4()); - const operations = [ - ...newFileIds.map((fileId) => - addAgentResourceFile({ - agent_id: agent.id, - tool_resource: 'test_tool', - file_id: fileId, - }), - ), - ...initialFileIds.map((fileId) => removeAgentResourceFiles({ agent_id: agent.id, files: [{ tool_resource: 'test_tool', file_id: fileId }], }), - ), - ]; + ]; - await Promise.all(operations); + await Promise.all(operations); - const updatedAgent = await Agent.findOne({ id: agent.id }); - expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined(); - expect(updatedAgent.tool_resources.test_tool.file_ids).toHaveLength(5); - }); + const updatedAgent = await Agent.findOne({ id: agent.id }); + const finalFileIds = updatedAgent.tool_resources.test_tool.file_ids; + const count = finalFileIds.filter((id) => id === fileId).length; - test('should initialize array when adding to non-existent tool resource', async () => { - const agent = await createBasicAgent(); - const fileId = uuidv4(); - - const updatedAgent = await addAgentResourceFile({ - agent_id: agent.id, - tool_resource: 'new_tool', - file_id: fileId, + expect(count).toBeLessThanOrEqual(1); + if (count === 0) { + expect(finalFileIds).toHaveLength(0); + } else { + expect(finalFileIds).toHaveLength(1); + expect(finalFileIds[0]).toBe(fileId); + } }); - expect(updatedAgent.tool_resources.new_tool.file_ids).toBeDefined(); - expect(updatedAgent.tool_resources.new_tool.file_ids).toHaveLength(1); - expect(updatedAgent.tool_resources.new_tool.file_ids[0]).toBe(fileId); - }); + test('should handle concurrent removals of different files', async () => { + const agent = await createBasicAgent(); + const fileIds = Array.from({ length: 10 }, () => uuidv4()); - test('should handle rapid sequential modifications to same tool resource', async () => { - const agent = await createBasicAgent(); - const fileId = uuidv4(); - - for (let i = 0; i < 10; i++) { - await addAgentResourceFile({ - agent_id: agent.id, - tool_resource: 'test_tool', - file_id: `${fileId}_${i}`, - }); - - if (i % 2 === 0) { - await removeAgentResourceFiles({ - agent_id: agent.id, - files: [{ tool_resource: 'test_tool', file_id: `${fileId}_${i}` }], - }); - } - } - - const updatedAgent = await Agent.findOne({ id: agent.id }); - expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined(); - expect(Array.isArray(updatedAgent.tool_resources.test_tool.file_ids)).toBe(true); - }); - - test('should handle multiple tool resources concurrently', async () => { - const agent = await createBasicAgent(); - const toolResources = ['tool1', 'tool2', 'tool3']; - const operations = []; - - toolResources.forEach((tool) => { - const fileIds = Array.from({ length: 5 }, () => uuidv4()); - fileIds.forEach((fileId) => { - operations.push( + // Add all files first + await Promise.all( + fileIds.map((fileId) => addAgentResourceFile({ agent_id: agent.id, - tool_resource: tool, + tool_resource: 'test_tool', file_id: fileId, }), - ); + ), + ); + + // Concurrently remove all files + const removalPromises = fileIds.map((fileId) => + removeAgentResourceFiles({ + agent_id: agent.id, + files: [{ tool_resource: 'test_tool', file_id: fileId }], + }), + ); + + await Promise.all(removalPromises); + + const updatedAgent = await Agent.findOne({ id: agent.id }); + // Check if the array is empty or the tool resource itself is removed + const finalFileIds = updatedAgent.tool_resources?.test_tool?.file_ids ?? []; + expect(finalFileIds).toHaveLength(0); + }); + + describe('Edge Cases', () => { + describe.each([ + { + operation: 'add', + name: 'empty file_id', + needsAgent: true, + params: { tool_resource: 'file_search', file_id: '' }, + shouldResolve: true, + }, + { + operation: 'add', + name: 'non-existent agent', + needsAgent: false, + params: { tool_resource: 'file_search', file_id: 'file123' }, + shouldResolve: false, + error: 'Agent not found for adding resource file', + }, + ])('addAgentResourceFile with $name', ({ needsAgent, params, shouldResolve, error }) => { + test(`should ${shouldResolve ? 'resolve' : 'reject'}`, async () => { + const agent = needsAgent ? await createBasicAgent() : null; + const agent_id = needsAgent ? agent.id : `agent_${uuidv4()}`; + + if (shouldResolve) { + await expect(addAgentResourceFile({ agent_id, ...params })).resolves.toBeDefined(); + } else { + await expect(addAgentResourceFile({ agent_id, ...params })).rejects.toThrow(error); + } + }); + }); + + describe.each([ + { + name: 'empty files array', + files: [], + needsAgent: true, + shouldResolve: true, + }, + { + name: 'non-existent tool_resource', + files: [{ tool_resource: 'non_existent_tool', file_id: 'file123' }], + needsAgent: true, + shouldResolve: true, + }, + { + name: 'non-existent agent', + files: [{ tool_resource: 'file_search', file_id: 'file123' }], + needsAgent: false, + shouldResolve: false, + error: 'Agent not found for removing resource files', + }, + ])('removeAgentResourceFiles with $name', ({ files, needsAgent, shouldResolve, error }) => { + test(`should ${shouldResolve ? 'resolve' : 'reject'}`, async () => { + const agent = needsAgent ? await createBasicAgent() : null; + const agent_id = needsAgent ? agent.id : `agent_${uuidv4()}`; + + if (shouldResolve) { + const result = await removeAgentResourceFiles({ agent_id, files }); + expect(result).toBeDefined(); + if (agent) { + expect(result.id).toBe(agent.id); + } + } else { + await expect(removeAgentResourceFiles({ agent_id, files })).rejects.toThrow(error); + } + }); }); }); - - await Promise.all(operations); - - const updatedAgent = await Agent.findOne({ id: agent.id }); - toolResources.forEach((tool) => { - expect(updatedAgent.tool_resources[tool].file_ids).toBeDefined(); - expect(updatedAgent.tool_resources[tool].file_ids).toHaveLength(5); - }); }); - test('should handle concurrent duplicate additions', async () => { - const agent = await createBasicAgent(); - const fileId = uuidv4(); + describe('Agent CRUD Operations', () => { + let mongoServer; - // Concurrent additions of the same file - const additionPromises = Array.from({ length: 5 }).map(() => - addAgentResourceFile({ - agent_id: agent.id, - tool_resource: 'test_tool', - file_id: fileId, - }), - ); - - await Promise.all(additionPromises); - - const updatedAgent = await Agent.findOne({ id: agent.id }); - expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined(); - // Should only contain one instance of the fileId - expect(updatedAgent.tool_resources.test_tool.file_ids).toHaveLength(1); - expect(updatedAgent.tool_resources.test_tool.file_ids[0]).toBe(fileId); - }); - - test('should handle concurrent add and remove of the same file', async () => { - const agent = await createBasicAgent(); - const fileId = uuidv4(); - - // First, ensure the file exists (or test might be trivial if remove runs first) - await addAgentResourceFile({ - agent_id: agent.id, - tool_resource: 'test_tool', - file_id: fileId, + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + await mongoose.connect(mongoUri); }); - // Concurrent add (which should be ignored) and remove - const operations = [ - addAgentResourceFile({ - agent_id: agent.id, - tool_resource: 'test_tool', - file_id: fileId, - }), - removeAgentResourceFiles({ - agent_id: agent.id, - files: [{ tool_resource: 'test_tool', file_id: fileId }], - }), - ]; - - await Promise.all(operations); - - const updatedAgent = await Agent.findOne({ id: agent.id }); - // The final state should ideally be that the file is removed, - // but the key point is consistency (not duplicated or error state). - // Depending on execution order, the file might remain if the add operation's - // findOneAndUpdate runs after the remove operation completes. - // A more robust check might be that the length is <= 1. - // Given the remove uses an update pipeline, it might be more likely to win. - // The final state depends on race condition timing (add or remove might "win"). - // The critical part is that the state is consistent (no duplicates, no errors). - // Assert that the fileId is either present exactly once or not present at all. - expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined(); - const finalFileIds = updatedAgent.tool_resources.test_tool.file_ids; - const count = finalFileIds.filter((id) => id === fileId).length; - expect(count).toBeLessThanOrEqual(1); // Should be 0 or 1, never more - // Optional: Check overall length is consistent with the count - if (count === 0) { - expect(finalFileIds).toHaveLength(0); - } else { - expect(finalFileIds).toHaveLength(1); - expect(finalFileIds[0]).toBe(fileId); - } - }); - - test('should handle concurrent duplicate removals', async () => { - const agent = await createBasicAgent(); - const fileId = uuidv4(); - - // Add the file first - await addAgentResourceFile({ - agent_id: agent.id, - tool_resource: 'test_tool', - file_id: fileId, + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); }); - // Concurrent removals of the same file - const removalPromises = Array.from({ length: 5 }).map(() => - removeAgentResourceFiles({ - agent_id: agent.id, - files: [{ tool_resource: 'test_tool', file_id: fileId }], - }), - ); - - await Promise.all(removalPromises); - - const updatedAgent = await Agent.findOne({ id: agent.id }); - // Check if the array is empty or the tool resource itself is removed - const fileIds = updatedAgent.tool_resources?.test_tool?.file_ids ?? []; - expect(fileIds).toHaveLength(0); - expect(fileIds).not.toContain(fileId); - }); - - test('should handle concurrent removals of different files', async () => { - const agent = await createBasicAgent(); - const fileIds = Array.from({ length: 10 }, () => uuidv4()); - - // Add all files first - await Promise.all( - fileIds.map((fileId) => - addAgentResourceFile({ - agent_id: agent.id, - tool_resource: 'test_tool', - file_id: fileId, - }), - ), - ); - - // Concurrently remove all files - const removalPromises = fileIds.map((fileId) => - removeAgentResourceFiles({ - agent_id: agent.id, - files: [{ tool_resource: 'test_tool', file_id: fileId }], - }), - ); - - await Promise.all(removalPromises); - - const updatedAgent = await Agent.findOne({ id: agent.id }); - // Check if the array is empty or the tool resource itself is removed - const finalFileIds = updatedAgent.tool_resources?.test_tool?.file_ids ?? []; - expect(finalFileIds).toHaveLength(0); - }); -}); - -describe('Agent CRUD Operations', () => { - let mongoServer; - - beforeAll(async () => { - mongoServer = await MongoMemoryServer.create(); - const mongoUri = mongoServer.getUri(); - await mongoose.connect(mongoUri); - }); - - afterAll(async () => { - await mongoose.disconnect(); - await mongoServer.stop(); - }); - - beforeEach(async () => { - await Agent.deleteMany({}); - }); - - test('should create and get an agent', async () => { - const agentId = `agent_${uuidv4()}`; - const authorId = new mongoose.Types.ObjectId(); - - const newAgent = await createAgent({ - id: agentId, - name: 'Test Agent', - provider: 'test', - model: 'test-model', - author: authorId, - description: 'Test description', + beforeEach(async () => { + await Agent.deleteMany({}); }); - expect(newAgent).toBeDefined(); - expect(newAgent.id).toBe(agentId); - expect(newAgent.name).toBe('Test Agent'); + test('should create and get an agent', async () => { + const { agentId, authorId } = createTestIds(); - const retrievedAgent = await getAgent({ id: agentId }); - expect(retrievedAgent).toBeDefined(); - expect(retrievedAgent.id).toBe(agentId); - expect(retrievedAgent.name).toBe('Test Agent'); - expect(retrievedAgent.description).toBe('Test description'); - }); + const newAgent = await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + description: 'Test description', + }); - test('should delete an agent', async () => { - const agentId = `agent_${uuidv4()}`; - const authorId = new mongoose.Types.ObjectId(); + expect(newAgent).toBeDefined(); + expect(newAgent.id).toBe(agentId); + expect(newAgent.name).toBe('Test Agent'); - await createAgent({ - id: agentId, - name: 'Agent To Delete', - provider: 'test', - model: 'test-model', - author: authorId, + const retrievedAgent = await getAgent({ id: agentId }); + expect(retrievedAgent).toBeDefined(); + expect(retrievedAgent.id).toBe(agentId); + expect(retrievedAgent.name).toBe('Test Agent'); + expect(retrievedAgent.description).toBe('Test description'); }); - const agentBeforeDelete = await getAgent({ id: agentId }); - expect(agentBeforeDelete).toBeDefined(); + test('should delete an agent', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); - await deleteAgent({ id: agentId }); - - const agentAfterDelete = await getAgent({ id: agentId }); - expect(agentAfterDelete).toBeNull(); - }); - - test('should list agents by author', async () => { - const authorId = new mongoose.Types.ObjectId(); - const otherAuthorId = new mongoose.Types.ObjectId(); - - const agentIds = []; - for (let i = 0; i < 5; i++) { - const id = `agent_${uuidv4()}`; - agentIds.push(id); await createAgent({ - id, - name: `Agent ${i}`, + id: agentId, + name: 'Agent To Delete', provider: 'test', model: 'test-model', author: authorId, }); - } - for (let i = 0; i < 3; i++) { + const agentBeforeDelete = await getAgent({ id: agentId }); + expect(agentBeforeDelete).toBeDefined(); + + await deleteAgent({ id: agentId }); + + const agentAfterDelete = await getAgent({ id: agentId }); + expect(agentAfterDelete).toBeNull(); + }); + + test('should list agents by author', async () => { + const authorId = new mongoose.Types.ObjectId(); + const otherAuthorId = new mongoose.Types.ObjectId(); + + const agentIds = []; + for (let i = 0; i < 5; i++) { + const id = `agent_${uuidv4()}`; + agentIds.push(id); + await createAgent({ + id, + name: `Agent ${i}`, + provider: 'test', + model: 'test-model', + author: authorId, + }); + } + + for (let i = 0; i < 3; i++) { + await createAgent({ + id: `other_agent_${uuidv4()}`, + name: `Other Agent ${i}`, + provider: 'test', + model: 'test-model', + author: otherAuthorId, + }); + } + + const result = await getListAgents({ author: authorId.toString() }); + + expect(result).toBeDefined(); + expect(result.data).toBeDefined(); + expect(result.data).toHaveLength(5); + expect(result.has_more).toBe(true); + + for (const agent of result.data) { + expect(agent.author).toBe(authorId.toString()); + } + }); + + test('should update agent projects', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + const projectId1 = new mongoose.Types.ObjectId(); + const projectId2 = new mongoose.Types.ObjectId(); + const projectId3 = new mongoose.Types.ObjectId(); + await createAgent({ - id: `other_agent_${uuidv4()}`, - name: `Other Agent ${i}`, + id: agentId, + name: 'Project Test Agent', provider: 'test', model: 'test-model', - author: otherAuthorId, + author: authorId, + projectIds: [projectId1], }); - } - const result = await getListAgents({ author: authorId.toString() }); + await updateAgent( + { id: agentId }, + { $addToSet: { projectIds: { $each: [projectId2, projectId3] } } }, + ); - expect(result).toBeDefined(); - expect(result.data).toBeDefined(); - expect(result.data).toHaveLength(5); - expect(result.has_more).toBe(true); + await updateAgent({ id: agentId }, { $pull: { projectIds: projectId1 } }); - for (const agent of result.data) { - expect(agent.author).toBe(authorId.toString()); - } - }); + await updateAgent({ id: agentId }, { projectIds: [projectId2, projectId3] }); - test('should update agent projects', async () => { - const agentId = `agent_${uuidv4()}`; - const authorId = new mongoose.Types.ObjectId(); - const projectId1 = new mongoose.Types.ObjectId(); - const projectId2 = new mongoose.Types.ObjectId(); - const projectId3 = new mongoose.Types.ObjectId(); + const updatedAgent = await getAgent({ id: agentId }); + expect(updatedAgent.projectIds).toHaveLength(2); + expect(updatedAgent.projectIds.map((id) => id.toString())).toContain(projectId2.toString()); + expect(updatedAgent.projectIds.map((id) => id.toString())).toContain(projectId3.toString()); + expect(updatedAgent.projectIds.map((id) => id.toString())).not.toContain( + projectId1.toString(), + ); - await createAgent({ - id: agentId, - name: 'Project Test Agent', - provider: 'test', - model: 'test-model', - author: authorId, - projectIds: [projectId1], + await updateAgent({ id: agentId }, { projectIds: [] }); + + const emptyProjectsAgent = await getAgent({ id: agentId }); + expect(emptyProjectsAgent.projectIds).toHaveLength(0); + + const nonExistentId = `agent_${uuidv4()}`; + await expect( + updateAgentProjects({ + id: nonExistentId, + projectIds: [projectId1], + }), + ).rejects.toThrow(); }); - await updateAgent( - { id: agentId }, - { $addToSet: { projectIds: { $each: [projectId2, projectId3] } } }, - ); + test('should handle ephemeral agent loading', async () => { + const agentId = 'ephemeral_test'; + const endpoint = 'openai'; - await updateAgent({ id: agentId }, { $pull: { projectIds: projectId1 } }); + const originalModule = jest.requireActual('librechat-data-provider'); - await updateAgent({ id: agentId }, { projectIds: [projectId2, projectId3] }); + const mockDataProvider = { + ...originalModule, + Constants: { + ...originalModule.Constants, + EPHEMERAL_AGENT_ID: 'ephemeral_test', + }, + }; - const updatedAgent = await getAgent({ id: agentId }); - expect(updatedAgent.projectIds).toHaveLength(2); - expect(updatedAgent.projectIds.map((id) => id.toString())).toContain(projectId2.toString()); - expect(updatedAgent.projectIds.map((id) => id.toString())).toContain(projectId3.toString()); - expect(updatedAgent.projectIds.map((id) => id.toString())).not.toContain(projectId1.toString()); + jest.doMock('librechat-data-provider', () => mockDataProvider); - await updateAgent({ id: agentId }, { projectIds: [] }); + expect(agentId).toBeDefined(); + expect(endpoint).toBeDefined(); - const emptyProjectsAgent = await getAgent({ id: agentId }); - expect(emptyProjectsAgent.projectIds).toHaveLength(0); + jest.dontMock('librechat-data-provider'); + }); - const nonExistentId = `agent_${uuidv4()}`; - await expect( - updateAgentProjects({ - id: nonExistentId, - projectIds: [projectId1], - }), - ).rejects.toThrow(); + test('should handle loadAgent functionality and errors', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Load Agent', + provider: 'test', + model: 'test-model', + author: authorId, + tools: ['tool1', 'tool2'], + }); + + const agent = await getAgent({ id: agentId }); + + expect(agent).toBeDefined(); + expect(agent.id).toBe(agentId); + expect(agent.name).toBe('Test Load Agent'); + expect(agent.tools).toEqual(expect.arrayContaining(['tool1', 'tool2'])); + + const mockLoadAgent = jest.fn().mockResolvedValue(agent); + const loadedAgent = await mockLoadAgent(); + expect(loadedAgent).toBeDefined(); + expect(loadedAgent.id).toBe(agentId); + + const nonExistentId = `agent_${uuidv4()}`; + const nonExistentAgent = await getAgent({ id: nonExistentId }); + expect(nonExistentAgent).toBeNull(); + + const mockLoadAgentError = jest.fn().mockRejectedValue(new Error('No agent found with ID')); + await expect(mockLoadAgentError()).rejects.toThrow('No agent found with ID'); + }); + + describe('Edge Cases', () => { + test.each([ + { + name: 'getAgent with undefined search parameters', + fn: () => getAgent(undefined), + expected: null, + }, + { + name: 'deleteAgent with non-existent agent', + fn: () => deleteAgent({ id: 'non-existent' }), + expected: null, + }, + ])('$name should return null', async ({ fn, expected }) => { + const result = await fn(); + expect(result).toBe(expected); + }); + + test('should handle getListAgents with invalid author format', async () => { + try { + const result = await getListAgents({ author: 'invalid-object-id' }); + expect(result.data).toEqual([]); + } catch (error) { + expect(error).toBeDefined(); + } + }); + + test('should handle getListAgents with no agents', async () => { + const authorId = new mongoose.Types.ObjectId(); + const result = await getListAgents({ author: authorId.toString() }); + + expect(result).toBeDefined(); + expect(result.data).toEqual([]); + expect(result.has_more).toBe(false); + expect(result.first_id).toBeNull(); + expect(result.last_id).toBeNull(); + }); + + test('should handle updateAgentProjects with non-existent agent', async () => { + const nonExistentId = `agent_${uuidv4()}`; + const userId = new mongoose.Types.ObjectId(); + const projectId = new mongoose.Types.ObjectId(); + + const result = await updateAgentProjects({ + user: { id: userId.toString() }, + agentId: nonExistentId, + projectIds: [projectId.toString()], + }); + + expect(result).toBeNull(); + }); + }); }); - test('should handle ephemeral agent loading', async () => { - const agentId = 'ephemeral_test'; - const endpoint = 'openai'; + describe('Agent Version History', () => { + let mongoServer; - const originalModule = jest.requireActual('librechat-data-provider'); + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema); + await mongoose.connect(mongoUri); + }); - const mockDataProvider = { - ...originalModule, - Constants: { - ...originalModule.Constants, - EPHEMERAL_AGENT_ID: 'ephemeral_test', - }, - }; + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); + }); - jest.doMock('librechat-data-provider', () => mockDataProvider); + beforeEach(async () => { + await Agent.deleteMany({}); + }); - const mockReq = { - user: { id: 'user123' }, - body: { - promptPrefix: 'This is a test instruction', - ephemeralAgent: { - execute_code: true, - mcp: ['server1', 'server2'], + test('should create an agent with a single entry in versions array', async () => { + const agent = await createBasicAgent(); + + expect(agent.versions).toBeDefined(); + expect(Array.isArray(agent.versions)).toBe(true); + expect(agent.versions).toHaveLength(1); + expect(agent.versions[0].name).toBe('Test Agent'); + expect(agent.versions[0].provider).toBe('test'); + expect(agent.versions[0].model).toBe('test-model'); + }); + + test('should accumulate version history across multiple updates', async () => { + const agentId = `agent_${uuidv4()}`; + const author = new mongoose.Types.ObjectId(); + await createAgent({ + id: agentId, + name: 'First Name', + provider: 'test', + model: 'test-model', + author, + description: 'First description', + }); + + await updateAgent( + { id: agentId }, + { name: 'Second Name', description: 'Second description' }, + ); + await updateAgent({ id: agentId }, { name: 'Third Name', model: 'new-model' }); + const finalAgent = await updateAgent({ id: agentId }, { description: 'Final description' }); + + expect(finalAgent.versions).toBeDefined(); + expect(Array.isArray(finalAgent.versions)).toBe(true); + expect(finalAgent.versions).toHaveLength(4); + + expect(finalAgent.versions[0].name).toBe('First Name'); + expect(finalAgent.versions[0].description).toBe('First description'); + expect(finalAgent.versions[0].model).toBe('test-model'); + + expect(finalAgent.versions[1].name).toBe('Second Name'); + expect(finalAgent.versions[1].description).toBe('Second description'); + expect(finalAgent.versions[1].model).toBe('test-model'); + + expect(finalAgent.versions[2].name).toBe('Third Name'); + expect(finalAgent.versions[2].description).toBe('Second description'); + expect(finalAgent.versions[2].model).toBe('new-model'); + + expect(finalAgent.versions[3].name).toBe('Third Name'); + expect(finalAgent.versions[3].description).toBe('Final description'); + expect(finalAgent.versions[3].model).toBe('new-model'); + + expect(finalAgent.name).toBe('Third Name'); + expect(finalAgent.description).toBe('Final description'); + expect(finalAgent.model).toBe('new-model'); + }); + + test('should not include metadata fields in version history', async () => { + const agentId = `agent_${uuidv4()}`; + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: new mongoose.Types.ObjectId(), + }); + + const updatedAgent = await updateAgent({ id: agentId }, { description: 'New description' }); + + expect(updatedAgent.versions).toHaveLength(2); + expect(updatedAgent.versions[0]._id).toBeUndefined(); + expect(updatedAgent.versions[0].__v).toBeUndefined(); + expect(updatedAgent.versions[0].name).toBe('Test Agent'); + expect(updatedAgent.versions[0].author).toBeUndefined(); + + expect(updatedAgent.versions[1]._id).toBeUndefined(); + expect(updatedAgent.versions[1].__v).toBeUndefined(); + }); + + test('should not recursively include previous versions', async () => { + const agentId = `agent_${uuidv4()}`; + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: new mongoose.Types.ObjectId(), + }); + + await updateAgent({ id: agentId }, { name: 'Updated Name 1' }); + await updateAgent({ id: agentId }, { name: 'Updated Name 2' }); + const finalAgent = await updateAgent({ id: agentId }, { name: 'Updated Name 3' }); + + expect(finalAgent.versions).toHaveLength(4); + + finalAgent.versions.forEach((version) => { + expect(version.versions).toBeUndefined(); + }); + }); + + test('should handle MongoDB operators and field updates correctly', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + const projectId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'MongoDB Operator Test', + provider: 'test', + model: 'test-model', + author: authorId, + tools: ['tool1'], + }); + + await updateAgent( + { id: agentId }, + { + description: 'Updated description', + $push: { tools: 'tool2' }, + $addToSet: { projectIds: projectId }, }, - }, - app: { - locals: { - availableTools: { - tool__server1: {}, - tool__server2: {}, - another_tool: {}, + ); + + const firstUpdate = await getAgent({ id: agentId }); + expect(firstUpdate.description).toBe('Updated description'); + expect(firstUpdate.tools).toContain('tool1'); + expect(firstUpdate.tools).toContain('tool2'); + expect(firstUpdate.projectIds.map((id) => id.toString())).toContain(projectId.toString()); + expect(firstUpdate.versions).toHaveLength(2); + + await updateAgent( + { id: agentId }, + { + tools: ['tool2', 'tool3'], + }, + ); + + const secondUpdate = await getAgent({ id: agentId }); + expect(secondUpdate.tools).toHaveLength(2); + expect(secondUpdate.tools).toContain('tool2'); + expect(secondUpdate.tools).toContain('tool3'); + expect(secondUpdate.tools).not.toContain('tool1'); + expect(secondUpdate.versions).toHaveLength(3); + + await updateAgent( + { id: agentId }, + { + $push: { tools: 'tool3' }, + }, + ); + + const thirdUpdate = await getAgent({ id: agentId }); + const toolCount = thirdUpdate.tools.filter((t) => t === 'tool3').length; + expect(toolCount).toBe(2); + expect(thirdUpdate.versions).toHaveLength(4); + }); + + test('should handle parameter objects correctly', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Parameters Test', + provider: 'test', + model: 'test-model', + author: authorId, + model_parameters: { temperature: 0.7 }, + }); + + const updatedAgent = await updateAgent( + { id: agentId }, + { model_parameters: { temperature: 0.8 } }, + ); + + expect(updatedAgent.versions).toHaveLength(2); + expect(updatedAgent.model_parameters.temperature).toBe(0.8); + + await updateAgent( + { id: agentId }, + { + model_parameters: { + temperature: 0.8, + max_tokens: 1000, }, }, - }, - }; + ); - const params = { - req: mockReq, - agent_id: agentId, - endpoint, - model_parameters: { - model: 'gpt-4', - temperature: 0.7, - }, - }; + const complexAgent = await getAgent({ id: agentId }); + expect(complexAgent.versions).toHaveLength(3); + expect(complexAgent.model_parameters.temperature).toBe(0.8); + expect(complexAgent.model_parameters.max_tokens).toBe(1000); - expect(agentId).toBeDefined(); - expect(endpoint).toBeDefined(); + await updateAgent({ id: agentId }, { model_parameters: {} }); - jest.dontMock('librechat-data-provider'); - }); - - test('should handle loadAgent functionality and errors', async () => { - const agentId = `agent_${uuidv4()}`; - const authorId = new mongoose.Types.ObjectId(); - - await createAgent({ - id: agentId, - name: 'Test Load Agent', - provider: 'test', - model: 'test-model', - author: authorId, - tools: ['tool1', 'tool2'], + const emptyParamsAgent = await getAgent({ id: agentId }); + expect(emptyParamsAgent.versions).toHaveLength(4); + expect(emptyParamsAgent.model_parameters).toEqual({}); }); - const agent = await getAgent({ id: agentId }); + test('should detect duplicate versions and reject updates', async () => { + const originalConsoleError = console.error; + console.error = jest.fn(); - expect(agent).toBeDefined(); - expect(agent.id).toBe(agentId); - expect(agent.name).toBe('Test Load Agent'); - expect(agent.tools).toEqual(expect.arrayContaining(['tool1', 'tool2'])); + try { + const authorId = new mongoose.Types.ObjectId(); + const testCases = generateVersionTestCases(); - const mockLoadAgent = jest.fn().mockResolvedValue(agent); - const loadedAgent = await mockLoadAgent(); - expect(loadedAgent).toBeDefined(); - expect(loadedAgent.id).toBe(agentId); + for (const testCase of testCases) { + const testAgentId = `agent_${uuidv4()}`; - const nonExistentId = `agent_${uuidv4()}`; - const nonExistentAgent = await getAgent({ id: nonExistentId }); - expect(nonExistentAgent).toBeNull(); + await createAgent({ + id: testAgentId, + provider: 'test', + model: 'test-model', + author: authorId, + ...testCase.initial, + }); - const mockLoadAgentError = jest.fn().mockRejectedValue(new Error('No agent found with ID')); - await expect(mockLoadAgentError()).rejects.toThrow('No agent found with ID'); - }); -}); + await updateAgent({ id: testAgentId }, testCase.update); -describe('Agent Version History', () => { - let mongoServer; + let error; + try { + await updateAgent({ id: testAgentId }, testCase.duplicate); + } catch (e) { + error = e; + } - beforeAll(async () => { - mongoServer = await MongoMemoryServer.create(); - const mongoUri = mongoServer.getUri(); - await mongoose.connect(mongoUri); - }); + expect(error).toBeDefined(); + expect(error.message).toContain('Duplicate version'); + expect(error.statusCode).toBe(409); + expect(error.details).toBeDefined(); + expect(error.details.duplicateVersion).toBeDefined(); - afterAll(async () => { - await mongoose.disconnect(); - await mongoServer.stop(); - }); - - beforeEach(async () => { - await Agent.deleteMany({}); - }); - - test('should create an agent with a single entry in versions array', async () => { - const agentId = `agent_${uuidv4()}`; - const agent = await createAgent({ - id: agentId, - name: 'Test Agent', - provider: 'test', - model: 'test-model', - author: new mongoose.Types.ObjectId(), + const agent = await getAgent({ id: testAgentId }); + expect(agent.versions).toHaveLength(2); + } + } finally { + console.error = originalConsoleError; + } }); - expect(agent.versions).toBeDefined(); - expect(Array.isArray(agent.versions)).toBe(true); - expect(agent.versions).toHaveLength(1); - expect(agent.versions[0].name).toBe('Test Agent'); - expect(agent.versions[0].provider).toBe('test'); - expect(agent.versions[0].model).toBe('test-model'); - }); + test('should track updatedBy when a different user updates an agent', async () => { + const agentId = `agent_${uuidv4()}`; + const originalAuthor = new mongoose.Types.ObjectId(); + const updatingUser = new mongoose.Types.ObjectId(); - test('should accumulate version history across multiple updates', async () => { - const agentId = `agent_${uuidv4()}`; - const author = new mongoose.Types.ObjectId(); - await createAgent({ - id: agentId, - name: 'First Name', - provider: 'test', - model: 'test-model', - author, - description: 'First description', + await createAgent({ + id: agentId, + name: 'Original Agent', + provider: 'test', + model: 'test-model', + author: originalAuthor, + description: 'Original description', + }); + + const updatedAgent = await updateAgent( + { id: agentId }, + { name: 'Updated Agent', description: 'Updated description' }, + { updatingUserId: updatingUser.toString() }, + ); + + expect(updatedAgent.versions).toHaveLength(2); + expect(updatedAgent.versions[1].updatedBy.toString()).toBe(updatingUser.toString()); + expect(updatedAgent.author.toString()).toBe(originalAuthor.toString()); }); - await updateAgent({ id: agentId }, { name: 'Second Name', description: 'Second description' }); - await updateAgent({ id: agentId }, { name: 'Third Name', model: 'new-model' }); - const finalAgent = await updateAgent({ id: agentId }, { description: 'Final description' }); + test('should include updatedBy even when the original author updates the agent', async () => { + const agentId = `agent_${uuidv4()}`; + const originalAuthor = new mongoose.Types.ObjectId(); - expect(finalAgent.versions).toBeDefined(); - expect(Array.isArray(finalAgent.versions)).toBe(true); - expect(finalAgent.versions).toHaveLength(4); + await createAgent({ + id: agentId, + name: 'Original Agent', + provider: 'test', + model: 'test-model', + author: originalAuthor, + description: 'Original description', + }); - expect(finalAgent.versions[0].name).toBe('First Name'); - expect(finalAgent.versions[0].description).toBe('First description'); - expect(finalAgent.versions[0].model).toBe('test-model'); + const updatedAgent = await updateAgent( + { id: agentId }, + { name: 'Updated Agent', description: 'Updated description' }, + { updatingUserId: originalAuthor.toString() }, + ); - expect(finalAgent.versions[1].name).toBe('Second Name'); - expect(finalAgent.versions[1].description).toBe('Second description'); - expect(finalAgent.versions[1].model).toBe('test-model'); - - expect(finalAgent.versions[2].name).toBe('Third Name'); - expect(finalAgent.versions[2].description).toBe('Second description'); - expect(finalAgent.versions[2].model).toBe('new-model'); - - expect(finalAgent.versions[3].name).toBe('Third Name'); - expect(finalAgent.versions[3].description).toBe('Final description'); - expect(finalAgent.versions[3].model).toBe('new-model'); - - expect(finalAgent.name).toBe('Third Name'); - expect(finalAgent.description).toBe('Final description'); - expect(finalAgent.model).toBe('new-model'); - }); - - test('should not include metadata fields in version history', async () => { - const agentId = `agent_${uuidv4()}`; - await createAgent({ - id: agentId, - name: 'Test Agent', - provider: 'test', - model: 'test-model', - author: new mongoose.Types.ObjectId(), + expect(updatedAgent.versions).toHaveLength(2); + expect(updatedAgent.versions[1].updatedBy.toString()).toBe(originalAuthor.toString()); + expect(updatedAgent.author.toString()).toBe(originalAuthor.toString()); }); - const updatedAgent = await updateAgent({ id: agentId }, { description: 'New description' }); + test('should track multiple different users updating the same agent', async () => { + const agentId = `agent_${uuidv4()}`; + const originalAuthor = new mongoose.Types.ObjectId(); + const user1 = new mongoose.Types.ObjectId(); + const user2 = new mongoose.Types.ObjectId(); + const user3 = new mongoose.Types.ObjectId(); - expect(updatedAgent.versions).toHaveLength(2); - expect(updatedAgent.versions[0]._id).toBeUndefined(); - expect(updatedAgent.versions[0].__v).toBeUndefined(); - expect(updatedAgent.versions[0].name).toBe('Test Agent'); - expect(updatedAgent.versions[0].author).toBeUndefined(); + await createAgent({ + id: agentId, + name: 'Original Agent', + provider: 'test', + model: 'test-model', + author: originalAuthor, + description: 'Original description', + }); - expect(updatedAgent.versions[1]._id).toBeUndefined(); - expect(updatedAgent.versions[1].__v).toBeUndefined(); - }); + // User 1 makes an update + await updateAgent( + { id: agentId }, + { name: 'Updated by User 1', description: 'First update' }, + { updatingUserId: user1.toString() }, + ); - test('should not recursively include previous versions', async () => { - const agentId = `agent_${uuidv4()}`; - await createAgent({ - id: agentId, - name: 'Test Agent', - provider: 'test', - model: 'test-model', - author: new mongoose.Types.ObjectId(), + // Original author makes an update + await updateAgent( + { id: agentId }, + { description: 'Updated by original author' }, + { updatingUserId: originalAuthor.toString() }, + ); + + // User 2 makes an update + await updateAgent( + { id: agentId }, + { name: 'Updated by User 2', model: 'new-model' }, + { updatingUserId: user2.toString() }, + ); + + // User 3 makes an update + const finalAgent = await updateAgent( + { id: agentId }, + { description: 'Final update by User 3' }, + { updatingUserId: user3.toString() }, + ); + + expect(finalAgent.versions).toHaveLength(5); + expect(finalAgent.author.toString()).toBe(originalAuthor.toString()); + + // Check that each version has the correct updatedBy + expect(finalAgent.versions[0].updatedBy).toBeUndefined(); // Initial creation has no updatedBy + expect(finalAgent.versions[1].updatedBy.toString()).toBe(user1.toString()); + expect(finalAgent.versions[2].updatedBy.toString()).toBe(originalAuthor.toString()); + expect(finalAgent.versions[3].updatedBy.toString()).toBe(user2.toString()); + expect(finalAgent.versions[4].updatedBy.toString()).toBe(user3.toString()); + + // Verify the final state + expect(finalAgent.name).toBe('Updated by User 2'); + expect(finalAgent.description).toBe('Final update by User 3'); + expect(finalAgent.model).toBe('new-model'); }); - await updateAgent({ id: agentId }, { name: 'Updated Name 1' }); - await updateAgent({ id: agentId }, { name: 'Updated Name 2' }); - const finalAgent = await updateAgent({ id: agentId }, { name: 'Updated Name 3' }); + test('should preserve original author during agent restoration', async () => { + const agentId = `agent_${uuidv4()}`; + const originalAuthor = new mongoose.Types.ObjectId(); + const updatingUser = new mongoose.Types.ObjectId(); - expect(finalAgent.versions).toHaveLength(4); + await createAgent({ + id: agentId, + name: 'Original Agent', + provider: 'test', + model: 'test-model', + author: originalAuthor, + description: 'Original description', + }); - finalAgent.versions.forEach((version) => { - expect(version.versions).toBeUndefined(); - }); - }); + await updateAgent( + { id: agentId }, + { name: 'Updated Agent', description: 'Updated description' }, + { updatingUserId: updatingUser.toString() }, + ); - test('should handle MongoDB operators and field updates correctly', async () => { - const agentId = `agent_${uuidv4()}`; - const authorId = new mongoose.Types.ObjectId(); - const projectId = new mongoose.Types.ObjectId(); + const { revertAgentVersion } = require('./Agent'); + const revertedAgent = await revertAgentVersion({ id: agentId }, 0); - await createAgent({ - id: agentId, - name: 'MongoDB Operator Test', - provider: 'test', - model: 'test-model', - author: authorId, - tools: ['tool1'], + expect(revertedAgent.author.toString()).toBe(originalAuthor.toString()); + expect(revertedAgent.name).toBe('Original Agent'); + expect(revertedAgent.description).toBe('Original description'); }); - await updateAgent( - { id: agentId }, - { - description: 'Updated description', - $push: { tools: 'tool2' }, - $addToSet: { projectIds: projectId }, - }, - ); + test('should detect action metadata changes and force version update', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + const actionId = 'testActionId123'; - const firstUpdate = await getAgent({ id: agentId }); - expect(firstUpdate.description).toBe('Updated description'); - expect(firstUpdate.tools).toContain('tool1'); - expect(firstUpdate.tools).toContain('tool2'); - expect(firstUpdate.projectIds.map((id) => id.toString())).toContain(projectId.toString()); - expect(firstUpdate.versions).toHaveLength(2); + // Create agent with actions + await createAgent({ + id: agentId, + name: 'Agent with Actions', + provider: 'test', + model: 'test-model', + author: authorId, + actions: [`test.com_action_${actionId}`], + tools: ['listEvents_action_test.com', 'createEvent_action_test.com'], + }); - await updateAgent( - { id: agentId }, - { - tools: ['tool2', 'tool3'], - }, - ); + // First update with forceVersion should create a version + const firstUpdate = await updateAgent( + { id: agentId }, + { tools: ['listEvents_action_test.com', 'createEvent_action_test.com'] }, + { updatingUserId: authorId.toString(), forceVersion: true }, + ); - const secondUpdate = await getAgent({ id: agentId }); - expect(secondUpdate.tools).toHaveLength(2); - expect(secondUpdate.tools).toContain('tool2'); - expect(secondUpdate.tools).toContain('tool3'); - expect(secondUpdate.tools).not.toContain('tool1'); - expect(secondUpdate.versions).toHaveLength(3); + expect(firstUpdate.versions).toHaveLength(2); - await updateAgent( - { id: agentId }, - { - $push: { tools: 'tool3' }, - }, - ); + // Second update with same data but forceVersion should still create a version + const secondUpdate = await updateAgent( + { id: agentId }, + { tools: ['listEvents_action_test.com', 'createEvent_action_test.com'] }, + { updatingUserId: authorId.toString(), forceVersion: true }, + ); - const thirdUpdate = await getAgent({ id: agentId }); - const toolCount = thirdUpdate.tools.filter((t) => t === 'tool3').length; - expect(toolCount).toBe(2); - expect(thirdUpdate.versions).toHaveLength(4); - }); + expect(secondUpdate.versions).toHaveLength(3); - test('should handle parameter objects correctly', async () => { - const agentId = `agent_${uuidv4()}`; - const authorId = new mongoose.Types.ObjectId(); + // Update without forceVersion and no changes should not create a version + let error; + try { + await updateAgent( + { id: agentId }, + { tools: ['listEvents_action_test.com', 'createEvent_action_test.com'] }, + { updatingUserId: authorId.toString(), forceVersion: false }, + ); + } catch (e) { + error = e; + } - await createAgent({ - id: agentId, - name: 'Parameters Test', - provider: 'test', - model: 'test-model', - author: authorId, - model_parameters: { temperature: 0.7 }, + expect(error).toBeDefined(); + expect(error.message).toContain('Duplicate version'); + expect(error.statusCode).toBe(409); }); - const updatedAgent = await updateAgent( - { id: agentId }, - { model_parameters: { temperature: 0.8 } }, - ); + test('should handle isDuplicateVersion with arrays containing null/undefined values', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); - expect(updatedAgent.versions).toHaveLength(2); - expect(updatedAgent.model_parameters.temperature).toBe(0.8); + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + tools: ['tool1', null, 'tool2', undefined], + }); - await updateAgent( - { id: agentId }, - { - model_parameters: { - temperature: 0.8, - max_tokens: 1000, + // Update with same array but different null/undefined arrangement + const updatedAgent = await updateAgent({ id: agentId }, { tools: ['tool1', 'tool2'] }); + + expect(updatedAgent.versions).toHaveLength(2); + expect(updatedAgent.tools).toEqual(['tool1', 'tool2']); + }); + + test('should handle isDuplicateVersion with empty objects in tool_kwargs', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + tool_kwargs: [ + { tool: 'tool1', config: { setting: 'value' } }, + {}, + { tool: 'tool2', config: {} }, + ], + }); + + // Try to update with reordered but equivalent tool_kwargs + const updatedAgent = await updateAgent( + { id: agentId }, + { + tool_kwargs: [ + { tool: 'tool2', config: {} }, + { tool: 'tool1', config: { setting: 'value' } }, + {}, + ], }, - }, - ); + ); - const complexAgent = await getAgent({ id: agentId }); - expect(complexAgent.versions).toHaveLength(3); - expect(complexAgent.model_parameters.temperature).toBe(0.8); - expect(complexAgent.model_parameters.max_tokens).toBe(1000); + // Should create new version as order matters for arrays + expect(updatedAgent.versions).toHaveLength(2); + }); - await updateAgent({ id: agentId }, { model_parameters: {} }); + test('should handle isDuplicateVersion with mixed primitive and object arrays', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); - const emptyParamsAgent = await getAgent({ id: agentId }); - expect(emptyParamsAgent.versions).toHaveLength(4); - expect(emptyParamsAgent.model_parameters).toEqual({}); + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + mixed_array: [1, 'string', { key: 'value' }, true, null], + }); + + // Update with same values but different types + const updatedAgent = await updateAgent( + { id: agentId }, + { mixed_array: ['1', 'string', { key: 'value' }, 'true', null] }, + ); + + // Should create new version as types differ + expect(updatedAgent.versions).toHaveLength(2); + }); + + test('should handle isDuplicateVersion with deeply nested objects', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + const deepObject = { + level1: { + level2: { + level3: { + level4: { + value: 'deep', + array: [1, 2, { nested: true }], + }, + }, + }, + }, + }; + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + model_parameters: deepObject, + }); + + // First create a version with changes + await updateAgent({ id: agentId }, { description: 'Updated' }); + + // Then try to create duplicate of the original version + await updateAgent( + { id: agentId }, + { + model_parameters: deepObject, + description: undefined, + }, + ); + + // Since we're updating back to the same model_parameters but with a different description, + // it should create a new version + const agent = await getAgent({ id: agentId }); + expect(agent.versions).toHaveLength(3); + }); + + test('should handle version comparison with special field types', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + const projectId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + projectIds: [projectId], + model_parameters: { temperature: 0.7 }, + }); + + // Update with a real field change first + const firstUpdate = await updateAgent({ id: agentId }, { description: 'New description' }); + + expect(firstUpdate.versions).toHaveLength(2); + + // Update with model parameters change + const secondUpdate = await updateAgent( + { id: agentId }, + { model_parameters: { temperature: 0.8 } }, + ); + + expect(secondUpdate.versions).toHaveLength(3); + }); + + describe('Edge Cases', () => { + test('should handle extremely large version history', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Version Test', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + for (let i = 0; i < 20; i++) { + await updateAgent({ id: agentId }, { description: `Version ${i}` }); + } + + const agent = await getAgent({ id: agentId }); + expect(agent.versions).toHaveLength(21); + expect(agent.description).toBe('Version 19'); + }); + + test('should handle revertAgentVersion with invalid version index', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + await expect(revertAgentVersion({ id: agentId }, 5)).rejects.toThrow('Version 5 not found'); + }); + + test('should handle revertAgentVersion with non-existent agent', async () => { + const nonExistentId = `agent_${uuidv4()}`; + + await expect(revertAgentVersion({ id: nonExistentId }, 0)).rejects.toThrow( + 'Agent not found', + ); + }); + + test('should handle updateAgent with empty update object', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + const updatedAgent = await updateAgent({ id: agentId }, {}); + + expect(updatedAgent).toBeDefined(); + expect(updatedAgent.name).toBe('Test Agent'); + expect(updatedAgent.versions).toHaveLength(1); + }); + }); }); - test('should detect duplicate versions and reject updates', async () => { - const originalConsoleError = console.error; - console.error = jest.fn(); + describe('Action Metadata and Hash Generation', () => { + let mongoServer; - try { + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema); + await mongoose.connect(mongoUri); + }); + + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); + }); + + beforeEach(async () => { + await Agent.deleteMany({}); + }); + + test('should generate consistent hash for same action metadata', async () => { + const actionIds = ['test.com_action_123', 'example.com_action_456']; + const actions = [ + { + action_id: '123', + metadata: { version: '1.0', endpoints: ['GET /api/test'], schema: { type: 'object' } }, + }, + { + action_id: '456', + metadata: { + version: '2.0', + endpoints: ['POST /api/example'], + schema: { type: 'string' }, + }, + }, + ]; + + const hash1 = await generateActionMetadataHash(actionIds, actions); + const hash2 = await generateActionMetadataHash(actionIds, actions); + + expect(hash1).toBe(hash2); + expect(typeof hash1).toBe('string'); + expect(hash1.length).toBe(64); // SHA-256 produces 64 character hex string + }); + + test('should generate different hashes for different action metadata', async () => { + const actionIds = ['test.com_action_123']; + const actions1 = [ + { action_id: '123', metadata: { version: '1.0', endpoints: ['GET /api/test'] } }, + ]; + const actions2 = [ + { action_id: '123', metadata: { version: '2.0', endpoints: ['GET /api/test'] } }, + ]; + + const hash1 = await generateActionMetadataHash(actionIds, actions1); + const hash2 = await generateActionMetadataHash(actionIds, actions2); + + expect(hash1).not.toBe(hash2); + }); + + test('should handle empty action arrays', async () => { + const hash = await generateActionMetadataHash([], []); + expect(hash).toBe(''); + }); + + test('should handle null or undefined action arrays', async () => { + const hash1 = await generateActionMetadataHash(null, []); + const hash2 = await generateActionMetadataHash(undefined, []); + + expect(hash1).toBe(''); + expect(hash2).toBe(''); + }); + + test('should handle missing action metadata gracefully', async () => { + const actionIds = ['test.com_action_123', 'missing.com_action_999']; + const actions = [ + { action_id: '123', metadata: { version: '1.0' } }, + // missing action with id '999' + ]; + + const hash = await generateActionMetadataHash(actionIds, actions); + expect(typeof hash).toBe('string'); + expect(hash.length).toBe(64); + }); + + test('should sort action IDs for consistent hashing', async () => { + const actionIds1 = ['b.com_action_2', 'a.com_action_1']; + const actionIds2 = ['a.com_action_1', 'b.com_action_2']; + const actions = [ + { action_id: '1', metadata: { version: '1.0' } }, + { action_id: '2', metadata: { version: '2.0' } }, + ]; + + const hash1 = await generateActionMetadataHash(actionIds1, actions); + const hash2 = await generateActionMetadataHash(actionIds2, actions); + + expect(hash1).toBe(hash2); + }); + + test('should handle complex nested metadata objects', async () => { + const actionIds = ['complex.com_action_1']; + const actions = [ + { + action_id: '1', + metadata: { + version: '1.0', + schema: { + type: 'object', + properties: { + name: { type: 'string' }, + nested: { + type: 'object', + properties: { + id: { type: 'number' }, + tags: { type: 'array', items: { type: 'string' } }, + }, + }, + }, + }, + endpoints: [ + { path: '/api/test', method: 'GET', params: ['id'] }, + { path: '/api/create', method: 'POST', body: true }, + ], + }, + }, + ]; + + const hash = await generateActionMetadataHash(actionIds, actions); + expect(typeof hash).toBe('string'); + expect(hash.length).toBe(64); + }); + + describe('Edge Cases', () => { + test('should handle generateActionMetadataHash with null metadata', async () => { + const hash = await generateActionMetadataHash( + ['test.com_action_1'], + [{ action_id: '1', metadata: null }], + ); + expect(typeof hash).toBe('string'); + }); + + test('should handle generateActionMetadataHash with deeply nested metadata', async () => { + const deepMetadata = { + level1: { + level2: { + level3: { + level4: { + level5: 'deep value', + array: [1, 2, { nested: true }], + }, + }, + }, + }, + }; + + const hash = await generateActionMetadataHash( + ['test.com_action_1'], + [{ action_id: '1', metadata: deepMetadata }], + ); + + expect(typeof hash).toBe('string'); + expect(hash.length).toBe(64); + }); + + test('should handle generateActionMetadataHash with special characters', async () => { + const specialMetadata = { + unicode: '🚀🎉👍', + symbols: '!@#$%^&*()_+-=[]{}|;:,.<>?', + quotes: 'single\'s and "doubles"', + newlines: 'line1\nline2\r\nline3', + }; + + const hash = await generateActionMetadataHash( + ['test.com_action_1'], + [{ action_id: '1', metadata: specialMetadata }], + ); + + expect(typeof hash).toBe('string'); + expect(hash.length).toBe(64); + }); + }); + }); + + describe('Load Agent Functionality', () => { + let mongoServer; + + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema); + await mongoose.connect(mongoUri); + }); + + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); + }); + + beforeEach(async () => { + await Agent.deleteMany({}); + }); + + test('should return null when agent_id is not provided', async () => { + const mockReq = { user: { id: 'user123' } }; + const result = await loadAgent({ + req: mockReq, + agent_id: null, + endpoint: 'openai', + model_parameters: { model: 'gpt-4' }, + }); + + expect(result).toBeNull(); + }); + + test('should return null when agent_id is empty string', async () => { + const mockReq = { user: { id: 'user123' } }; + const result = await loadAgent({ + req: mockReq, + agent_id: '', + endpoint: 'openai', + model_parameters: { model: 'gpt-4' }, + }); + + expect(result).toBeNull(); + }); + + test('should test ephemeral agent loading logic', async () => { + const { EPHEMERAL_AGENT_ID } = require('librechat-data-provider').Constants; + + const mockReq = { + user: { id: 'user123' }, + body: { + promptPrefix: 'Test instructions', + ephemeralAgent: { + execute_code: true, + web_search: true, + mcp: ['server1', 'server2'], + }, + }, + app: { + locals: { + availableTools: { + tool1_mcp_server1: {}, + tool2_mcp_server2: {}, + another_tool: {}, + }, + }, + }, + }; + + const result = await loadAgent({ + req: mockReq, + agent_id: EPHEMERAL_AGENT_ID, + endpoint: 'openai', + model_parameters: { model: 'gpt-4', temperature: 0.7 }, + }); + + if (result) { + expect(result.id).toBe(EPHEMERAL_AGENT_ID); + expect(result.instructions).toBe('Test instructions'); + expect(result.provider).toBe('openai'); + expect(result.model).toBe('gpt-4'); + expect(result.model_parameters.temperature).toBe(0.7); + expect(result.tools).toContain('execute_code'); + expect(result.tools).toContain('web_search'); + expect(result.tools).toContain('tool1_mcp_server1'); + expect(result.tools).toContain('tool2_mcp_server2'); + } else { + expect(result).toBeNull(); + } + }); + + test('should return null for non-existent agent', async () => { + const mockReq = { user: { id: 'user123' } }; + const result = await loadAgent({ + req: mockReq, + agent_id: 'non_existent_agent', + endpoint: 'openai', + model_parameters: { model: 'gpt-4' }, + }); + + expect(result).toBeNull(); + }); + + test('should load agent when user is the author', async () => { + const userId = new mongoose.Types.ObjectId(); + const agentId = `agent_${uuidv4()}`; + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'openai', + model: 'gpt-4', + author: userId, + description: 'Test description', + tools: ['web_search'], + }); + + const mockReq = { user: { id: userId.toString() } }; + const result = await loadAgent({ + req: mockReq, + agent_id: agentId, + endpoint: 'openai', + model_parameters: { model: 'gpt-4' }, + }); + + expect(result).toBeDefined(); + expect(result.id).toBe(agentId); + expect(result.name).toBe('Test Agent'); + expect(result.author.toString()).toBe(userId.toString()); + expect(result.version).toBe(1); + }); + + test('should return null when user is not author and agent has no projectIds', async () => { + const authorId = new mongoose.Types.ObjectId(); + const userId = new mongoose.Types.ObjectId(); + const agentId = `agent_${uuidv4()}`; + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'openai', + model: 'gpt-4', + author: authorId, + }); + + const mockReq = { user: { id: userId.toString() } }; + const result = await loadAgent({ + req: mockReq, + agent_id: agentId, + endpoint: 'openai', + model_parameters: { model: 'gpt-4' }, + }); + + expect(result).toBeFalsy(); + }); + + test('should handle ephemeral agent with no MCP servers', async () => { + const { EPHEMERAL_AGENT_ID } = require('librechat-data-provider').Constants; + + const mockReq = { + user: { id: 'user123' }, + body: { + promptPrefix: 'Simple instructions', + ephemeralAgent: { + execute_code: false, + web_search: false, + mcp: [], + }, + }, + app: { + locals: { + availableTools: {}, + }, + }, + }; + + const result = await loadAgent({ + req: mockReq, + agent_id: EPHEMERAL_AGENT_ID, + endpoint: 'openai', + model_parameters: { model: 'gpt-3.5-turbo' }, + }); + + if (result) { + expect(result.tools).toEqual([]); + expect(result.instructions).toBe('Simple instructions'); + } else { + expect(result).toBeFalsy(); + } + }); + + test('should handle ephemeral agent with undefined ephemeralAgent in body', async () => { + const { EPHEMERAL_AGENT_ID } = require('librechat-data-provider').Constants; + + const mockReq = { + user: { id: 'user123' }, + body: { + promptPrefix: 'Basic instructions', + }, + app: { + locals: { + availableTools: {}, + }, + }, + }; + + const result = await loadAgent({ + req: mockReq, + agent_id: EPHEMERAL_AGENT_ID, + endpoint: 'openai', + model_parameters: { model: 'gpt-4' }, + }); + + if (result) { + expect(result.tools).toEqual([]); + } else { + expect(result).toBeFalsy(); + } + }); + + describe('Edge Cases', () => { + test('should handle loadAgent with malformed req object', async () => { + const result = await loadAgent({ + req: null, + agent_id: 'test', + endpoint: 'openai', + model_parameters: { model: 'gpt-4' }, + }); + + expect(result).toBeNull(); + }); + + test('should handle ephemeral agent with extremely large tool list', async () => { + const { EPHEMERAL_AGENT_ID } = require('librechat-data-provider').Constants; + + const largeToolList = Array.from({ length: 100 }, (_, i) => `tool_${i}_mcp_server1`); + const mockReq = { + user: { id: 'user123' }, + body: { + promptPrefix: 'Test', + ephemeralAgent: { + execute_code: true, + web_search: true, + mcp: ['server1'], + }, + }, + app: { + locals: { + availableTools: largeToolList.reduce((acc, tool) => { + acc[tool] = {}; + return acc; + }, {}), + }, + }, + }; + + const result = await loadAgent({ + req: mockReq, + agent_id: EPHEMERAL_AGENT_ID, + endpoint: 'openai', + model_parameters: { model: 'gpt-4' }, + }); + + if (result) { + expect(result.tools.length).toBeGreaterThan(100); + } + }); + + test('should handle loadAgent with agent from different project', async () => { + const authorId = new mongoose.Types.ObjectId(); + const userId = new mongoose.Types.ObjectId(); + const agentId = `agent_${uuidv4()}`; + const projectId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Project Agent', + provider: 'openai', + model: 'gpt-4', + author: authorId, + projectIds: [projectId], + }); + + const mockReq = { user: { id: userId.toString() } }; + const result = await loadAgent({ + req: mockReq, + agent_id: agentId, + endpoint: 'openai', + model_parameters: { model: 'gpt-4' }, + }); + + expect(result).toBeFalsy(); + }); + }); + }); + + describe('Agent Edge Cases and Error Handling', () => { + let mongoServer; + + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema); + await mongoose.connect(mongoUri); + }); + + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); + }); + + beforeEach(async () => { + await Agent.deleteMany({}); + }); + + test('should handle agent creation with minimal required fields', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + const agent = await createAgent({ + id: agentId, + provider: 'test', + model: 'test-model', + author: authorId, + }); + + expect(agent).toBeDefined(); + expect(agent.id).toBe(agentId); + expect(agent.versions).toHaveLength(1); + expect(agent.versions[0].provider).toBe('test'); + expect(agent.versions[0].model).toBe('test-model'); + }); + + test('should handle agent creation with all optional fields', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + const projectId = new mongoose.Types.ObjectId(); + + const agent = await createAgent({ + id: agentId, + name: 'Complex Agent', + provider: 'test', + model: 'test-model', + author: authorId, + description: 'Complex description', + instructions: 'Complex instructions', + tools: ['tool1', 'tool2'], + actions: ['action1', 'action2'], + model_parameters: { temperature: 0.8, max_tokens: 1000 }, + projectIds: [projectId], + avatar: 'https://example.com/avatar.png', + isCollaborative: true, + tool_resources: { + file_search: { file_ids: ['file1', 'file2'] }, + }, + }); + + expect(agent).toBeDefined(); + expect(agent.name).toBe('Complex Agent'); + expect(agent.description).toBe('Complex description'); + expect(agent.instructions).toBe('Complex instructions'); + expect(agent.tools).toEqual(['tool1', 'tool2']); + expect(agent.actions).toEqual(['action1', 'action2']); + expect(agent.model_parameters.temperature).toBe(0.8); + expect(agent.model_parameters.max_tokens).toBe(1000); + expect(agent.projectIds.map((id) => id.toString())).toContain(projectId.toString()); + expect(agent.avatar).toBe('https://example.com/avatar.png'); + expect(agent.isCollaborative).toBe(true); + expect(agent.tool_resources.file_search.file_ids).toEqual(['file1', 'file2']); + }); + + test('should handle updateAgent with empty update object', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + const updatedAgent = await updateAgent({ id: agentId }, {}); + + expect(updatedAgent).toBeDefined(); + expect(updatedAgent.name).toBe('Test Agent'); + expect(updatedAgent.versions).toHaveLength(1); // No new version should be created + }); + + test('should handle concurrent updates to different agents', async () => { + const agent1Id = `agent_${uuidv4()}`; + const agent2Id = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agent1Id, + name: 'Agent 1', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + await createAgent({ + id: agent2Id, + name: 'Agent 2', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + // Concurrent updates to different agents + const [updated1, updated2] = await Promise.all([ + updateAgent({ id: agent1Id }, { description: 'Updated Agent 1' }), + updateAgent({ id: agent2Id }, { description: 'Updated Agent 2' }), + ]); + + expect(updated1.description).toBe('Updated Agent 1'); + expect(updated2.description).toBe('Updated Agent 2'); + expect(updated1.versions).toHaveLength(2); + expect(updated2.versions).toHaveLength(2); + }); + + test('should handle agent deletion with non-existent ID', async () => { + const nonExistentId = `agent_${uuidv4()}`; + const result = await deleteAgent({ id: nonExistentId }); + + expect(result).toBeNull(); + }); + + test('should handle getListAgents with no agents', async () => { + const authorId = new mongoose.Types.ObjectId(); + const result = await getListAgents({ author: authorId.toString() }); + + expect(result).toBeDefined(); + expect(result.data).toEqual([]); + expect(result.has_more).toBe(false); + expect(result.first_id).toBeNull(); + expect(result.last_id).toBeNull(); + }); + + test('should handle updateAgent with MongoDB operators mixed with direct updates', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + tools: ['tool1'], + }); + + // Test with $push and direct field update + const updatedAgent = await updateAgent( + { id: agentId }, + { + name: 'Updated Name', + $push: { tools: 'tool2' }, + }, + ); + + expect(updatedAgent.name).toBe('Updated Name'); + expect(updatedAgent.tools).toContain('tool1'); + expect(updatedAgent.tools).toContain('tool2'); + expect(updatedAgent.versions).toHaveLength(2); + }); + + test('should handle revertAgentVersion with invalid version index', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + // Try to revert to non-existent version + await expect(revertAgentVersion({ id: agentId }, 5)).rejects.toThrow('Version 5 not found'); + }); + + test('should handle revertAgentVersion with non-existent agent', async () => { + const nonExistentId = `agent_${uuidv4()}`; + + await expect(revertAgentVersion({ id: nonExistentId }, 0)).rejects.toThrow('Agent not found'); + }); + + test('should handle addAgentResourceFile with non-existent agent', async () => { + const nonExistentId = `agent_${uuidv4()}`; + const mockReq = { user: { id: 'user123' } }; + + await expect( + addAgentResourceFile({ + req: mockReq, + agent_id: nonExistentId, + tool_resource: 'file_search', + file_id: 'file123', + }), + ).rejects.toThrow('Agent not found for adding resource file'); + }); + + test('should handle removeAgentResourceFiles with non-existent agent', async () => { + const nonExistentId = `agent_${uuidv4()}`; + + await expect( + removeAgentResourceFiles({ + agent_id: nonExistentId, + files: [{ tool_resource: 'file_search', file_id: 'file123' }], + }), + ).rejects.toThrow('Agent not found for removing resource files'); + }); + + test('should handle updateAgent with complex nested updates', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + model_parameters: { temperature: 0.5 }, + tools: ['tool1'], + }); + + // First update with $push operation + const firstUpdate = await updateAgent( + { id: agentId }, + { + $push: { tools: 'tool2' }, + }, + ); + + expect(firstUpdate.tools).toContain('tool1'); + expect(firstUpdate.tools).toContain('tool2'); + + // Second update with direct field update and $addToSet + const secondUpdate = await updateAgent( + { id: agentId }, + { + name: 'Updated Agent', + model_parameters: { temperature: 0.8, max_tokens: 500 }, + $addToSet: { tools: 'tool3' }, + }, + ); + + expect(secondUpdate.name).toBe('Updated Agent'); + expect(secondUpdate.model_parameters.temperature).toBe(0.8); + expect(secondUpdate.model_parameters.max_tokens).toBe(500); + expect(secondUpdate.tools).toContain('tool1'); + expect(secondUpdate.tools).toContain('tool2'); + expect(secondUpdate.tools).toContain('tool3'); + expect(secondUpdate.versions).toHaveLength(3); + }); + + test('should preserve version order in versions array', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Version 1', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + await updateAgent({ id: agentId }, { name: 'Version 2' }); + await updateAgent({ id: agentId }, { name: 'Version 3' }); + const finalAgent = await updateAgent({ id: agentId }, { name: 'Version 4' }); + + expect(finalAgent.versions).toHaveLength(4); + expect(finalAgent.versions[0].name).toBe('Version 1'); + expect(finalAgent.versions[1].name).toBe('Version 2'); + expect(finalAgent.versions[2].name).toBe('Version 3'); + expect(finalAgent.versions[3].name).toBe('Version 4'); + expect(finalAgent.name).toBe('Version 4'); + }); + + test('should handle updateAgentProjects error scenarios', async () => { + const nonExistentId = `agent_${uuidv4()}`; + const userId = new mongoose.Types.ObjectId(); + const projectId = new mongoose.Types.ObjectId(); + + // Test with non-existent agent + const result = await updateAgentProjects({ + user: { id: userId.toString() }, + agentId: nonExistentId, + projectIds: [projectId.toString()], + }); + + expect(result).toBeNull(); + }); + + test('should handle revertAgentVersion properly', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Original Name', + provider: 'test', + model: 'test-model', + author: authorId, + description: 'Original description', + }); + + await updateAgent( + { id: agentId }, + { name: 'Updated Name', description: 'Updated description' }, + ); + + const revertedAgent = await revertAgentVersion({ id: agentId }, 0); + + expect(revertedAgent.name).toBe('Original Name'); + expect(revertedAgent.description).toBe('Original description'); + expect(revertedAgent.author.toString()).toBe(authorId.toString()); + }); + + test('should handle action-related updates with getActions error', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + // Create agent with actions that might cause getActions to fail + await createAgent({ + id: agentId, + name: 'Agent with Actions', + provider: 'test', + model: 'test-model', + author: authorId, + actions: ['test.com_action_invalid_id'], + }); + + // Update should still work even if getActions fails + const updatedAgent = await updateAgent( + { id: agentId }, + { description: 'Updated description' }, + ); + + expect(updatedAgent).toBeDefined(); + expect(updatedAgent.description).toBe('Updated description'); + expect(updatedAgent.versions).toHaveLength(2); + }); + + test('should handle updateAgent with combined MongoDB operators', async () => { const agentId = `agent_${uuidv4()}`; const authorId = new mongoose.Types.ObjectId(); const projectId1 = new mongoose.Types.ObjectId(); const projectId2 = new mongoose.Types.ObjectId(); - const testCases = [ + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + tools: ['tool1'], + projectIds: [projectId1], + }); + + // Use multiple operators in single update - but avoid conflicting operations on same field + const updatedAgent = await updateAgent( + { id: agentId }, { - name: 'simple field update', - initial: { - name: 'Test Agent', - description: 'Initial description', - }, - update: { name: 'Updated Name' }, - duplicate: { name: 'Updated Name' }, + name: 'Updated Name', + $push: { tools: 'tool2' }, + $addToSet: { projectIds: projectId2 }, }, + ); + + const finalAgent = await updateAgent( + { id: agentId }, { - name: 'object field update', - initial: { - model_parameters: { temperature: 0.7 }, - }, - update: { model_parameters: { temperature: 0.8 } }, - duplicate: { model_parameters: { temperature: 0.8 } }, - }, - { - name: 'array field update', - initial: { - tools: ['tool1', 'tool2'], - }, - update: { tools: ['tool2', 'tool3'] }, - duplicate: { tools: ['tool2', 'tool3'] }, - }, - { - name: 'projectIds update', - initial: { - projectIds: [projectId1], - }, - update: { projectIds: [projectId1, projectId2] }, - duplicate: { projectIds: [projectId2, projectId1] }, + $pull: { projectIds: projectId1 }, }, + ); + + expect(updatedAgent).toBeDefined(); + expect(updatedAgent.name).toBe('Updated Name'); + expect(updatedAgent.tools).toContain('tool1'); + expect(updatedAgent.tools).toContain('tool2'); + expect(updatedAgent.projectIds.map((id) => id.toString())).toContain(projectId2.toString()); + + expect(finalAgent).toBeDefined(); + expect(finalAgent.projectIds.map((id) => id.toString())).not.toContain(projectId1.toString()); + expect(finalAgent.versions).toHaveLength(3); + }); + + test('should handle updateAgent when agent does not exist', async () => { + const nonExistentId = `agent_${uuidv4()}`; + + const result = await updateAgent({ id: nonExistentId }, { name: 'New Name' }); + + expect(result).toBeNull(); + }); + + test('should handle concurrent updates with database errors', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + // Mock findOneAndUpdate to simulate database error + const cleanup = mockFindOneAndUpdateError(2); + + // Concurrent updates where one fails + const promises = [ + updateAgent({ id: agentId }, { name: 'Update 1' }), + updateAgent({ id: agentId }, { name: 'Update 2' }), + updateAgent({ id: agentId }, { name: 'Update 3' }), ]; - for (const testCase of testCases) { - const testAgentId = `agent_${uuidv4()}`; + const results = await Promise.allSettled(promises); - await createAgent({ - id: testAgentId, - provider: 'test', - model: 'test-model', - author: authorId, - ...testCase.initial, + cleanup(); + + const succeeded = results.filter((r) => r.status === 'fulfilled').length; + const failed = results.filter((r) => r.status === 'rejected').length; + + expect(succeeded).toBe(2); + expect(failed).toBe(1); + }); + + test('should handle removeAgentResourceFiles when agent is deleted during operation', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + tool_resources: { + file_search: { + file_ids: ['file1', 'file2', 'file3'], + }, + }, + }); + + // Mock findOneAndUpdate to return null (simulating deletion) + const originalFindOneAndUpdate = Agent.findOneAndUpdate; + Agent.findOneAndUpdate = jest.fn().mockImplementation(() => ({ + lean: jest.fn().mockResolvedValue(null), + })); + + // Try to remove files from deleted agent + await expect( + removeAgentResourceFiles({ + agent_id: agentId, + files: [ + { tool_resource: 'file_search', file_id: 'file1' }, + { tool_resource: 'file_search', file_id: 'file2' }, + ], + }), + ).rejects.toThrow('Failed to update agent during file removal (pull step)'); + + Agent.findOneAndUpdate = originalFindOneAndUpdate; + }); + + test('should handle loadEphemeralAgent with malformed MCP tool names', async () => { + const { EPHEMERAL_AGENT_ID } = require('librechat-data-provider').Constants; + + const mockReq = { + user: { id: 'user123' }, + body: { + promptPrefix: 'Test instructions', + ephemeralAgent: { + execute_code: false, + web_search: false, + mcp: ['server1'], + }, + }, + app: { + locals: { + availableTools: { + malformed_tool_name: {}, // No mcp delimiter + tool__server1: {}, // Wrong delimiter + tool_mcp_server1: {}, // Correct format + tool_mcp_server2: {}, // Different server + }, + }, + }, + }; + + const result = await loadAgent({ + req: mockReq, + agent_id: EPHEMERAL_AGENT_ID, + endpoint: 'openai', + model_parameters: { model: 'gpt-4' }, + }); + + if (result) { + expect(result.tools).toEqual(['tool_mcp_server1']); + expect(result.tools).not.toContain('malformed_tool_name'); + expect(result.tools).not.toContain('tool__server1'); + expect(result.tools).not.toContain('tool_mcp_server2'); + } + }); + + test('should handle addAgentResourceFile when array initialization fails', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + // Mock the updateOne operation to fail but let updateAgent succeed + const originalUpdateOne = Agent.updateOne; + let updateOneCalled = false; + Agent.updateOne = jest.fn().mockImplementation((...args) => { + if (!updateOneCalled) { + updateOneCalled = true; + return Promise.reject(new Error('Database error')); + } + return originalUpdateOne.apply(Agent, args); + }); + + try { + const result = await addAgentResourceFile({ + agent_id: agentId, + tool_resource: 'new_tool', + file_id: 'file123', }); - await updateAgent({ id: testAgentId }, testCase.update); - - let error; - try { - await updateAgent({ id: testAgentId }, testCase.duplicate); - } catch (e) { - error = e; - } - - expect(error).toBeDefined(); - expect(error.message).toContain('Duplicate version'); - expect(error.statusCode).toBe(409); - expect(error.details).toBeDefined(); - expect(error.details.duplicateVersion).toBeDefined(); - - const agent = await getAgent({ id: testAgentId }); - expect(agent.versions).toHaveLength(2); + expect(result).toBeDefined(); + expect(result.tools).toContain('new_tool'); + } catch (error) { + expect(error.message).toBe('Database error'); } - } finally { - console.error = originalConsoleError; - } + + Agent.updateOne = originalUpdateOne; + }); }); - test('should track updatedBy when a different user updates an agent', async () => { - const agentId = `agent_${uuidv4()}`; - const originalAuthor = new mongoose.Types.ObjectId(); - const updatingUser = new mongoose.Types.ObjectId(); + describe('Agent IDs Field in Version Detection', () => { + let mongoServer; - await createAgent({ - id: agentId, - name: 'Original Agent', - provider: 'test', - model: 'test-model', - author: originalAuthor, - description: 'Original description', + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema); + await mongoose.connect(mongoUri); }); - const updatedAgent = await updateAgent( - { id: agentId }, - { name: 'Updated Agent', description: 'Updated description' }, - updatingUser.toString(), - ); - - expect(updatedAgent.versions).toHaveLength(2); - expect(updatedAgent.versions[1].updatedBy.toString()).toBe(updatingUser.toString()); - expect(updatedAgent.author.toString()).toBe(originalAuthor.toString()); - }); - - test('should include updatedBy even when the original author updates the agent', async () => { - const agentId = `agent_${uuidv4()}`; - const originalAuthor = new mongoose.Types.ObjectId(); - - await createAgent({ - id: agentId, - name: 'Original Agent', - provider: 'test', - model: 'test-model', - author: originalAuthor, - description: 'Original description', + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); }); - const updatedAgent = await updateAgent( - { id: agentId }, - { name: 'Updated Agent', description: 'Updated description' }, - originalAuthor.toString(), - ); - - expect(updatedAgent.versions).toHaveLength(2); - expect(updatedAgent.versions[1].updatedBy.toString()).toBe(originalAuthor.toString()); - expect(updatedAgent.author.toString()).toBe(originalAuthor.toString()); - }); - - test('should track multiple different users updating the same agent', async () => { - const agentId = `agent_${uuidv4()}`; - const originalAuthor = new mongoose.Types.ObjectId(); - const user1 = new mongoose.Types.ObjectId(); - const user2 = new mongoose.Types.ObjectId(); - const user3 = new mongoose.Types.ObjectId(); - - await createAgent({ - id: agentId, - name: 'Original Agent', - provider: 'test', - model: 'test-model', - author: originalAuthor, - description: 'Original description', + beforeEach(async () => { + await Agent.deleteMany({}); }); - // User 1 makes an update - await updateAgent( - { id: agentId }, - { name: 'Updated by User 1', description: 'First update' }, - user1.toString(), - ); + test('should now create new version when agent_ids field changes', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); - // Original author makes an update - await updateAgent( - { id: agentId }, - { description: 'Updated by original author' }, - originalAuthor.toString(), - ); + const agent = await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + agent_ids: ['agent1', 'agent2'], + }); - // User 2 makes an update - await updateAgent( - { id: agentId }, - { name: 'Updated by User 2', model: 'new-model' }, - user2.toString(), - ); + expect(agent).toBeDefined(); + expect(agent.versions).toHaveLength(1); - // User 3 makes an update - const finalAgent = await updateAgent( - { id: agentId }, - { description: 'Final update by User 3' }, - user3.toString(), - ); + const updated = await updateAgent( + { id: agentId }, + { agent_ids: ['agent1', 'agent2', 'agent3'] }, + ); - expect(finalAgent.versions).toHaveLength(5); - expect(finalAgent.author.toString()).toBe(originalAuthor.toString()); - - // Check that each version has the correct updatedBy - expect(finalAgent.versions[0].updatedBy).toBeUndefined(); // Initial creation has no updatedBy - expect(finalAgent.versions[1].updatedBy.toString()).toBe(user1.toString()); - expect(finalAgent.versions[2].updatedBy.toString()).toBe(originalAuthor.toString()); - expect(finalAgent.versions[3].updatedBy.toString()).toBe(user2.toString()); - expect(finalAgent.versions[4].updatedBy.toString()).toBe(user3.toString()); - - // Verify the final state - expect(finalAgent.name).toBe('Updated by User 2'); - expect(finalAgent.description).toBe('Final update by User 3'); - expect(finalAgent.model).toBe('new-model'); - }); - - test('should preserve original author during agent restoration', async () => { - const agentId = `agent_${uuidv4()}`; - const originalAuthor = new mongoose.Types.ObjectId(); - const updatingUser = new mongoose.Types.ObjectId(); - - await createAgent({ - id: agentId, - name: 'Original Agent', - provider: 'test', - model: 'test-model', - author: originalAuthor, - description: 'Original description', + // Since agent_ids is no longer excluded, this should create a new version + expect(updated.versions).toHaveLength(2); + expect(updated.agent_ids).toEqual(['agent1', 'agent2', 'agent3']); }); - await updateAgent( - { id: agentId }, - { name: 'Updated Agent', description: 'Updated description' }, - updatingUser.toString(), - ); + test('should detect duplicate version if agent_ids is updated to same value', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); - const { revertAgentVersion } = require('./Agent'); - const revertedAgent = await revertAgentVersion({ id: agentId }, 0); + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + agent_ids: ['agent1', 'agent2'], + }); - expect(revertedAgent.author.toString()).toBe(originalAuthor.toString()); - expect(revertedAgent.name).toBe('Original Agent'); - expect(revertedAgent.description).toBe('Original description'); + await updateAgent({ id: agentId }, { agent_ids: ['agent1', 'agent2', 'agent3'] }); + + await expect( + updateAgent({ id: agentId }, { agent_ids: ['agent1', 'agent2', 'agent3'] }), + ).rejects.toThrow('Duplicate version'); + }); + + test('should handle agent_ids field alongside other fields', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + description: 'Initial description', + agent_ids: ['agent1'], + }); + + const updated = await updateAgent( + { id: agentId }, + { + agent_ids: ['agent1', 'agent2'], + description: 'Updated description', + }, + ); + + expect(updated.versions).toHaveLength(2); + expect(updated.agent_ids).toEqual(['agent1', 'agent2']); + expect(updated.description).toBe('Updated description'); + + const updated2 = await updateAgent({ id: agentId }, { description: 'Another description' }); + + expect(updated2.versions).toHaveLength(3); + expect(updated2.agent_ids).toEqual(['agent1', 'agent2']); + expect(updated2.description).toBe('Another description'); + }); + + test('should skip version creation when skipVersioning option is used', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + const projectId1 = new mongoose.Types.ObjectId(); + const projectId2 = new mongoose.Types.ObjectId(); + + // Create agent with initial projectIds + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + projectIds: [projectId1], + }); + + // Share agent using updateAgentProjects (which uses skipVersioning) + const shared = await updateAgentProjects({ + user: { id: authorId.toString() }, // Use the same author ID + agentId: agentId, + projectIds: [projectId2.toString()], + }); + + // Should NOT create a new version due to skipVersioning + expect(shared.versions).toHaveLength(1); + expect(shared.projectIds.map((id) => id.toString())).toContain(projectId1.toString()); + expect(shared.projectIds.map((id) => id.toString())).toContain(projectId2.toString()); + + // Unshare agent using updateAgentProjects + const unshared = await updateAgentProjects({ + user: { id: authorId.toString() }, + agentId: agentId, + removeProjectIds: [projectId1.toString()], + }); + + // Still should NOT create a new version + expect(unshared.versions).toHaveLength(1); + expect(unshared.projectIds.map((id) => id.toString())).not.toContain(projectId1.toString()); + expect(unshared.projectIds.map((id) => id.toString())).toContain(projectId2.toString()); + + // Regular update without skipVersioning should create a version + const regularUpdate = await updateAgent( + { id: agentId }, + { description: 'Updated description' }, + ); + + expect(regularUpdate.versions).toHaveLength(2); + expect(regularUpdate.description).toBe('Updated description'); + + // Direct updateAgent with MongoDB operators should still create versions + const directUpdate = await updateAgent( + { id: agentId }, + { $addToSet: { projectIds: { $each: [projectId1] } } }, + ); + + expect(directUpdate.versions).toHaveLength(3); + expect(directUpdate.projectIds.length).toBe(2); + }); + + test('should preserve agent_ids in version history', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + agent_ids: ['agent1'], + }); + + await updateAgent({ id: agentId }, { agent_ids: ['agent1', 'agent2'] }); + + await updateAgent({ id: agentId }, { agent_ids: ['agent3'] }); + + const finalAgent = await getAgent({ id: agentId }); + + expect(finalAgent.versions).toHaveLength(3); + expect(finalAgent.versions[0].agent_ids).toEqual(['agent1']); + expect(finalAgent.versions[1].agent_ids).toEqual(['agent1', 'agent2']); + expect(finalAgent.versions[2].agent_ids).toEqual(['agent3']); + expect(finalAgent.agent_ids).toEqual(['agent3']); + }); + + test('should handle empty agent_ids arrays', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + agent_ids: ['agent1', 'agent2'], + }); + + const updated = await updateAgent({ id: agentId }, { agent_ids: [] }); + + expect(updated.versions).toHaveLength(2); + expect(updated.agent_ids).toEqual([]); + + await expect(updateAgent({ id: agentId }, { agent_ids: [] })).rejects.toThrow( + 'Duplicate version', + ); + }); + + test('should handle agent without agent_ids field', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + const agent = await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + expect(agent.agent_ids).toEqual([]); + + const updated = await updateAgent({ id: agentId }, { agent_ids: ['agent1'] }); + + expect(updated.versions).toHaveLength(2); + expect(updated.agent_ids).toEqual(['agent1']); + }); }); }); + +function createBasicAgent(overrides = {}) { + const defaults = { + id: `agent_${uuidv4()}`, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: new mongoose.Types.ObjectId(), + }; + return createAgent({ ...defaults, ...overrides }); +} + +function createTestIds() { + return { + agentId: `agent_${uuidv4()}`, + authorId: new mongoose.Types.ObjectId(), + projectId: new mongoose.Types.ObjectId(), + fileId: uuidv4(), + }; +} + +function createFileOperations(agentId, fileIds, operation = 'add') { + return fileIds.map((fileId) => + operation === 'add' + ? addAgentResourceFile({ agent_id: agentId, tool_resource: 'test_tool', file_id: fileId }) + : removeAgentResourceFiles({ + agent_id: agentId, + files: [{ tool_resource: 'test_tool', file_id: fileId }], + }), + ); +} + +function mockFindOneAndUpdateError(errorOnCall = 1) { + const original = Agent.findOneAndUpdate; + let callCount = 0; + + Agent.findOneAndUpdate = jest.fn().mockImplementation((...args) => { + callCount++; + if (callCount === errorOnCall) { + throw new Error('Database connection lost'); + } + return original.apply(Agent, args); + }); + + return () => { + Agent.findOneAndUpdate = original; + }; +} + +function generateVersionTestCases() { + const projectId1 = new mongoose.Types.ObjectId(); + const projectId2 = new mongoose.Types.ObjectId(); + + return [ + { + name: 'simple field update', + initial: { + name: 'Test Agent', + description: 'Initial description', + }, + update: { name: 'Updated Name' }, + duplicate: { name: 'Updated Name' }, + }, + { + name: 'object field update', + initial: { + model_parameters: { temperature: 0.7 }, + }, + update: { model_parameters: { temperature: 0.8 } }, + duplicate: { model_parameters: { temperature: 0.8 } }, + }, + { + name: 'array field update', + initial: { + tools: ['tool1', 'tool2'], + }, + update: { tools: ['tool2', 'tool3'] }, + duplicate: { tools: ['tool2', 'tool3'] }, + }, + { + name: 'projectIds update', + initial: { + projectIds: [projectId1], + }, + update: { projectIds: [projectId1, projectId2] }, + duplicate: { projectIds: [projectId2, projectId1] }, + }, + ]; +} diff --git a/api/models/Assistant.js b/api/models/Assistant.js index a8a5b98157..be94d35d7d 100644 --- a/api/models/Assistant.js +++ b/api/models/Assistant.js @@ -1,7 +1,4 @@ -const mongoose = require('mongoose'); -const { assistantSchema } = require('@librechat/data-schemas'); - -const Assistant = mongoose.model('assistant', assistantSchema); +const { Assistant } = require('~/db/models'); /** * Update an assistant with new data without overwriting existing properties, diff --git a/api/models/Balance.js b/api/models/Balance.js deleted file mode 100644 index 226f6ef508..0000000000 --- a/api/models/Balance.js +++ /dev/null @@ -1,4 +0,0 @@ -const mongoose = require('mongoose'); -const { balanceSchema } = require('@librechat/data-schemas'); - -module.exports = mongoose.model('Balance', balanceSchema); diff --git a/api/models/Banner.js b/api/models/Banner.js index 399a8e72ee..42ad1599ed 100644 --- a/api/models/Banner.js +++ b/api/models/Banner.js @@ -1,8 +1,5 @@ -const mongoose = require('mongoose'); -const logger = require('~/config/winston'); -const { bannerSchema } = require('@librechat/data-schemas'); - -const Banner = mongoose.model('Banner', bannerSchema); +const { logger } = require('@librechat/data-schemas'); +const { Banner } = require('~/db/models'); /** * Retrieves the current active banner. @@ -28,4 +25,4 @@ const getBanner = async (user) => { } }; -module.exports = { Banner, getBanner }; +module.exports = { getBanner }; diff --git a/api/models/Config.js b/api/models/Config.js deleted file mode 100644 index fefb84b8f9..0000000000 --- a/api/models/Config.js +++ /dev/null @@ -1,86 +0,0 @@ -const mongoose = require('mongoose'); -const { logger } = require('~/config'); - -const major = [0, 0]; -const minor = [0, 0]; -const patch = [0, 5]; - -const configSchema = mongoose.Schema( - { - tag: { - type: String, - required: true, - validate: { - validator: function (tag) { - const [part1, part2, part3] = tag.replace('v', '').split('.').map(Number); - - // Check if all parts are numbers - if (isNaN(part1) || isNaN(part2) || isNaN(part3)) { - return false; - } - - // Check if all parts are within their respective ranges - if (part1 < major[0] || part1 > major[1]) { - return false; - } - if (part2 < minor[0] || part2 > minor[1]) { - return false; - } - if (part3 < patch[0] || part3 > patch[1]) { - return false; - } - return true; - }, - message: 'Invalid tag value', - }, - }, - searchEnabled: { - type: Boolean, - default: false, - }, - usersEnabled: { - type: Boolean, - default: false, - }, - startupCounts: { - type: Number, - default: 0, - }, - }, - { timestamps: true }, -); - -// Instance method -configSchema.methods.incrementCount = function () { - this.startupCounts += 1; -}; - -// Static methods -configSchema.statics.findByTag = async function (tag) { - return await this.findOne({ tag }).lean(); -}; - -configSchema.statics.updateByTag = async function (tag, update) { - return await this.findOneAndUpdate({ tag }, update, { new: true }); -}; - -const Config = mongoose.models.Config || mongoose.model('Config', configSchema); - -module.exports = { - getConfigs: async (filter) => { - try { - return await Config.find(filter).lean(); - } catch (error) { - logger.error('Error getting configs', error); - return { config: 'Error getting configs' }; - } - }, - deleteConfigs: async (filter) => { - try { - return await Config.deleteMany(filter); - } catch (error) { - logger.error('Error deleting configs', error); - return { config: 'Error deleting configs' }; - } - }, -}; diff --git a/api/models/Conversation.js b/api/models/Conversation.js index 51081a6491..38e2cbb448 100644 --- a/api/models/Conversation.js +++ b/api/models/Conversation.js @@ -1,6 +1,6 @@ -const Conversation = require('./schema/convoSchema'); +const { logger } = require('@librechat/data-schemas'); const { getMessages, deleteMessages } = require('./Message'); -const logger = require('~/config/winston'); +const { Conversation } = require('~/db/models'); /** * Searches for a conversation by conversationId and returns a lean document with only conversationId and user. @@ -75,7 +75,6 @@ const getConvoFiles = async (conversationId) => { }; module.exports = { - Conversation, getConvoFiles, searchConversation, deleteNullOrEmptyConversations, @@ -155,7 +154,6 @@ module.exports = { { cursor, limit = 25, isArchived = false, tags, search, order = 'desc' } = {}, ) => { const filters = [{ user }]; - if (isArchived) { filters.push({ isArchived: true }); } else { @@ -288,7 +286,6 @@ module.exports = { deleteConvos: async (user, filter) => { try { const userFilter = { ...filter, user }; - const conversations = await Conversation.find(userFilter).select('conversationId'); const conversationIds = conversations.map((c) => c.conversationId); diff --git a/api/models/ConversationTag.js b/api/models/ConversationTag.js index f0cac8620e..e6dc96be64 100644 --- a/api/models/ConversationTag.js +++ b/api/models/ConversationTag.js @@ -1,10 +1,5 @@ -const mongoose = require('mongoose'); -const Conversation = require('./schema/convoSchema'); -const logger = require('~/config/winston'); - -const { conversationTagSchema } = require('@librechat/data-schemas'); - -const ConversationTag = mongoose.model('ConversationTag', conversationTagSchema); +const { logger } = require('@librechat/data-schemas'); +const { ConversationTag, Conversation } = require('~/db/models'); /** * Retrieves all conversation tags for a user. @@ -140,13 +135,13 @@ const adjustPositions = async (user, oldPosition, newPosition) => { const position = oldPosition < newPosition ? { - $gt: Math.min(oldPosition, newPosition), - $lte: Math.max(oldPosition, newPosition), - } + $gt: Math.min(oldPosition, newPosition), + $lte: Math.max(oldPosition, newPosition), + } : { - $gte: Math.min(oldPosition, newPosition), - $lt: Math.max(oldPosition, newPosition), - }; + $gte: Math.min(oldPosition, newPosition), + $lt: Math.max(oldPosition, newPosition), + }; await ConversationTag.updateMany( { diff --git a/api/models/File.js b/api/models/File.js index 4d94994478..ff509539e3 100644 --- a/api/models/File.js +++ b/api/models/File.js @@ -1,9 +1,6 @@ -const mongoose = require('mongoose'); +const { logger } = require('@librechat/data-schemas'); const { EToolResources } = require('librechat-data-provider'); -const { fileSchema } = require('@librechat/data-schemas'); -const { logger } = require('~/config'); - -const File = mongoose.model('File', fileSchema); +const { File } = require('~/db/models'); /** * Finds a file by its file_id with additional query options. @@ -169,7 +166,6 @@ async function batchUpdateFiles(updates) { } module.exports = { - File, findFileById, getFiles, getToolFilesByIds, diff --git a/api/models/Key.js b/api/models/Key.js deleted file mode 100644 index c69c350a42..0000000000 --- a/api/models/Key.js +++ /dev/null @@ -1,4 +0,0 @@ -const mongoose = require('mongoose'); -const { keySchema } = require('@librechat/data-schemas'); - -module.exports = mongoose.model('Key', keySchema); diff --git a/api/models/Message.js b/api/models/Message.js index 86fd2fd549..abd538084e 100644 --- a/api/models/Message.js +++ b/api/models/Message.js @@ -1,6 +1,6 @@ const { z } = require('zod'); -const Message = require('./schema/messageSchema'); -const { logger } = require('~/config'); +const { logger } = require('@librechat/data-schemas'); +const { Message } = require('~/db/models'); const idSchema = z.string().uuid(); @@ -68,7 +68,6 @@ async function saveMessage(req, params, metadata) { logger.info(`---\`saveMessage\` context: ${metadata?.context}`); update.tokenCount = 0; } - const message = await Message.findOneAndUpdate( { messageId: params.messageId, user: req.user.id }, update, @@ -140,7 +139,6 @@ async function bulkSaveMessages(messages, overrideTimestamp = false) { upsert: true, }, })); - const result = await Message.bulkWrite(bulkOps); return result; } catch (err) { @@ -255,6 +253,7 @@ async function updateMessage(req, message, metadata) { text: updatedMessage.text, isCreatedByUser: updatedMessage.isCreatedByUser, tokenCount: updatedMessage.tokenCount, + feedback: updatedMessage.feedback, }; } catch (err) { logger.error('Error updating message:', err); @@ -355,7 +354,6 @@ async function deleteMessages(filter) { } module.exports = { - Message, saveMessage, bulkSaveMessages, recordMessage, diff --git a/api/models/Message.spec.js b/api/models/Message.spec.js index a542130b59..aebaebb442 100644 --- a/api/models/Message.spec.js +++ b/api/models/Message.spec.js @@ -1,32 +1,7 @@ const mongoose = require('mongoose'); +const { MongoMemoryServer } = require('mongodb-memory-server'); const { v4: uuidv4 } = require('uuid'); - -jest.mock('mongoose'); - -const mockFindQuery = { - select: jest.fn().mockReturnThis(), - sort: jest.fn().mockReturnThis(), - lean: jest.fn().mockReturnThis(), - deleteMany: jest.fn().mockResolvedValue({ deletedCount: 1 }), -}; - -const mockSchema = { - findOneAndUpdate: jest.fn(), - updateOne: jest.fn(), - findOne: jest.fn(() => ({ - lean: jest.fn(), - })), - find: jest.fn(() => mockFindQuery), - deleteMany: jest.fn(), -}; - -mongoose.model.mockReturnValue(mockSchema); - -jest.mock('~/models/schema/messageSchema', () => mockSchema); - -jest.mock('~/config/winston', () => ({ - error: jest.fn(), -})); +const { messageSchema } = require('@librechat/data-schemas'); const { saveMessage, @@ -35,77 +10,102 @@ const { deleteMessages, updateMessageText, deleteMessagesSince, -} = require('~/models/Message'); +} = require('./Message'); + +/** + * @type {import('mongoose').Model} + */ +let Message; describe('Message Operations', () => { + let mongoServer; let mockReq; - let mockMessage; + let mockMessageData; - beforeEach(() => { - jest.clearAllMocks(); + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + Message = mongoose.models.Message || mongoose.model('Message', messageSchema); + await mongoose.connect(mongoUri); + }); + + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); + }); + + beforeEach(async () => { + // Clear database + await Message.deleteMany({}); mockReq = { user: { id: 'user123' }, }; - mockMessage = { + mockMessageData = { messageId: 'msg123', conversationId: uuidv4(), text: 'Hello, world!', user: 'user123', }; - - mockSchema.findOneAndUpdate.mockResolvedValue({ - toObject: () => mockMessage, - }); }); describe('saveMessage', () => { it('should save a message for an authenticated user', async () => { - const result = await saveMessage(mockReq, mockMessage); - expect(result).toEqual(mockMessage); - expect(mockSchema.findOneAndUpdate).toHaveBeenCalledWith( - { messageId: 'msg123', user: 'user123' }, - expect.objectContaining({ user: 'user123' }), - expect.any(Object), - ); + const result = await saveMessage(mockReq, mockMessageData); + + expect(result.messageId).toBe('msg123'); + expect(result.user).toBe('user123'); + expect(result.text).toBe('Hello, world!'); + + // Verify the message was actually saved to the database + const savedMessage = await Message.findOne({ messageId: 'msg123', user: 'user123' }); + expect(savedMessage).toBeTruthy(); + expect(savedMessage.text).toBe('Hello, world!'); }); it('should throw an error for unauthenticated user', async () => { mockReq.user = null; - await expect(saveMessage(mockReq, mockMessage)).rejects.toThrow('User not authenticated'); + await expect(saveMessage(mockReq, mockMessageData)).rejects.toThrow('User not authenticated'); }); - it('should throw an error for invalid conversation ID', async () => { - mockMessage.conversationId = 'invalid-id'; - await expect(saveMessage(mockReq, mockMessage)).resolves.toBeUndefined(); + it('should handle invalid conversation ID gracefully', async () => { + mockMessageData.conversationId = 'invalid-id'; + const result = await saveMessage(mockReq, mockMessageData); + expect(result).toBeUndefined(); }); }); describe('updateMessageText', () => { it('should update message text for the authenticated user', async () => { + // First save a message + await saveMessage(mockReq, mockMessageData); + + // Then update it await updateMessageText(mockReq, { messageId: 'msg123', text: 'Updated text' }); - expect(mockSchema.updateOne).toHaveBeenCalledWith( - { messageId: 'msg123', user: 'user123' }, - { text: 'Updated text' }, - ); + + // Verify the update + const updatedMessage = await Message.findOne({ messageId: 'msg123', user: 'user123' }); + expect(updatedMessage.text).toBe('Updated text'); }); }); describe('updateMessage', () => { it('should update a message for the authenticated user', async () => { - mockSchema.findOneAndUpdate.mockResolvedValue(mockMessage); + // First save a message + await saveMessage(mockReq, mockMessageData); + const result = await updateMessage(mockReq, { messageId: 'msg123', text: 'Updated text' }); - expect(result).toEqual( - expect.objectContaining({ - messageId: 'msg123', - text: 'Hello, world!', - }), - ); + + expect(result.messageId).toBe('msg123'); + expect(result.text).toBe('Updated text'); + + // Verify in database + const updatedMessage = await Message.findOne({ messageId: 'msg123', user: 'user123' }); + expect(updatedMessage.text).toBe('Updated text'); }); it('should throw an error if message is not found', async () => { - mockSchema.findOneAndUpdate.mockResolvedValue(null); await expect( updateMessage(mockReq, { messageId: 'nonexistent', text: 'Test' }), ).rejects.toThrow('Message not found or user not authorized.'); @@ -114,19 +114,45 @@ describe('Message Operations', () => { describe('deleteMessagesSince', () => { it('should delete messages only for the authenticated user', async () => { - mockSchema.findOne().lean.mockResolvedValueOnce({ createdAt: new Date() }); - mockFindQuery.deleteMany.mockResolvedValueOnce({ deletedCount: 1 }); - const result = await deleteMessagesSince(mockReq, { - messageId: 'msg123', - conversationId: 'convo123', + const conversationId = uuidv4(); + + // Create multiple messages in the same conversation + const message1 = await saveMessage(mockReq, { + messageId: 'msg1', + conversationId, + text: 'First message', + user: 'user123', }); - expect(mockSchema.findOne).toHaveBeenCalledWith({ messageId: 'msg123', user: 'user123' }); - expect(mockSchema.find).not.toHaveBeenCalled(); - expect(result).toBeUndefined(); + + const message2 = await saveMessage(mockReq, { + messageId: 'msg2', + conversationId, + text: 'Second message', + user: 'user123', + }); + + const message3 = await saveMessage(mockReq, { + messageId: 'msg3', + conversationId, + text: 'Third message', + user: 'user123', + }); + + // Delete messages since message2 (this should only delete messages created AFTER msg2) + await deleteMessagesSince(mockReq, { + messageId: 'msg2', + conversationId, + }); + + // Verify msg1 and msg2 remain, msg3 is deleted + const remainingMessages = await Message.find({ conversationId, user: 'user123' }); + expect(remainingMessages).toHaveLength(2); + expect(remainingMessages.map((m) => m.messageId)).toContain('msg1'); + expect(remainingMessages.map((m) => m.messageId)).toContain('msg2'); + expect(remainingMessages.map((m) => m.messageId)).not.toContain('msg3'); }); it('should return undefined if no message is found', async () => { - mockSchema.findOne().lean.mockResolvedValueOnce(null); const result = await deleteMessagesSince(mockReq, { messageId: 'nonexistent', conversationId: 'convo123', @@ -137,29 +163,71 @@ describe('Message Operations', () => { describe('getMessages', () => { it('should retrieve messages with the correct filter', async () => { - const filter = { conversationId: 'convo123' }; - await getMessages(filter); - expect(mockSchema.find).toHaveBeenCalledWith(filter); - expect(mockFindQuery.sort).toHaveBeenCalledWith({ createdAt: 1 }); - expect(mockFindQuery.lean).toHaveBeenCalled(); + const conversationId = uuidv4(); + + // Save some messages + await saveMessage(mockReq, { + messageId: 'msg1', + conversationId, + text: 'First message', + user: 'user123', + }); + + await saveMessage(mockReq, { + messageId: 'msg2', + conversationId, + text: 'Second message', + user: 'user123', + }); + + const messages = await getMessages({ conversationId }); + expect(messages).toHaveLength(2); + expect(messages[0].text).toBe('First message'); + expect(messages[1].text).toBe('Second message'); }); }); describe('deleteMessages', () => { it('should delete messages with the correct filter', async () => { + // Save some messages for different users + await saveMessage(mockReq, mockMessageData); + await saveMessage( + { user: { id: 'user456' } }, + { + messageId: 'msg456', + conversationId: uuidv4(), + text: 'Other user message', + user: 'user456', + }, + ); + await deleteMessages({ user: 'user123' }); - expect(mockSchema.deleteMany).toHaveBeenCalledWith({ user: 'user123' }); + + // Verify only user123's messages were deleted + const user123Messages = await Message.find({ user: 'user123' }); + const user456Messages = await Message.find({ user: 'user456' }); + + expect(user123Messages).toHaveLength(0); + expect(user456Messages).toHaveLength(1); }); }); describe('Conversation Hijacking Prevention', () => { - it('should not allow editing a message in another user\'s conversation', async () => { + it("should not allow editing a message in another user's conversation", async () => { const attackerReq = { user: { id: 'attacker123' } }; - const victimConversationId = 'victim-convo-123'; + const victimConversationId = uuidv4(); const victimMessageId = 'victim-msg-123'; - mockSchema.findOneAndUpdate.mockResolvedValue(null); + // First, save a message as the victim (but we'll try to edit as attacker) + const victimReq = { user: { id: 'victim123' } }; + await saveMessage(victimReq, { + messageId: victimMessageId, + conversationId: victimConversationId, + text: 'Victim message', + user: 'victim123', + }); + // Attacker tries to edit the victim's message await expect( updateMessage(attackerReq, { messageId: victimMessageId, @@ -168,71 +236,82 @@ describe('Message Operations', () => { }), ).rejects.toThrow('Message not found or user not authorized.'); - expect(mockSchema.findOneAndUpdate).toHaveBeenCalledWith( - { messageId: victimMessageId, user: 'attacker123' }, - expect.anything(), - expect.anything(), - ); + // Verify the original message is unchanged + const originalMessage = await Message.findOne({ + messageId: victimMessageId, + user: 'victim123', + }); + expect(originalMessage.text).toBe('Victim message'); }); - it('should not allow deleting messages from another user\'s conversation', async () => { + it("should not allow deleting messages from another user's conversation", async () => { const attackerReq = { user: { id: 'attacker123' } }; - const victimConversationId = 'victim-convo-123'; + const victimConversationId = uuidv4(); const victimMessageId = 'victim-msg-123'; - mockSchema.findOne().lean.mockResolvedValueOnce(null); // Simulating message not found for this user + // Save a message as the victim + const victimReq = { user: { id: 'victim123' } }; + await saveMessage(victimReq, { + messageId: victimMessageId, + conversationId: victimConversationId, + text: 'Victim message', + user: 'victim123', + }); + + // Attacker tries to delete from victim's conversation const result = await deleteMessagesSince(attackerReq, { messageId: victimMessageId, conversationId: victimConversationId, }); expect(result).toBeUndefined(); - expect(mockSchema.findOne).toHaveBeenCalledWith({ + + // Verify the victim's message still exists + const victimMessage = await Message.findOne({ messageId: victimMessageId, - user: 'attacker123', + user: 'victim123', }); + expect(victimMessage).toBeTruthy(); + expect(victimMessage.text).toBe('Victim message'); }); - it('should not allow inserting a new message into another user\'s conversation', async () => { + it("should not allow inserting a new message into another user's conversation", async () => { const attackerReq = { user: { id: 'attacker123' } }; - const victimConversationId = uuidv4(); // Use a valid UUID + const victimConversationId = uuidv4(); - await expect( - saveMessage(attackerReq, { - conversationId: victimConversationId, - text: 'Inserted malicious message', - messageId: 'new-msg-123', - }), - ).resolves.not.toThrow(); // It should not throw an error + // Attacker tries to save a message - this should succeed but with attacker's user ID + const result = await saveMessage(attackerReq, { + conversationId: victimConversationId, + text: 'Inserted malicious message', + messageId: 'new-msg-123', + user: 'attacker123', + }); - // Check that the message was saved with the attacker's user ID - expect(mockSchema.findOneAndUpdate).toHaveBeenCalledWith( - { messageId: 'new-msg-123', user: 'attacker123' }, - expect.objectContaining({ - user: 'attacker123', - conversationId: victimConversationId, - }), - expect.anything(), - ); + expect(result).toBeTruthy(); + expect(result.user).toBe('attacker123'); + + // Verify the message was saved with the attacker's user ID, not as an anonymous message + const savedMessage = await Message.findOne({ messageId: 'new-msg-123' }); + expect(savedMessage.user).toBe('attacker123'); + expect(savedMessage.conversationId).toBe(victimConversationId); }); it('should allow retrieving messages from any conversation', async () => { - const victimConversationId = 'victim-convo-123'; + const victimConversationId = uuidv4(); - await getMessages({ conversationId: victimConversationId }); - - expect(mockSchema.find).toHaveBeenCalledWith({ + // Save a message in the victim's conversation + const victimReq = { user: { id: 'victim123' } }; + await saveMessage(victimReq, { + messageId: 'victim-msg', conversationId: victimConversationId, + text: 'Victim message', + user: 'victim123', }); - mockSchema.find.mockReturnValueOnce({ - select: jest.fn().mockReturnThis(), - sort: jest.fn().mockReturnThis(), - lean: jest.fn().mockResolvedValue([{ text: 'Test message' }]), - }); - - const result = await getMessages({ conversationId: victimConversationId }); - expect(result).toEqual([{ text: 'Test message' }]); + // Anyone should be able to retrieve messages by conversation ID + const messages = await getMessages({ conversationId: victimConversationId }); + expect(messages).toHaveLength(1); + expect(messages[0].text).toBe('Victim message'); }); }); }); diff --git a/api/models/Preset.js b/api/models/Preset.js index 970b2958fb..4db3d59066 100644 --- a/api/models/Preset.js +++ b/api/models/Preset.js @@ -1,5 +1,5 @@ -const Preset = require('./schema/presetSchema'); -const { logger } = require('~/config'); +const { logger } = require('@librechat/data-schemas'); +const { Preset } = require('~/db/models'); const getPreset = async (user, presetId) => { try { @@ -11,7 +11,6 @@ const getPreset = async (user, presetId) => { }; module.exports = { - Preset, getPreset, getPresets: async (user, filter) => { try { diff --git a/api/models/Project.js b/api/models/Project.js index 43d7263723..8fd1e556f9 100644 --- a/api/models/Project.js +++ b/api/models/Project.js @@ -1,8 +1,5 @@ -const { model } = require('mongoose'); const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants; -const { projectSchema } = require('@librechat/data-schemas'); - -const Project = model('Project', projectSchema); +const { Project } = require('~/db/models'); /** * Retrieve a project by ID and convert the found project document to a plain object. diff --git a/api/models/Prompt.js b/api/models/Prompt.js index 43dc3ec22b..9499e19c8e 100644 --- a/api/models/Prompt.js +++ b/api/models/Prompt.js @@ -1,5 +1,5 @@ -const mongoose = require('mongoose'); const { ObjectId } = require('mongodb'); +const { logger } = require('@librechat/data-schemas'); const { SystemRoles, SystemCategories, Constants } = require('librechat-data-provider'); const { getProjectByName, @@ -7,12 +7,8 @@ const { removeGroupIdsFromProject, removeGroupFromAllProjects, } = require('./Project'); -const { promptGroupSchema, promptSchema } = require('@librechat/data-schemas'); +const { PromptGroup, Prompt } = require('~/db/models'); const { escapeRegExp } = require('~/server/utils'); -const { logger } = require('~/config'); - -const PromptGroup = mongoose.model('PromptGroup', promptGroupSchema); -const Prompt = mongoose.model('Prompt', promptSchema); /** * Create a pipeline for the aggregation to get prompt groups diff --git a/api/models/Role.js b/api/models/Role.js index 07bf5a2ccb..d7f1c0f9cf 100644 --- a/api/models/Role.js +++ b/api/models/Role.js @@ -1,4 +1,3 @@ -const mongoose = require('mongoose'); const { CacheKeys, SystemRoles, @@ -7,11 +6,9 @@ const { permissionsSchema, removeNullishValues, } = require('librechat-data-provider'); +const { logger } = require('@librechat/data-schemas'); const getLogStores = require('~/cache/getLogStores'); -const { roleSchema } = require('@librechat/data-schemas'); -const { logger } = require('~/config'); - -const Role = mongoose.model('Role', roleSchema); +const { Role } = require('~/db/models'); /** * Retrieve a role by name and convert the found role document to a plain object. @@ -173,35 +170,6 @@ async function updateAccessPermissions(roleName, permissionsUpdate) { } } -/** - * Initialize default roles in the system. - * Creates the default roles (ADMIN, USER) if they don't exist in the database. - * Updates existing roles with new permission types if they're missing. - * - * @returns {Promise} - */ -const initializeRoles = async function () { - for (const roleName of [SystemRoles.ADMIN, SystemRoles.USER]) { - let role = await Role.findOne({ name: roleName }); - const defaultPerms = roleDefaults[roleName].permissions; - - if (!role) { - // Create new role if it doesn't exist. - role = new Role(roleDefaults[roleName]); - } else { - // Ensure role.permissions is defined. - role.permissions = role.permissions || {}; - // For each permission type in defaults, add it if missing. - for (const permType of Object.keys(defaultPerms)) { - if (role.permissions[permType] == null) { - role.permissions[permType] = defaultPerms[permType]; - } - } - } - await role.save(); - } -}; - /** * Migrates roles from old schema to new schema structure. * This can be called directly to fix existing roles. @@ -282,10 +250,8 @@ const migrateRoleSchema = async function (roleName) { }; module.exports = { - Role, getRoleByName, - initializeRoles, updateRoleByName, - updateAccessPermissions, migrateRoleSchema, + updateAccessPermissions, }; diff --git a/api/models/Role.spec.js b/api/models/Role.spec.js index a8b60801ca..c344f719dd 100644 --- a/api/models/Role.spec.js +++ b/api/models/Role.spec.js @@ -6,8 +6,10 @@ const { roleDefaults, PermissionTypes, } = require('librechat-data-provider'); -const { Role, getRoleByName, updateAccessPermissions, initializeRoles } = require('~/models/Role'); +const { getRoleByName, updateAccessPermissions } = require('~/models/Role'); const getLogStores = require('~/cache/getLogStores'); +const { initializeRoles } = require('~/models'); +const { Role } = require('~/db/models'); // Mock the cache jest.mock('~/cache/getLogStores', () => diff --git a/api/models/Session.js b/api/models/Session.js deleted file mode 100644 index 38821b77dd..0000000000 --- a/api/models/Session.js +++ /dev/null @@ -1,275 +0,0 @@ -const mongoose = require('mongoose'); -const signPayload = require('~/server/services/signPayload'); -const { hashToken } = require('~/server/utils/crypto'); -const { sessionSchema } = require('@librechat/data-schemas'); -const { logger } = require('~/config'); - -const Session = mongoose.model('Session', sessionSchema); - -const { REFRESH_TOKEN_EXPIRY } = process.env ?? {}; -const expires = eval(REFRESH_TOKEN_EXPIRY) ?? 1000 * 60 * 60 * 24 * 7; // 7 days default - -/** - * Error class for Session-related errors - */ -class SessionError extends Error { - constructor(message, code = 'SESSION_ERROR') { - super(message); - this.name = 'SessionError'; - this.code = code; - } -} - -/** - * Creates a new session for a user - * @param {string} userId - The ID of the user - * @param {Object} options - Additional options for session creation - * @param {Date} options.expiration - Custom expiration date - * @returns {Promise<{session: Session, refreshToken: string}>} - * @throws {SessionError} - */ -const createSession = async (userId, options = {}) => { - if (!userId) { - throw new SessionError('User ID is required', 'INVALID_USER_ID'); - } - - try { - const session = new Session({ - user: userId, - expiration: options.expiration || new Date(Date.now() + expires), - }); - const refreshToken = await generateRefreshToken(session); - return { session, refreshToken }; - } catch (error) { - logger.error('[createSession] Error creating session:', error); - throw new SessionError('Failed to create session', 'CREATE_SESSION_FAILED'); - } -}; - -/** - * Finds a session by various parameters - * @param {Object} params - Search parameters - * @param {string} [params.refreshToken] - The refresh token to search by - * @param {string} [params.userId] - The user ID to search by - * @param {string} [params.sessionId] - The session ID to search by - * @param {Object} [options] - Additional options - * @param {boolean} [options.lean=true] - Whether to return plain objects instead of documents - * @returns {Promise} - * @throws {SessionError} - */ -const findSession = async (params, options = { lean: true }) => { - try { - const query = {}; - - if (!params.refreshToken && !params.userId && !params.sessionId) { - throw new SessionError('At least one search parameter is required', 'INVALID_SEARCH_PARAMS'); - } - - if (params.refreshToken) { - const tokenHash = await hashToken(params.refreshToken); - query.refreshTokenHash = tokenHash; - } - - if (params.userId) { - query.user = params.userId; - } - - if (params.sessionId) { - const sessionId = params.sessionId.sessionId || params.sessionId; - if (!mongoose.Types.ObjectId.isValid(sessionId)) { - throw new SessionError('Invalid session ID format', 'INVALID_SESSION_ID'); - } - query._id = sessionId; - } - - // Add expiration check to only return valid sessions - query.expiration = { $gt: new Date() }; - - const sessionQuery = Session.findOne(query); - - if (options.lean) { - return await sessionQuery.lean(); - } - - return await sessionQuery.exec(); - } catch (error) { - logger.error('[findSession] Error finding session:', error); - throw new SessionError('Failed to find session', 'FIND_SESSION_FAILED'); - } -}; - -/** - * Updates session expiration - * @param {Session|string} session - The session or session ID to update - * @param {Date} [newExpiration] - Optional new expiration date - * @returns {Promise} - * @throws {SessionError} - */ -const updateExpiration = async (session, newExpiration) => { - try { - const sessionDoc = typeof session === 'string' ? await Session.findById(session) : session; - - if (!sessionDoc) { - throw new SessionError('Session not found', 'SESSION_NOT_FOUND'); - } - - sessionDoc.expiration = newExpiration || new Date(Date.now() + expires); - return await sessionDoc.save(); - } catch (error) { - logger.error('[updateExpiration] Error updating session:', error); - throw new SessionError('Failed to update session expiration', 'UPDATE_EXPIRATION_FAILED'); - } -}; - -/** - * Deletes a session by refresh token or session ID - * @param {Object} params - Delete parameters - * @param {string} [params.refreshToken] - The refresh token of the session to delete - * @param {string} [params.sessionId] - The ID of the session to delete - * @returns {Promise} - * @throws {SessionError} - */ -const deleteSession = async (params) => { - try { - if (!params.refreshToken && !params.sessionId) { - throw new SessionError( - 'Either refreshToken or sessionId is required', - 'INVALID_DELETE_PARAMS', - ); - } - - const query = {}; - - if (params.refreshToken) { - query.refreshTokenHash = await hashToken(params.refreshToken); - } - - if (params.sessionId) { - query._id = params.sessionId; - } - - const result = await Session.deleteOne(query); - - if (result.deletedCount === 0) { - logger.warn('[deleteSession] No session found to delete'); - } - - return result; - } catch (error) { - logger.error('[deleteSession] Error deleting session:', error); - throw new SessionError('Failed to delete session', 'DELETE_SESSION_FAILED'); - } -}; - -/** - * Deletes all sessions for a user - * @param {string} userId - The ID of the user - * @param {Object} [options] - Additional options - * @param {boolean} [options.excludeCurrentSession] - Whether to exclude the current session - * @param {string} [options.currentSessionId] - The ID of the current session to exclude - * @returns {Promise} - * @throws {SessionError} - */ -const deleteAllUserSessions = async (userId, options = {}) => { - try { - if (!userId) { - throw new SessionError('User ID is required', 'INVALID_USER_ID'); - } - - // Extract userId if it's passed as an object - const userIdString = userId.userId || userId; - - if (!mongoose.Types.ObjectId.isValid(userIdString)) { - throw new SessionError('Invalid user ID format', 'INVALID_USER_ID_FORMAT'); - } - - const query = { user: userIdString }; - - if (options.excludeCurrentSession && options.currentSessionId) { - query._id = { $ne: options.currentSessionId }; - } - - const result = await Session.deleteMany(query); - - if (result.deletedCount > 0) { - logger.debug( - `[deleteAllUserSessions] Deleted ${result.deletedCount} sessions for user ${userIdString}.`, - ); - } - - return result; - } catch (error) { - logger.error('[deleteAllUserSessions] Error deleting user sessions:', error); - throw new SessionError('Failed to delete user sessions', 'DELETE_ALL_SESSIONS_FAILED'); - } -}; - -/** - * Generates a refresh token for a session - * @param {Session} session - The session to generate a token for - * @returns {Promise} - * @throws {SessionError} - */ -const generateRefreshToken = async (session) => { - if (!session || !session.user) { - throw new SessionError('Invalid session object', 'INVALID_SESSION'); - } - - try { - const expiresIn = session.expiration ? session.expiration.getTime() : Date.now() + expires; - - if (!session.expiration) { - session.expiration = new Date(expiresIn); - } - - const refreshToken = await signPayload({ - payload: { - id: session.user, - sessionId: session._id, - }, - secret: process.env.JWT_REFRESH_SECRET, - expirationTime: Math.floor((expiresIn - Date.now()) / 1000), - }); - - session.refreshTokenHash = await hashToken(refreshToken); - await session.save(); - - return refreshToken; - } catch (error) { - logger.error('[generateRefreshToken] Error generating refresh token:', error); - throw new SessionError('Failed to generate refresh token', 'GENERATE_TOKEN_FAILED'); - } -}; - -/** - * Counts active sessions for a user - * @param {string} userId - The ID of the user - * @returns {Promise} - * @throws {SessionError} - */ -const countActiveSessions = async (userId) => { - try { - if (!userId) { - throw new SessionError('User ID is required', 'INVALID_USER_ID'); - } - - return await Session.countDocuments({ - user: userId, - expiration: { $gt: new Date() }, - }); - } catch (error) { - logger.error('[countActiveSessions] Error counting active sessions:', error); - throw new SessionError('Failed to count active sessions', 'COUNT_SESSIONS_FAILED'); - } -}; - -module.exports = { - createSession, - findSession, - updateExpiration, - deleteSession, - deleteAllUserSessions, - generateRefreshToken, - countActiveSessions, - SessionError, -}; diff --git a/api/models/Share.js b/api/models/Share.js deleted file mode 100644 index 8611d01bc0..0000000000 --- a/api/models/Share.js +++ /dev/null @@ -1,351 +0,0 @@ -const mongoose = require('mongoose'); -const { nanoid } = require('nanoid'); -const { Constants } = require('librechat-data-provider'); -const { Conversation } = require('~/models/Conversation'); -const { shareSchema } = require('@librechat/data-schemas'); -const SharedLink = mongoose.model('SharedLink', shareSchema); -const { getMessages } = require('./Message'); -const logger = require('~/config/winston'); - -class ShareServiceError extends Error { - constructor(message, code) { - super(message); - this.name = 'ShareServiceError'; - this.code = code; - } -} - -const memoizedAnonymizeId = (prefix) => { - const memo = new Map(); - return (id) => { - if (!memo.has(id)) { - memo.set(id, `${prefix}_${nanoid()}`); - } - return memo.get(id); - }; -}; - -const anonymizeConvoId = memoizedAnonymizeId('convo'); -const anonymizeAssistantId = memoizedAnonymizeId('a'); -const anonymizeMessageId = (id) => - id === Constants.NO_PARENT ? id : memoizedAnonymizeId('msg')(id); - -function anonymizeConvo(conversation) { - if (!conversation) { - return null; - } - - const newConvo = { ...conversation }; - if (newConvo.assistant_id) { - newConvo.assistant_id = anonymizeAssistantId(newConvo.assistant_id); - } - return newConvo; -} - -function anonymizeMessages(messages, newConvoId) { - if (!Array.isArray(messages)) { - return []; - } - - const idMap = new Map(); - return messages.map((message) => { - const newMessageId = anonymizeMessageId(message.messageId); - idMap.set(message.messageId, newMessageId); - - const anonymizedAttachments = message.attachments?.map((attachment) => { - return { - ...attachment, - messageId: newMessageId, - conversationId: newConvoId, - }; - }); - - return { - ...message, - messageId: newMessageId, - parentMessageId: - idMap.get(message.parentMessageId) || anonymizeMessageId(message.parentMessageId), - conversationId: newConvoId, - model: message.model?.startsWith('asst_') - ? anonymizeAssistantId(message.model) - : message.model, - attachments: anonymizedAttachments, - }; - }); -} - -async function getSharedMessages(shareId) { - try { - const share = await SharedLink.findOne({ shareId, isPublic: true }) - .populate({ - path: 'messages', - select: '-_id -__v -user', - }) - .select('-_id -__v -user') - .lean(); - - if (!share?.conversationId || !share.isPublic) { - return null; - } - - const newConvoId = anonymizeConvoId(share.conversationId); - const result = { - ...share, - conversationId: newConvoId, - messages: anonymizeMessages(share.messages, newConvoId), - }; - - return result; - } catch (error) { - logger.error('[getShare] Error getting share link', { - error: error.message, - shareId, - }); - throw new ShareServiceError('Error getting share link', 'SHARE_FETCH_ERROR'); - } -} - -async function getSharedLinks(user, pageParam, pageSize, isPublic, sortBy, sortDirection, search) { - try { - const query = { user, isPublic }; - - if (pageParam) { - if (sortDirection === 'desc') { - query[sortBy] = { $lt: pageParam }; - } else { - query[sortBy] = { $gt: pageParam }; - } - } - - if (search && search.trim()) { - try { - const searchResults = await Conversation.meiliSearch(search); - - if (!searchResults?.hits?.length) { - return { - links: [], - nextCursor: undefined, - hasNextPage: false, - }; - } - - const conversationIds = searchResults.hits.map((hit) => hit.conversationId); - query['conversationId'] = { $in: conversationIds }; - } catch (searchError) { - logger.error('[getSharedLinks] Meilisearch error', { - error: searchError.message, - user, - }); - return { - links: [], - nextCursor: undefined, - hasNextPage: false, - }; - } - } - - const sort = {}; - sort[sortBy] = sortDirection === 'desc' ? -1 : 1; - - if (Array.isArray(query.conversationId)) { - query.conversationId = { $in: query.conversationId }; - } - - const sharedLinks = await SharedLink.find(query) - .sort(sort) - .limit(pageSize + 1) - .select('-__v -user') - .lean(); - - const hasNextPage = sharedLinks.length > pageSize; - const links = sharedLinks.slice(0, pageSize); - - const nextCursor = hasNextPage ? links[links.length - 1][sortBy] : undefined; - - return { - links: links.map((link) => ({ - shareId: link.shareId, - title: link?.title || 'Untitled', - isPublic: link.isPublic, - createdAt: link.createdAt, - conversationId: link.conversationId, - })), - nextCursor, - hasNextPage, - }; - } catch (error) { - logger.error('[getSharedLinks] Error getting shares', { - error: error.message, - user, - }); - throw new ShareServiceError('Error getting shares', 'SHARES_FETCH_ERROR'); - } -} - -async function deleteAllSharedLinks(user) { - try { - const result = await SharedLink.deleteMany({ user }); - return { - message: 'All shared links deleted successfully', - deletedCount: result.deletedCount, - }; - } catch (error) { - logger.error('[deleteAllSharedLinks] Error deleting shared links', { - error: error.message, - user, - }); - throw new ShareServiceError('Error deleting shared links', 'BULK_DELETE_ERROR'); - } -} - -async function createSharedLink(user, conversationId) { - if (!user || !conversationId) { - throw new ShareServiceError('Missing required parameters', 'INVALID_PARAMS'); - } - - try { - const [existingShare, conversationMessages] = await Promise.all([ - SharedLink.findOne({ conversationId, isPublic: true }).select('-_id -__v -user').lean(), - getMessages({ conversationId }), - ]); - - if (existingShare && existingShare.isPublic) { - throw new ShareServiceError('Share already exists', 'SHARE_EXISTS'); - } else if (existingShare) { - await SharedLink.deleteOne({ conversationId }); - } - - const conversation = await Conversation.findOne({ conversationId }).lean(); - const title = conversation?.title || 'Untitled'; - - const shareId = nanoid(); - await SharedLink.create({ - shareId, - conversationId, - messages: conversationMessages, - title, - user, - }); - - return { shareId, conversationId }; - } catch (error) { - logger.error('[createSharedLink] Error creating shared link', { - error: error.message, - user, - conversationId, - }); - throw new ShareServiceError('Error creating shared link', 'SHARE_CREATE_ERROR'); - } -} - -async function getSharedLink(user, conversationId) { - if (!user || !conversationId) { - throw new ShareServiceError('Missing required parameters', 'INVALID_PARAMS'); - } - - try { - const share = await SharedLink.findOne({ conversationId, user, isPublic: true }) - .select('shareId -_id') - .lean(); - - if (!share) { - return { shareId: null, success: false }; - } - - return { shareId: share.shareId, success: true }; - } catch (error) { - logger.error('[getSharedLink] Error getting shared link', { - error: error.message, - user, - conversationId, - }); - throw new ShareServiceError('Error getting shared link', 'SHARE_FETCH_ERROR'); - } -} - -async function updateSharedLink(user, shareId) { - if (!user || !shareId) { - throw new ShareServiceError('Missing required parameters', 'INVALID_PARAMS'); - } - - try { - const share = await SharedLink.findOne({ shareId }).select('-_id -__v -user').lean(); - - if (!share) { - throw new ShareServiceError('Share not found', 'SHARE_NOT_FOUND'); - } - - const [updatedMessages] = await Promise.all([ - getMessages({ conversationId: share.conversationId }), - ]); - - const newShareId = nanoid(); - const update = { - messages: updatedMessages, - user, - shareId: newShareId, - }; - - const updatedShare = await SharedLink.findOneAndUpdate({ shareId, user }, update, { - new: true, - upsert: false, - runValidators: true, - }).lean(); - - if (!updatedShare) { - throw new ShareServiceError('Share update failed', 'SHARE_UPDATE_ERROR'); - } - - anonymizeConvo(updatedShare); - - return { shareId: newShareId, conversationId: updatedShare.conversationId }; - } catch (error) { - logger.error('[updateSharedLink] Error updating shared link', { - error: error.message, - user, - shareId, - }); - throw new ShareServiceError( - error.code === 'SHARE_UPDATE_ERROR' ? error.message : 'Error updating shared link', - error.code || 'SHARE_UPDATE_ERROR', - ); - } -} - -async function deleteSharedLink(user, shareId) { - if (!user || !shareId) { - throw new ShareServiceError('Missing required parameters', 'INVALID_PARAMS'); - } - - try { - const result = await SharedLink.findOneAndDelete({ shareId, user }).lean(); - - if (!result) { - return null; - } - - return { - success: true, - shareId, - message: 'Share deleted successfully', - }; - } catch (error) { - logger.error('[deleteSharedLink] Error deleting shared link', { - error: error.message, - user, - shareId, - }); - throw new ShareServiceError('Error deleting shared link', 'SHARE_DELETE_ERROR'); - } -} - -module.exports = { - SharedLink, - getSharedLink, - getSharedLinks, - createSharedLink, - updateSharedLink, - deleteSharedLink, - getSharedMessages, - deleteAllSharedLinks, -}; diff --git a/api/models/Token.js b/api/models/Token.js index c89abb8c84..6f130eb2c4 100644 --- a/api/models/Token.js +++ b/api/models/Token.js @@ -1,158 +1,5 @@ -const mongoose = require('mongoose'); +const { findToken, updateToken, createToken } = require('~/models'); const { encryptV2 } = require('~/server/utils/crypto'); -const { tokenSchema } = require('@librechat/data-schemas'); -const { logger } = require('~/config'); - -/** - * Token model. - * @type {mongoose.Model} - */ -const Token = mongoose.model('Token', tokenSchema); -/** - * Fixes the indexes for the Token collection from legacy TTL indexes to the new expiresAt index. - */ -async function fixIndexes() { - try { - if ( - process.env.NODE_ENV === 'CI' || - process.env.NODE_ENV === 'development' || - process.env.NODE_ENV === 'test' - ) { - return; - } - const indexes = await Token.collection.indexes(); - logger.debug('Existing Token Indexes:', JSON.stringify(indexes, null, 2)); - const unwantedTTLIndexes = indexes.filter( - (index) => index.key.createdAt === 1 && index.expireAfterSeconds !== undefined, - ); - if (unwantedTTLIndexes.length === 0) { - logger.debug('No unwanted Token indexes found.'); - return; - } - for (const index of unwantedTTLIndexes) { - logger.debug(`Dropping unwanted Token index: ${index.name}`); - await Token.collection.dropIndex(index.name); - logger.debug(`Dropped Token index: ${index.name}`); - } - logger.debug('Token index cleanup completed successfully.'); - } catch (error) { - logger.error('An error occurred while fixing Token indexes:', error); - } -} - -fixIndexes(); - -/** - * Creates a new Token instance. - * @param {Object} tokenData - The data for the new Token. - * @param {mongoose.Types.ObjectId} tokenData.userId - The user's ID. It is required. - * @param {String} tokenData.email - The user's email. - * @param {String} tokenData.token - The token. It is required. - * @param {Number} tokenData.expiresIn - The number of seconds until the token expires. - * @returns {Promise} The new Token instance. - * @throws Will throw an error if token creation fails. - */ -async function createToken(tokenData) { - try { - const currentTime = new Date(); - const expiresAt = new Date(currentTime.getTime() + tokenData.expiresIn * 1000); - - const newTokenData = { - ...tokenData, - createdAt: currentTime, - expiresAt, - }; - - return await Token.create(newTokenData); - } catch (error) { - logger.debug('An error occurred while creating token:', error); - throw error; - } -} - -/** - * Finds a Token document that matches the provided query. - * @param {Object} query - The query to match against. - * @param {mongoose.Types.ObjectId|String} query.userId - The ID of the user. - * @param {String} query.token - The token value. - * @param {String} [query.email] - The email of the user. - * @param {String} [query.identifier] - Unique, alternative identifier for the token. - * @returns {Promise} The matched Token document, or null if not found. - * @throws Will throw an error if the find operation fails. - */ -async function findToken(query) { - try { - const conditions = []; - - if (query.userId) { - conditions.push({ userId: query.userId }); - } - if (query.token) { - conditions.push({ token: query.token }); - } - if (query.email) { - conditions.push({ email: query.email }); - } - if (query.identifier) { - conditions.push({ identifier: query.identifier }); - } - - const token = await Token.findOne({ - $and: conditions, - }).lean(); - - return token; - } catch (error) { - logger.debug('An error occurred while finding token:', error); - throw error; - } -} - -/** - * Updates a Token document that matches the provided query. - * @param {Object} query - The query to match against. - * @param {mongoose.Types.ObjectId|String} query.userId - The ID of the user. - * @param {String} query.token - The token value. - * @param {String} [query.email] - The email of the user. - * @param {String} [query.identifier] - Unique, alternative identifier for the token. - * @param {Object} updateData - The data to update the Token with. - * @returns {Promise} The updated Token document, or null if not found. - * @throws Will throw an error if the update operation fails. - */ -async function updateToken(query, updateData) { - try { - return await Token.findOneAndUpdate(query, updateData, { new: true }); - } catch (error) { - logger.debug('An error occurred while updating token:', error); - throw error; - } -} - -/** - * Deletes all Token documents that match the provided token, user ID, or email. - * @param {Object} query - The query to match against. - * @param {mongoose.Types.ObjectId|String} query.userId - The ID of the user. - * @param {String} query.token - The token value. - * @param {String} [query.email] - The email of the user. - * @param {String} [query.identifier] - Unique, alternative identifier for the token. - * @returns {Promise} The result of the delete operation. - * @throws Will throw an error if the delete operation fails. - */ -async function deleteTokens(query) { - try { - return await Token.deleteMany({ - $or: [ - { userId: query.userId }, - { token: query.token }, - { email: query.email }, - { identifier: query.identifier }, - ], - }); - } catch (error) { - logger.debug('An error occurred while deleting tokens:', error); - throw error; - } -} /** * Handles the OAuth token by creating or updating the token. @@ -191,9 +38,5 @@ async function handleOAuthToken({ } module.exports = { - findToken, - createToken, - updateToken, - deleteTokens, handleOAuthToken, }; diff --git a/api/models/ToolCall.js b/api/models/ToolCall.js index 7bc0f157dc..689386114b 100644 --- a/api/models/ToolCall.js +++ b/api/models/ToolCall.js @@ -1,6 +1,4 @@ -const mongoose = require('mongoose'); -const { toolCallSchema } = require('@librechat/data-schemas'); -const ToolCall = mongoose.model('ToolCall', toolCallSchema); +const { ToolCall } = require('~/db/models'); /** * Create a new tool call diff --git a/api/models/Transaction.js b/api/models/Transaction.js index e171241b61..0e0e327857 100644 --- a/api/models/Transaction.js +++ b/api/models/Transaction.js @@ -1,9 +1,7 @@ -const mongoose = require('mongoose'); -const { transactionSchema } = require('@librechat/data-schemas'); +const { logger } = require('@librechat/data-schemas'); const { getBalanceConfig } = require('~/server/services/Config'); const { getMultiplier, getCacheMultiplier } = require('./tx'); -const { logger } = require('~/config'); -const Balance = require('./Balance'); +const { Transaction, Balance } = require('~/db/models'); const cancelRate = 1.15; @@ -140,19 +138,19 @@ const updateBalance = async ({ user, incrementValue, setValues }) => { }; /** Method to calculate and set the tokenValue for a transaction */ -transactionSchema.methods.calculateTokenValue = function () { - if (!this.valueKey || !this.tokenType) { - this.tokenValue = this.rawAmount; +function calculateTokenValue(txn) { + if (!txn.valueKey || !txn.tokenType) { + txn.tokenValue = txn.rawAmount; } - const { valueKey, tokenType, model, endpointTokenConfig } = this; + const { valueKey, tokenType, model, endpointTokenConfig } = txn; const multiplier = Math.abs(getMultiplier({ valueKey, tokenType, model, endpointTokenConfig })); - this.rate = multiplier; - this.tokenValue = this.rawAmount * multiplier; - if (this.context && this.tokenType === 'completion' && this.context === 'incomplete') { - this.tokenValue = Math.ceil(this.tokenValue * cancelRate); - this.rate *= cancelRate; + txn.rate = multiplier; + txn.tokenValue = txn.rawAmount * multiplier; + if (txn.context && txn.tokenType === 'completion' && txn.context === 'incomplete') { + txn.tokenValue = Math.ceil(txn.tokenValue * cancelRate); + txn.rate *= cancelRate; } -}; +} /** * New static method to create an auto-refill transaction that does NOT trigger a balance update. @@ -163,13 +161,13 @@ transactionSchema.methods.calculateTokenValue = function () { * @param {number} txData.rawAmount - The raw amount of tokens. * @returns {Promise} - The created transaction. */ -transactionSchema.statics.createAutoRefillTransaction = async function (txData) { +async function createAutoRefillTransaction(txData) { if (txData.rawAmount != null && isNaN(txData.rawAmount)) { return; } - const transaction = new this(txData); + const transaction = new Transaction(txData); transaction.endpointTokenConfig = txData.endpointTokenConfig; - transaction.calculateTokenValue(); + calculateTokenValue(transaction); await transaction.save(); const balanceResponse = await updateBalance({ @@ -185,21 +183,20 @@ transactionSchema.statics.createAutoRefillTransaction = async function (txData) logger.debug('[Balance.check] Auto-refill performed', result); result.transaction = transaction; return result; -}; +} /** * Static method to create a transaction and update the balance * @param {txData} txData - Transaction data. */ -transactionSchema.statics.create = async function (txData) { - const Transaction = this; +async function createTransaction(txData) { if (txData.rawAmount != null && isNaN(txData.rawAmount)) { return; } const transaction = new Transaction(txData); transaction.endpointTokenConfig = txData.endpointTokenConfig; - transaction.calculateTokenValue(); + calculateTokenValue(transaction); await transaction.save(); @@ -209,7 +206,6 @@ transactionSchema.statics.create = async function (txData) { } let incrementValue = transaction.tokenValue; - const balanceResponse = await updateBalance({ user: transaction.user, incrementValue, @@ -221,21 +217,19 @@ transactionSchema.statics.create = async function (txData) { balance: balanceResponse.tokenCredits, [transaction.tokenType]: incrementValue, }; -}; +} /** * Static method to create a structured transaction and update the balance * @param {txData} txData - Transaction data. */ -transactionSchema.statics.createStructured = async function (txData) { - const Transaction = this; - +async function createStructuredTransaction(txData) { const transaction = new Transaction({ ...txData, endpointTokenConfig: txData.endpointTokenConfig, }); - transaction.calculateStructuredTokenValue(); + calculateStructuredTokenValue(transaction); await transaction.save(); @@ -257,71 +251,69 @@ transactionSchema.statics.createStructured = async function (txData) { balance: balanceResponse.tokenCredits, [transaction.tokenType]: incrementValue, }; -}; +} /** Method to calculate token value for structured tokens */ -transactionSchema.methods.calculateStructuredTokenValue = function () { - if (!this.tokenType) { - this.tokenValue = this.rawAmount; +function calculateStructuredTokenValue(txn) { + if (!txn.tokenType) { + txn.tokenValue = txn.rawAmount; return; } - const { model, endpointTokenConfig } = this; + const { model, endpointTokenConfig } = txn; - if (this.tokenType === 'prompt') { + if (txn.tokenType === 'prompt') { const inputMultiplier = getMultiplier({ tokenType: 'prompt', model, endpointTokenConfig }); const writeMultiplier = getCacheMultiplier({ cacheType: 'write', model, endpointTokenConfig }) ?? inputMultiplier; const readMultiplier = getCacheMultiplier({ cacheType: 'read', model, endpointTokenConfig }) ?? inputMultiplier; - this.rateDetail = { + txn.rateDetail = { input: inputMultiplier, write: writeMultiplier, read: readMultiplier, }; const totalPromptTokens = - Math.abs(this.inputTokens || 0) + - Math.abs(this.writeTokens || 0) + - Math.abs(this.readTokens || 0); + Math.abs(txn.inputTokens || 0) + + Math.abs(txn.writeTokens || 0) + + Math.abs(txn.readTokens || 0); if (totalPromptTokens > 0) { - this.rate = - (Math.abs(inputMultiplier * (this.inputTokens || 0)) + - Math.abs(writeMultiplier * (this.writeTokens || 0)) + - Math.abs(readMultiplier * (this.readTokens || 0))) / + txn.rate = + (Math.abs(inputMultiplier * (txn.inputTokens || 0)) + + Math.abs(writeMultiplier * (txn.writeTokens || 0)) + + Math.abs(readMultiplier * (txn.readTokens || 0))) / totalPromptTokens; } else { - this.rate = Math.abs(inputMultiplier); // Default to input rate if no tokens + txn.rate = Math.abs(inputMultiplier); // Default to input rate if no tokens } - this.tokenValue = -( - Math.abs(this.inputTokens || 0) * inputMultiplier + - Math.abs(this.writeTokens || 0) * writeMultiplier + - Math.abs(this.readTokens || 0) * readMultiplier + txn.tokenValue = -( + Math.abs(txn.inputTokens || 0) * inputMultiplier + + Math.abs(txn.writeTokens || 0) * writeMultiplier + + Math.abs(txn.readTokens || 0) * readMultiplier ); - this.rawAmount = -totalPromptTokens; - } else if (this.tokenType === 'completion') { - const multiplier = getMultiplier({ tokenType: this.tokenType, model, endpointTokenConfig }); - this.rate = Math.abs(multiplier); - this.tokenValue = -Math.abs(this.rawAmount) * multiplier; - this.rawAmount = -Math.abs(this.rawAmount); + txn.rawAmount = -totalPromptTokens; + } else if (txn.tokenType === 'completion') { + const multiplier = getMultiplier({ tokenType: txn.tokenType, model, endpointTokenConfig }); + txn.rate = Math.abs(multiplier); + txn.tokenValue = -Math.abs(txn.rawAmount) * multiplier; + txn.rawAmount = -Math.abs(txn.rawAmount); } - if (this.context && this.tokenType === 'completion' && this.context === 'incomplete') { - this.tokenValue = Math.ceil(this.tokenValue * cancelRate); - this.rate *= cancelRate; - if (this.rateDetail) { - this.rateDetail = Object.fromEntries( - Object.entries(this.rateDetail).map(([k, v]) => [k, v * cancelRate]), + if (txn.context && txn.tokenType === 'completion' && txn.context === 'incomplete') { + txn.tokenValue = Math.ceil(txn.tokenValue * cancelRate); + txn.rate *= cancelRate; + if (txn.rateDetail) { + txn.rateDetail = Object.fromEntries( + Object.entries(txn.rateDetail).map(([k, v]) => [k, v * cancelRate]), ); } } -}; - -const Transaction = mongoose.model('Transaction', transactionSchema); +} /** * Queries and retrieves transactions based on a given filter. @@ -340,4 +332,9 @@ async function getTransactions(filter) { } } -module.exports = { Transaction, getTransactions }; +module.exports = { + getTransactions, + createTransaction, + createAutoRefillTransaction, + createStructuredTransaction, +}; diff --git a/api/models/Transaction.spec.js b/api/models/Transaction.spec.js index 43f3c004b2..3a1303edec 100644 --- a/api/models/Transaction.spec.js +++ b/api/models/Transaction.spec.js @@ -3,14 +3,13 @@ const { MongoMemoryServer } = require('mongodb-memory-server'); const { spendTokens, spendStructuredTokens } = require('./spendTokens'); const { getBalanceConfig } = require('~/server/services/Config'); const { getMultiplier, getCacheMultiplier } = require('./tx'); -const { Transaction } = require('./Transaction'); -const Balance = require('./Balance'); +const { createTransaction } = require('./Transaction'); +const { Balance } = require('~/db/models'); // Mock the custom config module so we can control the balance flag. jest.mock('~/server/services/Config'); let mongoServer; - beforeAll(async () => { mongoServer = await MongoMemoryServer.create(); const mongoUri = mongoServer.getUri(); @@ -368,7 +367,7 @@ describe('NaN Handling Tests', () => { }; // Act - const result = await Transaction.create(txData); + const result = await createTransaction(txData); // Assert: No transaction should be created and balance remains unchanged. expect(result).toBeUndefined(); diff --git a/api/models/User.js b/api/models/User.js deleted file mode 100644 index f4e8b0ec5b..0000000000 --- a/api/models/User.js +++ /dev/null @@ -1,6 +0,0 @@ -const mongoose = require('mongoose'); -const { userSchema } = require('@librechat/data-schemas'); - -const User = mongoose.model('User', userSchema); - -module.exports = User; diff --git a/api/models/balanceMethods.js b/api/models/balanceMethods.js index 4b788160aa..7e6321ab2c 100644 --- a/api/models/balanceMethods.js +++ b/api/models/balanceMethods.js @@ -1,9 +1,9 @@ +const { logger } = require('@librechat/data-schemas'); const { ViolationTypes } = require('librechat-data-provider'); -const { Transaction } = require('./Transaction'); +const { createAutoRefillTransaction } = require('./Transaction'); const { logViolation } = require('~/cache'); const { getMultiplier } = require('./tx'); -const { logger } = require('~/config'); -const Balance = require('./Balance'); +const { Balance } = require('~/db/models'); function isInvalidDate(date) { return isNaN(date); @@ -60,7 +60,7 @@ const checkBalanceRecord = async function ({ ) { try { /** @type {{ rate: number, user: string, balance: number, transaction: import('@librechat/data-schemas').ITransaction}} */ - const result = await Transaction.createAutoRefillTransaction({ + const result = await createAutoRefillTransaction({ user: user, tokenType: 'credits', context: 'autoRefill', diff --git a/api/models/convoStructure.spec.js b/api/models/convoStructure.spec.js index e672e0fa1c..33bf0c9b2b 100644 --- a/api/models/convoStructure.spec.js +++ b/api/models/convoStructure.spec.js @@ -1,6 +1,7 @@ const mongoose = require('mongoose'); const { MongoMemoryServer } = require('mongodb-memory-server'); -const { Message, getMessages, bulkSaveMessages } = require('./Message'); +const { getMessages, bulkSaveMessages } = require('./Message'); +const { Message } = require('~/db/models'); // Original version of buildTree function function buildTree({ messages, fileMap }) { @@ -42,7 +43,6 @@ function buildTree({ messages, fileMap }) { } let mongod; - beforeAll(async () => { mongod = await MongoMemoryServer.create(); const uri = mongod.getUri(); diff --git a/api/models/index.js b/api/models/index.js index 73cfa1c96c..7ecb9adcbb 100644 --- a/api/models/index.js +++ b/api/models/index.js @@ -1,13 +1,7 @@ -const { - comparePassword, - deleteUserById, - generateToken, - getUserById, - updateUser, - createUser, - countUsers, - findUser, -} = require('./userMethods'); +const mongoose = require('mongoose'); +const { createMethods } = require('@librechat/data-schemas'); +const methods = createMethods(mongoose); +const { comparePassword } = require('./userMethods'); const { findFileById, createFile, @@ -26,32 +20,12 @@ const { deleteMessagesSince, deleteMessages, } = require('./Message'); -const { - createSession, - findSession, - updateExpiration, - deleteSession, - deleteAllUserSessions, - generateRefreshToken, - countActiveSessions, -} = require('./Session'); const { getConvoTitle, getConvo, saveConvo, deleteConvos } = require('./Conversation'); const { getPreset, getPresets, savePreset, deletePresets } = require('./Preset'); -const { createToken, findToken, updateToken, deleteTokens } = require('./Token'); -const Balance = require('./Balance'); -const User = require('./User'); -const Key = require('./Key'); module.exports = { + ...methods, comparePassword, - deleteUserById, - generateToken, - getUserById, - updateUser, - createUser, - countUsers, - findUser, - findFileById, createFile, updateFile, @@ -77,21 +51,4 @@ module.exports = { getPresets, savePreset, deletePresets, - - createToken, - findToken, - updateToken, - deleteTokens, - - createSession, - findSession, - updateExpiration, - deleteSession, - deleteAllUserSessions, - generateRefreshToken, - countActiveSessions, - - User, - Key, - Balance, }; diff --git a/api/models/inviteUser.js b/api/models/inviteUser.js index 6cd699fd66..9f35b3f02b 100644 --- a/api/models/inviteUser.js +++ b/api/models/inviteUser.js @@ -1,7 +1,7 @@ const mongoose = require('mongoose'); -const { getRandomValues, hashToken } = require('~/server/utils/crypto'); -const { createToken, findToken } = require('./Token'); -const logger = require('~/config/winston'); +const { logger, hashToken } = require('@librechat/data-schemas'); +const { getRandomValues } = require('~/server/utils/crypto'); +const { createToken, findToken } = require('~/models'); /** * @module inviteUser diff --git a/api/models/plugins/mongoMeili.js b/api/models/plugins/mongoMeili.js deleted file mode 100644 index 75e3738e5d..0000000000 --- a/api/models/plugins/mongoMeili.js +++ /dev/null @@ -1,475 +0,0 @@ -const _ = require('lodash'); -const mongoose = require('mongoose'); -const { MeiliSearch } = require('meilisearch'); -const { parseTextParts, ContentTypes } = require('librechat-data-provider'); -const { cleanUpPrimaryKeyValue } = require('~/lib/utils/misc'); -const logger = require('~/config/meiliLogger'); - -// Environment flags -/** - * Flag to indicate if search is enabled based on environment variables. - * @type {boolean} - */ -const searchEnabled = process.env.SEARCH && process.env.SEARCH.toLowerCase() === 'true'; - -/** - * Flag to indicate if MeiliSearch is enabled based on required environment variables. - * @type {boolean} - */ -const meiliEnabled = process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY && searchEnabled; - -/** - * Validates the required options for configuring the mongoMeili plugin. - * - * @param {Object} options - The configuration options. - * @param {string} options.host - The MeiliSearch host. - * @param {string} options.apiKey - The MeiliSearch API key. - * @param {string} options.indexName - The name of the index. - * @throws {Error} Throws an error if any required option is missing. - */ -const validateOptions = function (options) { - const requiredKeys = ['host', 'apiKey', 'indexName']; - requiredKeys.forEach((key) => { - if (!options[key]) { - throw new Error(`Missing mongoMeili Option: ${key}`); - } - }); -}; - -/** - * Factory function to create a MeiliMongooseModel class which extends a Mongoose model. - * This class contains static and instance methods to synchronize and manage the MeiliSearch index - * corresponding to the MongoDB collection. - * - * @param {Object} config - Configuration object. - * @param {Object} config.index - The MeiliSearch index object. - * @param {Array} config.attributesToIndex - List of attributes to index. - * @returns {Function} A class definition that will be loaded into the Mongoose schema. - */ -const createMeiliMongooseModel = function ({ index, attributesToIndex }) { - // The primary key is assumed to be the first attribute in the attributesToIndex array. - const primaryKey = attributesToIndex[0]; - - class MeiliMongooseModel { - /** - * Synchronizes the data between the MongoDB collection and the MeiliSearch index. - * - * The synchronization process involves: - * 1. Fetching all documents from the MongoDB collection and MeiliSearch index. - * 2. Comparing documents from both sources. - * 3. Deleting documents from MeiliSearch that no longer exist in MongoDB. - * 4. Adding documents to MeiliSearch that exist in MongoDB but not in the index. - * 5. Updating documents in MeiliSearch if key fields (such as `text` or `title`) differ. - * 6. Updating the `_meiliIndex` field in MongoDB to indicate the indexing status. - * - * Note: The function processes documents in batches because MeiliSearch's - * `index.getDocuments` requires an exact limit and `index.addDocuments` does not handle - * partial failures in a batch. - * - * @returns {Promise} Resolves when the synchronization is complete. - */ - static async syncWithMeili() { - try { - let moreDocuments = true; - // Retrieve all MongoDB documents from the collection as plain JavaScript objects. - const mongoDocuments = await this.find().lean(); - - // Helper function to format a document by selecting only the attributes to index - // and omitting keys starting with '$'. - const format = (doc) => - _.omitBy(_.pick(doc, attributesToIndex), (v, k) => k.startsWith('$')); - - // Build a map of MongoDB documents for quick lookup based on the primary key. - const mongoMap = new Map(mongoDocuments.map((doc) => [doc[primaryKey], format(doc)])); - const indexMap = new Map(); - let offset = 0; - const batchSize = 1000; - - // Fetch documents from the MeiliSearch index in batches. - while (moreDocuments) { - const batch = await index.getDocuments({ limit: batchSize, offset }); - if (batch.results.length === 0) { - moreDocuments = false; - } - for (const doc of batch.results) { - indexMap.set(doc[primaryKey], format(doc)); - } - offset += batchSize; - } - - logger.debug('[syncWithMeili]', { indexMap: indexMap.size, mongoMap: mongoMap.size }); - - const updateOps = []; - - // Process documents present in the MeiliSearch index. - for (const [id, doc] of indexMap) { - const update = {}; - update[primaryKey] = id; - if (mongoMap.has(id)) { - // If document exists in MongoDB, check for discrepancies in key fields. - if ( - (doc.text && doc.text !== mongoMap.get(id).text) || - (doc.title && doc.title !== mongoMap.get(id).title) - ) { - logger.debug( - `[syncWithMeili] ${id} had document discrepancy in ${ - doc.text ? 'text' : 'title' - } field`, - ); - updateOps.push({ - updateOne: { filter: update, update: { $set: { _meiliIndex: true } } }, - }); - await index.addDocuments([doc]); - } - } else { - // If the document does not exist in MongoDB, delete it from MeiliSearch. - await index.deleteDocument(id); - updateOps.push({ - updateOne: { filter: update, update: { $set: { _meiliIndex: false } } }, - }); - } - } - - // Process documents present in MongoDB. - for (const [id, doc] of mongoMap) { - const update = {}; - update[primaryKey] = id; - // If the document is missing in the Meili index, add it. - if (!indexMap.has(id)) { - await index.addDocuments([doc]); - updateOps.push({ - updateOne: { filter: update, update: { $set: { _meiliIndex: true } } }, - }); - } else if (doc._meiliIndex === false) { - // If the document exists but is marked as not indexed, update the flag. - updateOps.push({ - updateOne: { filter: update, update: { $set: { _meiliIndex: true } } }, - }); - } - } - - // Execute bulk update operations in MongoDB to update the _meiliIndex flags. - if (updateOps.length > 0) { - await this.collection.bulkWrite(updateOps); - logger.debug( - `[syncWithMeili] Finished indexing ${ - primaryKey === 'messageId' ? 'messages' : 'conversations' - }`, - ); - } - } catch (error) { - logger.error('[syncWithMeili] Error adding document to Meili', error); - } - } - - /** - * Updates settings for the MeiliSearch index. - * - * @param {Object} settings - The settings to update on the MeiliSearch index. - * @returns {Promise} Promise resolving to the update result. - */ - static async setMeiliIndexSettings(settings) { - return await index.updateSettings(settings); - } - - /** - * Searches the MeiliSearch index and optionally populates the results with data from MongoDB. - * - * @param {string} q - The search query. - * @param {Object} params - Additional search parameters for MeiliSearch. - * @param {boolean} populate - Whether to populate search hits with full MongoDB documents. - * @returns {Promise} The search results with populated hits if requested. - */ - static async meiliSearch(q, params, populate) { - const data = await index.search(q, params); - - if (populate) { - // Build a query using the primary key values from the search hits. - const query = {}; - query[primaryKey] = _.map(data.hits, (hit) => cleanUpPrimaryKeyValue(hit[primaryKey])); - - // Build a projection object, including only keys that do not start with '$'. - const projection = Object.keys(this.schema.obj).reduce( - (results, key) => { - if (!key.startsWith('$')) { - results[key] = 1; - } - return results; - }, - { _id: 1, __v: 1 }, - ); - - // Retrieve the full documents from MongoDB. - const hitsFromMongoose = await this.find(query, projection).lean(); - - // Merge the MongoDB documents with the search hits. - const populatedHits = data.hits.map(function (hit) { - const query = {}; - query[primaryKey] = hit[primaryKey]; - const originalHit = _.find(hitsFromMongoose, query); - - return { - ...(originalHit ?? {}), - ...hit, - }; - }); - data.hits = populatedHits; - } - - return data; - } - - /** - * Preprocesses the current document for indexing. - * - * This method: - * - Picks only the defined attributes to index. - * - Omits any keys starting with '$'. - * - Replaces pipe characters ('|') in `conversationId` with '--'. - * - Extracts and concatenates text from an array of content items. - * - * @returns {Object} The preprocessed object ready for indexing. - */ - preprocessObjectForIndex() { - const object = _.omitBy(_.pick(this.toJSON(), attributesToIndex), (v, k) => - k.startsWith('$'), - ); - if (object.conversationId && object.conversationId.includes('|')) { - object.conversationId = object.conversationId.replace(/\|/g, '--'); - } - - if (object.content && Array.isArray(object.content)) { - object.text = parseTextParts(object.content); - delete object.content; - } - - return object; - } - - /** - * Adds the current document to the MeiliSearch index. - * - * The method preprocesses the document, adds it to MeiliSearch, and then updates - * the MongoDB document's `_meiliIndex` flag to true. - * - * @returns {Promise} - */ - async addObjectToMeili() { - const object = this.preprocessObjectForIndex(); - try { - await index.addDocuments([object]); - } catch (error) { - // Error handling can be enhanced as needed. - logger.error('[addObjectToMeili] Error adding document to Meili', error); - } - - await this.collection.updateMany({ _id: this._id }, { $set: { _meiliIndex: true } }); - } - - /** - * Updates the current document in the MeiliSearch index. - * - * @returns {Promise} - */ - async updateObjectToMeili() { - const object = _.omitBy(_.pick(this.toJSON(), attributesToIndex), (v, k) => - k.startsWith('$'), - ); - await index.updateDocuments([object]); - } - - /** - * Deletes the current document from the MeiliSearch index. - * - * @returns {Promise} - */ - async deleteObjectFromMeili() { - await index.deleteDocument(this._id); - } - - /** - * Post-save hook to synchronize the document with MeiliSearch. - * - * If the document is already indexed (i.e. `_meiliIndex` is true), it updates it; - * otherwise, it adds the document to the index. - */ - postSaveHook() { - if (this._meiliIndex) { - this.updateObjectToMeili(); - } else { - this.addObjectToMeili(); - } - } - - /** - * Post-update hook to update the document in MeiliSearch. - * - * This hook is triggered after a document update, ensuring that changes are - * propagated to the MeiliSearch index if the document is indexed. - */ - postUpdateHook() { - if (this._meiliIndex) { - this.updateObjectToMeili(); - } - } - - /** - * Post-remove hook to delete the document from MeiliSearch. - * - * This hook is triggered after a document is removed, ensuring that the document - * is also removed from the MeiliSearch index if it was previously indexed. - */ - postRemoveHook() { - if (this._meiliIndex) { - this.deleteObjectFromMeili(); - } - } - } - - return MeiliMongooseModel; -}; - -/** - * Mongoose plugin to synchronize MongoDB collections with a MeiliSearch index. - * - * This plugin: - * - Validates the provided options. - * - Adds a `_meiliIndex` field to the schema to track indexing status. - * - Sets up a MeiliSearch client and creates an index if it doesn't already exist. - * - Loads class methods for syncing, searching, and managing documents in MeiliSearch. - * - Registers Mongoose hooks (post-save, post-update, post-remove, etc.) to maintain index consistency. - * - * @param {mongoose.Schema} schema - The Mongoose schema to which the plugin is applied. - * @param {Object} options - Configuration options. - * @param {string} options.host - The MeiliSearch host. - * @param {string} options.apiKey - The MeiliSearch API key. - * @param {string} options.indexName - The name of the MeiliSearch index. - * @param {string} options.primaryKey - The primary key field for indexing. - */ -module.exports = function mongoMeili(schema, options) { - validateOptions(options); - - // Add _meiliIndex field to the schema to track if a document has been indexed in MeiliSearch. - schema.add({ - _meiliIndex: { - type: Boolean, - required: false, - select: false, - default: false, - }, - }); - - const { host, apiKey, indexName, primaryKey } = options; - - // Setup the MeiliSearch client. - const client = new MeiliSearch({ host, apiKey }); - - // Create the index asynchronously if it doesn't exist. - client.createIndex(indexName, { primaryKey }); - - // Setup the MeiliSearch index for this schema. - const index = client.index(indexName); - - // Collect attributes from the schema that should be indexed. - const attributesToIndex = [ - ..._.reduce( - schema.obj, - function (results, value, key) { - return value.meiliIndex ? [...results, key] : results; - }, - [], - ), - ]; - - // Load the class methods into the schema. - schema.loadClass(createMeiliMongooseModel({ index, indexName, client, attributesToIndex })); - - // Register Mongoose hooks to synchronize with MeiliSearch. - - // Post-save: synchronize after a document is saved. - schema.post('save', function (doc) { - doc.postSaveHook(); - }); - - // Post-update: synchronize after a document is updated. - schema.post('update', function (doc) { - doc.postUpdateHook(); - }); - - // Post-remove: synchronize after a document is removed. - schema.post('remove', function (doc) { - doc.postRemoveHook(); - }); - - // Pre-deleteMany hook: remove corresponding documents from MeiliSearch when multiple documents are deleted. - schema.pre('deleteMany', async function (next) { - if (!meiliEnabled) { - return next(); - } - - try { - // Check if the schema has a "messages" field to determine if it's a conversation schema. - if (Object.prototype.hasOwnProperty.call(schema.obj, 'messages')) { - const convoIndex = client.index('convos'); - const deletedConvos = await mongoose.model('Conversation').find(this._conditions).lean(); - const promises = deletedConvos.map((convo) => - convoIndex.deleteDocument(convo.conversationId), - ); - await Promise.all(promises); - } - - // Check if the schema has a "messageId" field to determine if it's a message schema. - if (Object.prototype.hasOwnProperty.call(schema.obj, 'messageId')) { - const messageIndex = client.index('messages'); - const deletedMessages = await mongoose.model('Message').find(this._conditions).lean(); - const promises = deletedMessages.map((message) => - messageIndex.deleteDocument(message.messageId), - ); - await Promise.all(promises); - } - return next(); - } catch (error) { - if (meiliEnabled) { - logger.error( - '[MeiliMongooseModel.deleteMany] There was an issue deleting conversation indexes upon deletion. Next startup may be slow due to syncing.', - error, - ); - } - return next(); - } - }); - - // Post-findOneAndUpdate hook: update MeiliSearch index after a document is updated via findOneAndUpdate. - schema.post('findOneAndUpdate', async function (doc) { - if (!meiliEnabled) { - return; - } - - // If the document is unfinished, do not update the index. - if (doc.unfinished) { - return; - } - - let meiliDoc; - // For conversation documents, try to fetch the document from the "convos" index. - if (doc.messages) { - try { - meiliDoc = await client.index('convos').getDocument(doc.conversationId); - } catch (error) { - logger.debug( - '[MeiliMongooseModel.findOneAndUpdate] Convo not found in MeiliSearch and will index ' + - doc.conversationId, - error, - ); - } - } - - // If the MeiliSearch document exists and the title is unchanged, do nothing. - if (meiliDoc && meiliDoc.title === doc.title) { - return; - } - - // Otherwise, trigger a post-save hook to synchronize the document. - doc.postSaveHook(); - }); -}; diff --git a/api/models/schema/convoSchema.js b/api/models/schema/convoSchema.js deleted file mode 100644 index 89cb9c80b5..0000000000 --- a/api/models/schema/convoSchema.js +++ /dev/null @@ -1,18 +0,0 @@ -const mongoose = require('mongoose'); -const mongoMeili = require('../plugins/mongoMeili'); - -const { convoSchema } = require('@librechat/data-schemas'); - -if (process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY) { - convoSchema.plugin(mongoMeili, { - host: process.env.MEILI_HOST, - apiKey: process.env.MEILI_MASTER_KEY, - /** Note: Will get created automatically if it doesn't exist already */ - indexName: 'convos', - primaryKey: 'conversationId', - }); -} - -const Conversation = mongoose.models.Conversation || mongoose.model('Conversation', convoSchema); - -module.exports = Conversation; diff --git a/api/models/schema/messageSchema.js b/api/models/schema/messageSchema.js deleted file mode 100644 index cf97b84eea..0000000000 --- a/api/models/schema/messageSchema.js +++ /dev/null @@ -1,16 +0,0 @@ -const mongoose = require('mongoose'); -const mongoMeili = require('~/models/plugins/mongoMeili'); -const { messageSchema } = require('@librechat/data-schemas'); - -if (process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY) { - messageSchema.plugin(mongoMeili, { - host: process.env.MEILI_HOST, - apiKey: process.env.MEILI_MASTER_KEY, - indexName: 'messages', - primaryKey: 'messageId', - }); -} - -const Message = mongoose.models.Message || mongoose.model('Message', messageSchema); - -module.exports = Message; diff --git a/api/models/schema/pluginAuthSchema.js b/api/models/schema/pluginAuthSchema.js deleted file mode 100644 index 2066eda4c4..0000000000 --- a/api/models/schema/pluginAuthSchema.js +++ /dev/null @@ -1,6 +0,0 @@ -const mongoose = require('mongoose'); -const { pluginAuthSchema } = require('@librechat/data-schemas'); - -const PluginAuth = mongoose.models.Plugin || mongoose.model('PluginAuth', pluginAuthSchema); - -module.exports = PluginAuth; diff --git a/api/models/schema/presetSchema.js b/api/models/schema/presetSchema.js deleted file mode 100644 index 6d03803ace..0000000000 --- a/api/models/schema/presetSchema.js +++ /dev/null @@ -1,6 +0,0 @@ -const mongoose = require('mongoose'); -const { presetSchema } = require('@librechat/data-schemas'); - -const Preset = mongoose.models.Preset || mongoose.model('Preset', presetSchema); - -module.exports = Preset; diff --git a/api/models/spendTokens.js b/api/models/spendTokens.js index 36b71ca9fc..834f740926 100644 --- a/api/models/spendTokens.js +++ b/api/models/spendTokens.js @@ -1,6 +1,5 @@ -const { Transaction } = require('./Transaction'); const { logger } = require('~/config'); - +const { createTransaction, createStructuredTransaction } = require('./Transaction'); /** * Creates up to two transactions to record the spending of tokens. * @@ -33,7 +32,7 @@ const spendTokens = async (txData, tokenUsage) => { let prompt, completion; try { if (promptTokens !== undefined) { - prompt = await Transaction.create({ + prompt = await createTransaction({ ...txData, tokenType: 'prompt', rawAmount: promptTokens === 0 ? 0 : -Math.max(promptTokens, 0), @@ -41,7 +40,7 @@ const spendTokens = async (txData, tokenUsage) => { } if (completionTokens !== undefined) { - completion = await Transaction.create({ + completion = await createTransaction({ ...txData, tokenType: 'completion', rawAmount: completionTokens === 0 ? 0 : -Math.max(completionTokens, 0), @@ -101,7 +100,7 @@ const spendStructuredTokens = async (txData, tokenUsage) => { try { if (promptTokens) { const { input = 0, write = 0, read = 0 } = promptTokens; - prompt = await Transaction.createStructured({ + prompt = await createStructuredTransaction({ ...txData, tokenType: 'prompt', inputTokens: -input, @@ -111,7 +110,7 @@ const spendStructuredTokens = async (txData, tokenUsage) => { } if (completionTokens) { - completion = await Transaction.create({ + completion = await createTransaction({ ...txData, tokenType: 'completion', rawAmount: -completionTokens, diff --git a/api/models/spendTokens.spec.js b/api/models/spendTokens.spec.js index eacf420330..7ee067e589 100644 --- a/api/models/spendTokens.spec.js +++ b/api/models/spendTokens.spec.js @@ -1,8 +1,9 @@ const mongoose = require('mongoose'); const { MongoMemoryServer } = require('mongodb-memory-server'); -const { Transaction } = require('./Transaction'); -const Balance = require('./Balance'); const { spendTokens, spendStructuredTokens } = require('./spendTokens'); +const { createTransaction, createAutoRefillTransaction } = require('./Transaction'); + +require('~/db/models'); // Mock the logger to prevent console output during tests jest.mock('~/config', () => ({ @@ -19,11 +20,15 @@ jest.mock('~/server/services/Config'); describe('spendTokens', () => { let mongoServer; let userId; + let Transaction; + let Balance; beforeAll(async () => { mongoServer = await MongoMemoryServer.create(); - const mongoUri = mongoServer.getUri(); - await mongoose.connect(mongoUri); + await mongoose.connect(mongoServer.getUri()); + + Transaction = mongoose.model('Transaction'); + Balance = mongoose.model('Balance'); }); afterAll(async () => { @@ -197,7 +202,7 @@ describe('spendTokens', () => { // Check that the transaction records show the adjusted values const transactionResults = await Promise.all( transactions.map((t) => - Transaction.create({ + createTransaction({ ...txData, tokenType: t.tokenType, rawAmount: t.rawAmount, @@ -280,7 +285,7 @@ describe('spendTokens', () => { // Check the return values from Transaction.create directly // This is to verify that the incrementValue is not becoming positive - const directResult = await Transaction.create({ + const directResult = await createTransaction({ user: userId, conversationId: 'test-convo-3', model: 'gpt-4', @@ -607,7 +612,7 @@ describe('spendTokens', () => { const promises = []; for (let i = 0; i < numberOfRefills; i++) { promises.push( - Transaction.createAutoRefillTransaction({ + createAutoRefillTransaction({ user: userId, tokenType: 'credits', context: 'concurrent-refill-test', diff --git a/api/models/userMethods.js b/api/models/userMethods.js index fbcd33aba8..a36409ebcf 100644 --- a/api/models/userMethods.js +++ b/api/models/userMethods.js @@ -1,159 +1,4 @@ const bcrypt = require('bcryptjs'); -const { getBalanceConfig } = require('~/server/services/Config'); -const signPayload = require('~/server/services/signPayload'); -const Balance = require('./Balance'); -const User = require('./User'); - -/** - * Retrieve a user by ID and convert the found user document to a plain object. - * - * @param {string} userId - The ID of the user to find and return as a plain object. - * @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document. - * @returns {Promise} A plain object representing the user document, or `null` if no user is found. - */ -const getUserById = async function (userId, fieldsToSelect = null) { - const query = User.findById(userId); - if (fieldsToSelect) { - query.select(fieldsToSelect); - } - return await query.lean(); -}; - -/** - * Search for a single user based on partial data and return matching user document as plain object. - * @param {Partial} searchCriteria - The partial data to use for searching the user. - * @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document. - * @returns {Promise} A plain object representing the user document, or `null` if no user is found. - */ -const findUser = async function (searchCriteria, fieldsToSelect = null) { - const query = User.findOne(searchCriteria); - if (fieldsToSelect) { - query.select(fieldsToSelect); - } - return await query.lean(); -}; - -/** - * Update a user with new data without overwriting existing properties. - * - * @param {string} userId - The ID of the user to update. - * @param {Object} updateData - An object containing the properties to update. - * @returns {Promise} The updated user document as a plain object, or `null` if no user is found. - */ -const updateUser = async function (userId, updateData) { - const updateOperation = { - $set: updateData, - $unset: { expiresAt: '' }, // Remove the expiresAt field to prevent TTL - }; - return await User.findByIdAndUpdate(userId, updateOperation, { - new: true, - runValidators: true, - }).lean(); -}; - -/** - * Creates a new user, optionally with a TTL of 1 week. - * @param {MongoUser} data - The user data to be created, must contain user_id. - * @param {boolean} [disableTTL=true] - Whether to disable the TTL. Defaults to `true`. - * @param {boolean} [returnUser=false] - Whether to return the created user object. - * @returns {Promise} A promise that resolves to the created user document ID or user object. - * @throws {Error} If a user with the same user_id already exists. - */ -const createUser = async (data, disableTTL = true, returnUser = false) => { - const balance = await getBalanceConfig(); - const userData = { - ...data, - expiresAt: disableTTL ? null : new Date(Date.now() + 604800 * 1000), // 1 week in milliseconds - }; - - if (disableTTL) { - delete userData.expiresAt; - } - - const user = await User.create(userData); - - // If balance is enabled, create or update a balance record for the user using global.interfaceConfig.balance - if (balance?.enabled && balance?.startBalance) { - const update = { - $inc: { tokenCredits: balance.startBalance }, - }; - - if ( - balance.autoRefillEnabled && - balance.refillIntervalValue != null && - balance.refillIntervalUnit != null && - balance.refillAmount != null - ) { - update.$set = { - autoRefillEnabled: true, - refillIntervalValue: balance.refillIntervalValue, - refillIntervalUnit: balance.refillIntervalUnit, - refillAmount: balance.refillAmount, - }; - } - - await Balance.findOneAndUpdate({ user: user._id }, update, { upsert: true, new: true }).lean(); - } - - if (returnUser) { - return user.toObject(); - } - return user._id; -}; - -/** - * Count the number of user documents in the collection based on the provided filter. - * - * @param {Object} [filter={}] - The filter to apply when counting the documents. - * @returns {Promise} The count of documents that match the filter. - */ -const countUsers = async function (filter = {}) { - return await User.countDocuments(filter); -}; - -/** - * Delete a user by their unique ID. - * - * @param {string} userId - The ID of the user to delete. - * @returns {Promise<{ deletedCount: number }>} An object indicating the number of deleted documents. - */ -const deleteUserById = async function (userId) { - try { - const result = await User.deleteOne({ _id: userId }); - if (result.deletedCount === 0) { - return { deletedCount: 0, message: 'No user found with that ID.' }; - } - return { deletedCount: result.deletedCount, message: 'User was deleted successfully.' }; - } catch (error) { - throw new Error('Error deleting user: ' + error.message); - } -}; - -const { SESSION_EXPIRY } = process.env ?? {}; -const expires = eval(SESSION_EXPIRY) ?? 1000 * 60 * 15; - -/** - * Generates a JWT token for a given user. - * - * @param {MongoUser} user - The user for whom the token is being generated. - * @returns {Promise} A promise that resolves to a JWT token. - */ -const generateToken = async (user) => { - if (!user) { - throw new Error('No user provided'); - } - - return await signPayload({ - payload: { - id: user._id, - username: user.username, - provider: user.provider, - email: user.email, - }, - secret: process.env.JWT_SECRET, - expirationTime: expires / 1000, - }); -}; /** * Compares the provided password with the user's password. @@ -167,6 +12,10 @@ const comparePassword = async (user, candidatePassword) => { throw new Error('No user provided'); } + if (!user.password) { + throw new Error('No password, likely an email first registered via Social/OIDC login'); + } + return new Promise((resolve, reject) => { bcrypt.compare(candidatePassword, user.password, (err, isMatch) => { if (err) { @@ -179,11 +28,4 @@ const comparePassword = async (user, candidatePassword) => { module.exports = { comparePassword, - deleteUserById, - generateToken, - getUserById, - countUsers, - createUser, - updateUser, - findUser, }; diff --git a/api/package.json b/api/package.json index 64903553ef..636e8cb8f3 100644 --- a/api/package.json +++ b/api/package.json @@ -48,8 +48,10 @@ "@langchain/google-genai": "^0.2.9", "@langchain/google-vertexai": "^0.2.9", "@langchain/textsplitters": "^0.1.0", - "@librechat/agents": "^2.4.37", + "@librechat/agents": "^2.4.38", + "@librechat/api": "*", "@librechat/data-schemas": "*", + "@node-saml/passport-saml": "^5.0.0", "@waylaidwanderer/fetch-event-source": "^3.0.1", "axios": "^1.8.2", "bcryptjs": "^2.4.3", @@ -80,15 +82,15 @@ "keyv-file": "^5.1.2", "klona": "^2.0.6", "librechat-data-provider": "*", - "librechat-mcp": "*", "lodash": "^4.17.21", "meilisearch": "^0.38.0", "memorystore": "^1.6.7", "mime": "^3.0.0", "module-alias": "^2.2.3", "mongoose": "^8.12.1", - "multer": "^2.0.0", + "multer": "^2.0.1", "nanoid": "^3.3.7", + "node-fetch": "^2.7.0", "nodemailer": "^6.9.15", "ollama": "^0.5.0", "openai": "^4.96.2", @@ -108,8 +110,9 @@ "tiktoken": "^1.0.15", "traverse": "^0.6.7", "ua-parser-js": "^1.0.36", + "undici": "^7.10.0", "winston": "^3.11.0", - "winston-daily-rotate-file": "^4.7.1", + "winston-daily-rotate-file": "^5.0.0", "youtube-transcript": "^1.2.1", "zod": "^3.22.4" }, diff --git a/api/server/cleanup.js b/api/server/cleanup.js index 5bf336eed5..de7450cea0 100644 --- a/api/server/cleanup.js +++ b/api/server/cleanup.js @@ -220,6 +220,9 @@ function disposeClient(client) { if (client.maxResponseTokens) { client.maxResponseTokens = null; } + if (client.processMemory) { + client.processMemory = null; + } if (client.run) { // Break circular references in run if (client.run.Graph) { diff --git a/api/server/controllers/AuthController.js b/api/server/controllers/AuthController.js index a71ce7d59a..0f8152de3e 100644 --- a/api/server/controllers/AuthController.js +++ b/api/server/controllers/AuthController.js @@ -1,6 +1,7 @@ -const openIdClient = require('openid-client'); const cookies = require('cookie'); const jwt = require('jsonwebtoken'); +const openIdClient = require('openid-client'); +const { logger } = require('@librechat/data-schemas'); const { registerUser, resetPassword, @@ -8,9 +9,8 @@ const { requestPasswordReset, setOpenIDAuthTokens, } = require('~/server/services/AuthService'); -const { findSession, getUserById, deleteAllUserSessions, findUser } = require('~/models'); +const { findUser, getUserById, deleteAllUserSessions, findSession } = require('~/models'); const { getOpenIdConfig } = require('~/strategies'); -const { logger } = require('~/config'); const { isEnabled } = require('~/server/utils'); const registrationController = async (req, res) => { @@ -96,7 +96,10 @@ const refreshController = async (req, res) => { } // Find the session with the hashed refresh token - const session = await findSession({ userId: userId, refreshToken: refreshToken }); + const session = await findSession({ + userId: userId, + refreshToken: refreshToken, + }); if (session && session.expiration > new Date()) { const token = await setAuthTokens(userId, res, session._id); diff --git a/api/server/controllers/Balance.js b/api/server/controllers/Balance.js index 729afc7684..c892a73b0c 100644 --- a/api/server/controllers/Balance.js +++ b/api/server/controllers/Balance.js @@ -1,9 +1,24 @@ -const Balance = require('~/models/Balance'); +const { Balance } = require('~/db/models'); async function balanceController(req, res) { - const { tokenCredits: balance = '' } = - (await Balance.findOne({ user: req.user.id }, 'tokenCredits').lean()) ?? {}; - res.status(200).send('' + balance); + const balanceData = await Balance.findOne( + { user: req.user.id }, + '-_id tokenCredits autoRefillEnabled refillIntervalValue refillIntervalUnit lastRefill refillAmount', + ).lean(); + + if (!balanceData) { + return res.status(404).json({ error: 'Balance not found' }); + } + + // If auto-refill is not enabled, remove auto-refill related fields from the response + if (!balanceData.autoRefillEnabled) { + delete balanceData.refillIntervalValue; + delete balanceData.refillIntervalUnit; + delete balanceData.lastRefill; + delete balanceData.refillAmount; + } + + res.status(200).json(balanceData); } module.exports = balanceController; diff --git a/api/server/controllers/TwoFactorController.js b/api/server/controllers/TwoFactorController.js index f5783f45ad..6e22db2e5c 100644 --- a/api/server/controllers/TwoFactorController.js +++ b/api/server/controllers/TwoFactorController.js @@ -1,12 +1,12 @@ +const { logger } = require('@librechat/data-schemas'); const { + verifyTOTP, + getTOTPSecret, + verifyBackupCode, generateTOTPSecret, generateBackupCodes, - verifyTOTP, - verifyBackupCode, - getTOTPSecret, } = require('~/server/services/twoFactorService'); -const { updateUser, getUserById } = require('~/models'); -const { logger } = require('~/config'); +const { getUserById, updateUser } = require('~/models'); const { encryptV3 } = require('~/server/utils/crypto'); const safeAppTitle = (process.env.APP_TITLE || 'LibreChat').replace(/\s+/g, ''); diff --git a/api/server/controllers/UserController.js b/api/server/controllers/UserController.js index 571c454552..bcffb2189c 100644 --- a/api/server/controllers/UserController.js +++ b/api/server/controllers/UserController.js @@ -1,12 +1,11 @@ const { Tools, - Constants, FileSources, webSearchKeys, extractWebSearchEnvVars, } = require('librechat-data-provider'); +const { logger } = require('@librechat/data-schemas'); const { - Balance, getFiles, updateUser, deleteFiles, @@ -16,16 +15,14 @@ const { deleteUserById, deleteAllUserSessions, } = require('~/models'); -const User = require('~/models/User'); const { updateUserPluginAuth, deleteUserPluginAuth } = require('~/server/services/PluginService'); const { updateUserPluginsService, deleteUserKey } = require('~/server/services/UserService'); const { verifyEmail, resendVerificationEmail } = require('~/server/services/AuthService'); const { needsRefresh, getNewS3URL } = require('~/server/services/Files/S3/crud'); const { processDeleteRequest } = require('~/server/services/Files/process'); -const { deleteAllSharedLinks } = require('~/models/Share'); +const { Transaction, Balance, User } = require('~/db/models'); const { deleteToolCalls } = require('~/models/ToolCall'); -const { Transaction } = require('~/models/Transaction'); -const { logger } = require('~/config'); +const { deleteAllSharedLinks } = require('~/models'); const getUserController = async (req, res) => { /** @type {MongoUser} */ @@ -166,7 +163,11 @@ const deleteUserController = async (req, res) => { await Balance.deleteMany({ user: user._id }); // delete user balances await deletePresets(user.id); // delete user presets /* TODO: Delete Assistant Threads */ - await deleteConvos(user.id); // delete user convos + try { + await deleteConvos(user.id); // delete user convos + } catch (error) { + logger.error('[deleteUserController] Error deleting user convos, likely no convos', error); + } await deleteUserPluginAuth(user.id, null, true); // delete user plugin auth await deleteUserById(user.id); // delete user await deleteAllSharedLinks(user.id); // delete user shared links diff --git a/api/server/controllers/agents/callbacks.js b/api/server/controllers/agents/callbacks.js index cedfc6bd62..60e68b5f2d 100644 --- a/api/server/controllers/agents/callbacks.js +++ b/api/server/controllers/agents/callbacks.js @@ -1,4 +1,6 @@ const { nanoid } = require('nanoid'); +const { sendEvent } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { Tools, StepTypes, FileContext } = require('librechat-data-provider'); const { EnvVar, @@ -12,7 +14,6 @@ const { const { processCodeOutput } = require('~/server/services/Files/Code/process'); const { loadAuthValues } = require('~/server/services/Tools/credentials'); const { saveBase64Image } = require('~/server/services/Files/process'); -const { logger, sendEvent } = require('~/config'); class ModelEndHandler { /** @@ -240,9 +241,7 @@ function createToolEndCallback({ req, res, artifactPromises }) { if (output.artifact[Tools.web_search]) { artifactPromises.push( (async () => { - const name = `${output.name}_${output.tool_call_id}_${nanoid()}`; const attachment = { - name, type: Tools.web_search, messageId: metadata.run_id, toolCallId: output.tool_call_id, diff --git a/api/server/controllers/agents/client.js b/api/server/controllers/agents/client.js index 31fd56930e..41e457e5b8 100644 --- a/api/server/controllers/agents/client.js +++ b/api/server/controllers/agents/client.js @@ -1,13 +1,12 @@ -// const { HttpsProxyAgent } = require('https-proxy-agent'); -// const { -// Constants, -// ImageDetail, -// EModelEndpoint, -// resolveHeaders, -// validateVisionModel, -// mapModelToAzureConfig, -// } = require('librechat-data-provider'); require('events').EventEmitter.defaultMaxListeners = 100; +const { logger } = require('@librechat/data-schemas'); +const { + sendEvent, + createRun, + Tokenizer, + memoryInstructions, + createMemoryProcessor, +} = require('@librechat/api'); const { Callback, GraphEvents, @@ -19,25 +18,30 @@ const { } = require('@librechat/agents'); const { Constants, + Permissions, VisionModes, ContentTypes, EModelEndpoint, KnownEndpoints, + PermissionTypes, isAgentsEndpoint, AgentCapabilities, bedrockInputSchema, removeNullishValues, } = require('librechat-data-provider'); +const { DynamicStructuredTool } = require('@langchain/core/tools'); +const { getBufferString, HumanMessage } = require('@langchain/core/messages'); const { getCustomEndpointConfig, checkCapability } = require('~/server/services/Config'); const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts'); +const { initializeAgent } = require('~/server/services/Endpoints/agents/agent'); const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); -const { getBufferString, HumanMessage } = require('@langchain/core/messages'); +const { setMemory, deleteMemory, getFormattedMemories } = require('~/models'); const { encodeAndFormat } = require('~/server/services/Files/images/encode'); const initOpenAI = require('~/server/services/Endpoints/openAI/initialize'); -const Tokenizer = require('~/server/services/Tokenizer'); +const { checkAccess } = require('~/server/middleware/roles/access'); const BaseClient = require('~/app/clients/BaseClient'); -const { logger, sendEvent } = require('~/config'); -const { createRun } = require('./run'); +const { loadAgent } = require('~/models/Agent'); +const { getMCPManager } = require('~/config'); /** * @param {ServerRequest} req @@ -57,12 +61,8 @@ const legacyContentEndpoints = new Set([KnownEndpoints.groq, KnownEndpoints.deep const noSystemModelRegex = [/\b(o1-preview|o1-mini|amazon\.titan-text)\b/gi]; -// const { processMemory, memoryInstructions } = require('~/server/services/Endpoints/agents/memory'); -// const { getFormattedMemories } = require('~/models/Memory'); -// const { getCurrentDateTime } = require('~/utils'); - function createTokenCounter(encoding) { - return (message) => { + return function (message) { const countTokens = (text) => Tokenizer.getTokenCount(text, encoding); return getTokenCountForMessage(message, countTokens); }; @@ -123,6 +123,8 @@ class AgentClient extends BaseClient { this.usage; /** @type {Record} */ this.indexTokenCountMap = {}; + /** @type {(messages: BaseMessage[]) => Promise} */ + this.processMemory; } /** @@ -137,55 +139,10 @@ class AgentClient extends BaseClient { } /** - * - * Checks if the model is a vision model based on request attachments and sets the appropriate options: - * - Sets `this.modelOptions.model` to `gpt-4-vision-preview` if the request is a vision request. - * - Sets `this.isVisionModel` to `true` if vision request. - * - Deletes `this.modelOptions.stop` if vision request. + * `AgentClient` is not opinionated about vision requests, so we don't do anything here * @param {MongoFile[]} attachments */ - checkVisionRequest(attachments) { - // if (!attachments) { - // return; - // } - // const availableModels = this.options.modelsConfig?.[this.options.endpoint]; - // if (!availableModels) { - // return; - // } - // let visionRequestDetected = false; - // for (const file of attachments) { - // if (file?.type?.includes('image')) { - // visionRequestDetected = true; - // break; - // } - // } - // if (!visionRequestDetected) { - // return; - // } - // this.isVisionModel = validateVisionModel({ model: this.modelOptions.model, availableModels }); - // if (this.isVisionModel) { - // delete this.modelOptions.stop; - // return; - // } - // for (const model of availableModels) { - // if (!validateVisionModel({ model, availableModels })) { - // continue; - // } - // this.modelOptions.model = model; - // this.isVisionModel = true; - // delete this.modelOptions.stop; - // return; - // } - // if (!availableModels.includes(this.defaultVisionModel)) { - // return; - // } - // if (!validateVisionModel({ model: this.defaultVisionModel, availableModels })) { - // return; - // } - // this.modelOptions.model = this.defaultVisionModel; - // this.isVisionModel = true; - // delete this.modelOptions.stop; - } + checkVisionRequest() {} getSaveOptions() { // TODO: @@ -269,24 +226,6 @@ class AgentClient extends BaseClient { .filter(Boolean) .join('\n') .trim(); - // this.systemMessage = getCurrentDateTime(); - // const { withKeys, withoutKeys } = await getFormattedMemories({ - // userId: this.options.req.user.id, - // }); - // processMemory({ - // userId: this.options.req.user.id, - // message: this.options.req.body.text, - // parentMessageId, - // memory: withKeys, - // thread_id: this.conversationId, - // }).catch((error) => { - // logger.error('Memory Agent failed to process memory', error); - // }); - - // this.systemMessage += '\n\n' + memoryInstructions; - // if (withoutKeys) { - // this.systemMessage += `\n\n# Existing memory about the user:\n${withoutKeys}`; - // } if (this.options.attachments) { const attachments = await this.options.attachments; @@ -370,6 +309,37 @@ class AgentClient extends BaseClient { systemContent = this.augmentedPrompt + systemContent; } + // Inject MCP server instructions if available + const ephemeralAgent = this.options.req.body.ephemeralAgent; + let mcpServers = []; + + // Check for ephemeral agent MCP servers + if (ephemeralAgent && ephemeralAgent.mcp && ephemeralAgent.mcp.length > 0) { + mcpServers = ephemeralAgent.mcp; + } + // Check for regular agent MCP tools + else if (this.options.agent && this.options.agent.tools) { + mcpServers = this.options.agent.tools + .filter( + (tool) => + tool instanceof DynamicStructuredTool && tool.name.includes(Constants.mcp_delimiter), + ) + .map((tool) => tool.name.split(Constants.mcp_delimiter).pop()) + .filter(Boolean); + } + + if (mcpServers.length > 0) { + try { + const mcpInstructions = getMCPManager().formatInstructionsForContext(mcpServers); + if (mcpInstructions) { + systemContent = [systemContent, mcpInstructions].filter(Boolean).join('\n\n'); + logger.debug('[AgentClient] Injected MCP instructions for servers:', mcpServers); + } + } catch (error) { + logger.error('[AgentClient] Failed to inject MCP instructions:', error); + } + } + if (systemContent) { this.options.agent.instructions = systemContent; } @@ -399,9 +369,150 @@ class AgentClient extends BaseClient { opts.getReqData({ promptTokens }); } + const withoutKeys = await this.useMemory(); + if (withoutKeys) { + systemContent += `${memoryInstructions}\n\n# Existing memory about the user:\n${withoutKeys}`; + } + + if (systemContent) { + this.options.agent.instructions = systemContent; + } + return result; } + /** + * @returns {Promise} + */ + async useMemory() { + const user = this.options.req.user; + if (user.personalization?.memories === false) { + return; + } + const hasAccess = await checkAccess(user, PermissionTypes.MEMORIES, [Permissions.USE]); + + if (!hasAccess) { + logger.debug( + `[api/server/controllers/agents/client.js #useMemory] User ${user.id} does not have USE permission for memories`, + ); + return; + } + /** @type {TCustomConfig['memory']} */ + const memoryConfig = this.options.req?.app?.locals?.memory; + if (!memoryConfig || memoryConfig.disabled === true) { + return; + } + + /** @type {Agent} */ + let prelimAgent; + const allowedProviders = new Set( + this.options.req?.app?.locals?.[EModelEndpoint.agents]?.allowedProviders, + ); + try { + if (memoryConfig.agent?.id != null && memoryConfig.agent.id !== this.options.agent.id) { + prelimAgent = await loadAgent({ + req: this.options.req, + agent_id: memoryConfig.agent.id, + endpoint: EModelEndpoint.agents, + }); + } else if ( + memoryConfig.agent?.id == null && + memoryConfig.agent?.model != null && + memoryConfig.agent?.provider != null + ) { + prelimAgent = { id: Constants.EPHEMERAL_AGENT_ID, ...memoryConfig.agent }; + } + } catch (error) { + logger.error( + '[api/server/controllers/agents/client.js #useMemory] Error loading agent for memory', + error, + ); + } + + const agent = await initializeAgent({ + req: this.options.req, + res: this.options.res, + agent: prelimAgent, + allowedProviders, + }); + + if (!agent) { + logger.warn( + '[api/server/controllers/agents/client.js #useMemory] No agent found for memory', + memoryConfig, + ); + return; + } + + const llmConfig = Object.assign( + { + provider: agent.provider, + model: agent.model, + }, + agent.model_parameters, + ); + + /** @type {import('@librechat/api').MemoryConfig} */ + const config = { + validKeys: memoryConfig.validKeys, + instructions: agent.instructions, + llmConfig, + tokenLimit: memoryConfig.tokenLimit, + }; + + const userId = this.options.req.user.id + ''; + const messageId = this.responseMessageId + ''; + const conversationId = this.conversationId + ''; + const [withoutKeys, processMemory] = await createMemoryProcessor({ + userId, + config, + messageId, + conversationId, + memoryMethods: { + setMemory, + deleteMemory, + getFormattedMemories, + }, + res: this.options.res, + }); + + this.processMemory = processMemory; + return withoutKeys; + } + + /** + * @param {BaseMessage[]} messages + * @returns {Promise} + */ + async runMemory(messages) { + try { + if (this.processMemory == null) { + return; + } + /** @type {TCustomConfig['memory']} */ + const memoryConfig = this.options.req?.app?.locals?.memory; + const messageWindowSize = memoryConfig?.messageWindowSize ?? 5; + + let messagesToProcess = [...messages]; + if (messages.length > messageWindowSize) { + for (let i = messages.length - messageWindowSize; i >= 0; i--) { + const potentialWindow = messages.slice(i, i + messageWindowSize); + if (potentialWindow[0]?.role === 'user') { + messagesToProcess = [...potentialWindow]; + break; + } + } + + if (messagesToProcess.length === messages.length) { + messagesToProcess = [...messages.slice(-messageWindowSize)]; + } + } + return await this.processMemory(messagesToProcess); + } catch (error) { + logger.error('Memory Agent failed to process memory', error); + } + } + /** @type {sendCompletion} */ async sendCompletion(payload, opts = {}) { await this.chatCompletion({ @@ -544,100 +655,13 @@ class AgentClient extends BaseClient { let config; /** @type {ReturnType} */ let run; + /** @type {Promise<(TAttachment | null)[] | undefined>} */ + let memoryPromise; try { if (!abortController) { abortController = new AbortController(); } - // if (this.options.headers) { - // opts.defaultHeaders = { ...opts.defaultHeaders, ...this.options.headers }; - // } - - // if (this.options.proxy) { - // opts.httpAgent = new HttpsProxyAgent(this.options.proxy); - // } - - // if (this.isVisionModel) { - // modelOptions.max_tokens = 4000; - // } - - // /** @type {TAzureConfig | undefined} */ - // const azureConfig = this.options?.req?.app?.locals?.[EModelEndpoint.azureOpenAI]; - - // if ( - // (this.azure && this.isVisionModel && azureConfig) || - // (azureConfig && this.isVisionModel && this.options.endpoint === EModelEndpoint.azureOpenAI) - // ) { - // const { modelGroupMap, groupMap } = azureConfig; - // const { - // azureOptions, - // baseURL, - // headers = {}, - // serverless, - // } = mapModelToAzureConfig({ - // modelName: modelOptions.model, - // modelGroupMap, - // groupMap, - // }); - // opts.defaultHeaders = resolveHeaders(headers); - // this.langchainProxy = extractBaseURL(baseURL); - // this.apiKey = azureOptions.azureOpenAIApiKey; - - // const groupName = modelGroupMap[modelOptions.model].group; - // this.options.addParams = azureConfig.groupMap[groupName].addParams; - // this.options.dropParams = azureConfig.groupMap[groupName].dropParams; - // // Note: `forcePrompt` not re-assigned as only chat models are vision models - - // this.azure = !serverless && azureOptions; - // this.azureEndpoint = - // !serverless && genAzureChatCompletion(this.azure, modelOptions.model, this); - // } - - // if (this.azure || this.options.azure) { - // /* Azure Bug, extremely short default `max_tokens` response */ - // if (!modelOptions.max_tokens && modelOptions.model === 'gpt-4-vision-preview') { - // modelOptions.max_tokens = 4000; - // } - - // /* Azure does not accept `model` in the body, so we need to remove it. */ - // delete modelOptions.model; - - // opts.baseURL = this.langchainProxy - // ? constructAzureURL({ - // baseURL: this.langchainProxy, - // azureOptions: this.azure, - // }) - // : this.azureEndpoint.split(/(? { - // delete modelOptions[param]; - // }); - // logger.debug('[api/server/controllers/agents/client.js #chatCompletion] dropped params', { - // dropParams: this.options.dropParams, - // modelOptions, - // }); - // } - /** @type {TCustomConfig['endpoints']['agents']} */ const agentsEConfig = this.options.req.app.locals[EModelEndpoint.agents]; @@ -647,6 +671,7 @@ class AgentClient extends BaseClient { last_agent_index: this.agentConfigs?.size ?? 0, user_id: this.user ?? this.options.req.user?.id, hide_sequential_outputs: this.options.agent.hide_sequential_outputs, + user: this.options.req.user, }, recursionLimit: agentsEConfig?.recursionLimit, signal: abortController.signal, @@ -734,6 +759,10 @@ class AgentClient extends BaseClient { messages = addCacheControl(messages); } + if (i === 0) { + memoryPromise = this.runMemory(messages); + } + run = await createRun({ agent, req: this.options.req, @@ -769,10 +798,9 @@ class AgentClient extends BaseClient { run.Graph.contentData = contentData; } - const encoding = this.getEncoding(); await run.processStream({ messages }, config, { keepContent: i !== 0, - tokenCounter: createTokenCounter(encoding), + tokenCounter: createTokenCounter(this.getEncoding()), indexTokenCountMap: currentIndexCountMap, maxContextTokens: agent.maxContextTokens, callbacks: { @@ -887,6 +915,12 @@ class AgentClient extends BaseClient { }); try { + if (memoryPromise) { + const attachments = await memoryPromise; + if (attachments && attachments.length > 0) { + this.artifactPromises.push(...attachments); + } + } await this.recordCollectedUsage({ context: 'message' }); } catch (err) { logger.error( @@ -895,6 +929,12 @@ class AgentClient extends BaseClient { ); } } catch (err) { + if (memoryPromise) { + const attachments = await memoryPromise; + if (attachments && attachments.length > 0) { + this.artifactPromises.push(...attachments); + } + } logger.error( '[api/server/controllers/agents/client.js #sendCompletion] Operation aborted', err, diff --git a/api/server/controllers/agents/request.js b/api/server/controllers/agents/request.js index fcee62edc7..24b7822c1f 100644 --- a/api/server/controllers/agents/request.js +++ b/api/server/controllers/agents/request.js @@ -228,7 +228,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { // Save user message if needed if (!client.skipSaveUserMessage) { await saveMessage(req, userMessage, { - context: 'api/server/controllers/agents/request.js - don\'t skip saving user message', + context: "api/server/controllers/agents/request.js - don't skip saving user message", }); } diff --git a/api/server/controllers/agents/run.js b/api/server/controllers/agents/run.js deleted file mode 100644 index 2452e66233..0000000000 --- a/api/server/controllers/agents/run.js +++ /dev/null @@ -1,94 +0,0 @@ -const { Run, Providers } = require('@librechat/agents'); -const { providerEndpointMap, KnownEndpoints } = require('librechat-data-provider'); - -/** - * @typedef {import('@librechat/agents').t} t - * @typedef {import('@librechat/agents').StandardGraphConfig} StandardGraphConfig - * @typedef {import('@librechat/agents').StreamEventData} StreamEventData - * @typedef {import('@librechat/agents').EventHandler} EventHandler - * @typedef {import('@librechat/agents').GraphEvents} GraphEvents - * @typedef {import('@librechat/agents').LLMConfig} LLMConfig - * @typedef {import('@librechat/agents').IState} IState - */ - -const customProviders = new Set([ - Providers.XAI, - Providers.OLLAMA, - Providers.DEEPSEEK, - Providers.OPENROUTER, -]); - -/** - * Creates a new Run instance with custom handlers and configuration. - * - * @param {Object} options - The options for creating the Run instance. - * @param {ServerRequest} [options.req] - The server request. - * @param {string | undefined} [options.runId] - Optional run ID; otherwise, a new run ID will be generated. - * @param {Agent} options.agent - The agent for this run. - * @param {AbortSignal} options.signal - The signal for this run. - * @param {Record | undefined} [options.customHandlers] - Custom event handlers. - * @param {boolean} [options.streaming=true] - Whether to use streaming. - * @param {boolean} [options.streamUsage=true] - Whether to stream usage information. - * @returns {Promise>} A promise that resolves to a new Run instance. - */ -async function createRun({ - runId, - agent, - signal, - customHandlers, - streaming = true, - streamUsage = true, -}) { - const provider = providerEndpointMap[agent.provider] ?? agent.provider; - /** @type {LLMConfig} */ - const llmConfig = Object.assign( - { - provider, - streaming, - streamUsage, - }, - agent.model_parameters, - ); - - /** Resolves issues with new OpenAI usage field */ - if ( - customProviders.has(agent.provider) || - (agent.provider === Providers.OPENAI && agent.endpoint !== agent.provider) - ) { - llmConfig.streamUsage = false; - llmConfig.usage = true; - } - - /** @type {'reasoning_content' | 'reasoning'} */ - let reasoningKey; - if ( - llmConfig.configuration?.baseURL?.includes(KnownEndpoints.openrouter) || - (agent.endpoint && agent.endpoint.toLowerCase().includes(KnownEndpoints.openrouter)) - ) { - reasoningKey = 'reasoning'; - } - - /** @type {StandardGraphConfig} */ - const graphConfig = { - signal, - llmConfig, - reasoningKey, - tools: agent.tools, - instructions: agent.instructions, - additional_instructions: agent.additional_instructions, - // toolEnd: agent.end_after_tools, - }; - - // TEMPORARY FOR TESTING - if (agent.provider === Providers.ANTHROPIC || agent.provider === Providers.BEDROCK) { - graphConfig.streamBuffer = 2000; - } - - return Run.create({ - runId, - graphConfig, - customHandlers, - }); -} - -module.exports = { createRun }; diff --git a/api/server/controllers/agents/v1.js b/api/server/controllers/agents/v1.js index 1799913b68..38a058b540 100644 --- a/api/server/controllers/agents/v1.js +++ b/api/server/controllers/agents/v1.js @@ -18,6 +18,7 @@ const { } = require('~/models/Agent'); const { uploadImageBuffer, filterFile } = require('~/server/services/Files/process'); const { getStrategyFunctions } = require('~/server/services/Files/strategies'); +const { resizeAvatar } = require('~/server/services/Files/images/avatar'); const { refreshS3Url } = require('~/server/services/Files/S3/crud'); const { updateAction, getActions } = require('~/models/Action'); const { updateAgentProjects } = require('~/models/Agent'); @@ -111,7 +112,7 @@ const getAgentHandler = async (req, res) => { const originalUrl = agent.avatar.filepath; agent.avatar.filepath = await refreshS3Url(agent.avatar); if (originalUrl !== agent.avatar.filepath) { - await updateAgent({ id }, { avatar: agent.avatar }, req.user.id); + await updateAgent({ id }, { avatar: agent.avatar }, { updatingUserId: req.user.id }); } } @@ -168,12 +169,18 @@ const updateAgentHandler = async (req, res) => { }); } + /** @type {boolean} */ + const isProjectUpdate = (projectIds?.length ?? 0) > 0 || (removeProjectIds?.length ?? 0) > 0; + let updatedAgent = Object.keys(updateData).length > 0 - ? await updateAgent({ id }, updateData, req.user.id) + ? await updateAgent({ id }, updateData, { + updatingUserId: req.user.id, + skipVersioning: isProjectUpdate, + }) : existingAgent; - if (projectIds || removeProjectIds) { + if (isProjectUpdate) { updatedAgent = await updateAgentProjects({ user: req.user, agentId: id, @@ -373,12 +380,27 @@ const uploadAgentAvatarHandler = async (req, res) => { } const buffer = await fs.readFile(req.file.path); - const image = await uploadImageBuffer({ - req, - context: FileContext.avatar, - metadata: { buffer }, + + const fileStrategy = req.app.locals.fileStrategy; + + const resizedBuffer = await resizeAvatar({ + userId: req.user.id, + input: buffer, }); + const { processAvatar } = getStrategyFunctions(fileStrategy); + const avatarUrl = await processAvatar({ + buffer: resizedBuffer, + userId: req.user.id, + manual: 'false', + agentId: agent_id, + }); + + const image = { + filepath: avatarUrl, + source: fileStrategy, + }; + let _avatar; try { const agent = await getAgent({ id: agent_id }); @@ -403,11 +425,15 @@ const uploadAgentAvatarHandler = async (req, res) => { const data = { avatar: { filepath: image.filepath, - source: req.app.locals.fileStrategy, + source: image.source, }, }; - promises.push(await updateAgent({ id: agent_id, author: req.user.id }, data, req.user.id)); + promises.push( + await updateAgent({ id: agent_id, author: req.user.id }, data, { + updatingUserId: req.user.id, + }), + ); const resolved = await Promise.all(promises); res.status(201).json(resolved[0]); diff --git a/api/server/controllers/auth/TwoFactorAuthController.js b/api/server/controllers/auth/TwoFactorAuthController.js index 15cde8122a..b37c89a998 100644 --- a/api/server/controllers/auth/TwoFactorAuthController.js +++ b/api/server/controllers/auth/TwoFactorAuthController.js @@ -1,12 +1,12 @@ const jwt = require('jsonwebtoken'); +const { logger } = require('@librechat/data-schemas'); const { verifyTOTP, - verifyBackupCode, getTOTPSecret, + verifyBackupCode, } = require('~/server/services/twoFactorService'); const { setAuthTokens } = require('~/server/services/AuthService'); -const { getUserById } = require('~/models/userMethods'); -const { logger } = require('~/config'); +const { getUserById } = require('~/models'); /** * Verifies the 2FA code during login using a temporary token. diff --git a/api/server/index.js b/api/server/index.js index c7525f9b91..a04c339b0f 100644 --- a/api/server/index.js +++ b/api/server/index.js @@ -9,8 +9,9 @@ const passport = require('passport'); const mongoSanitize = require('express-mongo-sanitize'); const fs = require('fs'); const cookieParser = require('cookie-parser'); +const { connectDb, indexSync } = require('~/db'); + const { jwtLogin, passportLogin } = require('~/strategies'); -const { connectDb, indexSync } = require('~/lib/db'); const { isEnabled } = require('~/server/utils'); const { ldapLogin } = require('~/strategies'); const { logger } = require('~/config'); @@ -36,6 +37,7 @@ const startServer = async () => { axios.defaults.headers.common['Accept-Encoding'] = 'gzip'; } await connectDb(); + logger.info('Connected to MongoDB'); await indexSync(); @@ -115,7 +117,7 @@ const startServer = async () => { app.use('/api/agents', routes.agents); app.use('/api/banner', routes.banner); app.use('/api/bedrock', routes.bedrock); - + app.use('/api/memories', routes.memories); app.use('/api/tags', routes.tags); app.use((req, res) => { diff --git a/api/server/middleware/abortMiddleware.js b/api/server/middleware/abortMiddleware.js index bfc28f513d..94d69004bd 100644 --- a/api/server/middleware/abortMiddleware.js +++ b/api/server/middleware/abortMiddleware.js @@ -327,7 +327,7 @@ const handleAbortError = async (res, req, error, data) => { errorText = `{"type":"${ErrorTypes.INVALID_REQUEST}"}`; } - if (error?.message?.includes('does not support \'system\'')) { + if (error?.message?.includes("does not support 'system'")) { errorText = `{"type":"${ErrorTypes.NO_SYSTEM_MESSAGES}"}`; } diff --git a/api/server/middleware/checkBan.js b/api/server/middleware/checkBan.js index 4e0593192a..91c31ab66a 100644 --- a/api/server/middleware/checkBan.js +++ b/api/server/middleware/checkBan.js @@ -1,12 +1,12 @@ const { Keyv } = require('keyv'); const uap = require('ua-parser-js'); +const { logger } = require('@librechat/data-schemas'); const { ViolationTypes } = require('librechat-data-provider'); const { isEnabled, removePorts } = require('~/server/utils'); const keyvMongo = require('~/cache/keyvMongo'); const denyRequest = require('./denyRequest'); const { getLogStores } = require('~/cache'); const { findUser } = require('~/models'); -const { logger } = require('~/config'); const banCache = new Keyv({ store: keyvMongo, namespace: ViolationTypes.BAN, ttl: 0 }); const message = 'Your account has been temporarily banned due to violations of our service.'; diff --git a/api/server/middleware/checkInviteUser.js b/api/server/middleware/checkInviteUser.js index e1ad271b55..42e1faba5b 100644 --- a/api/server/middleware/checkInviteUser.js +++ b/api/server/middleware/checkInviteUser.js @@ -1,5 +1,5 @@ const { getInvite } = require('~/models/inviteUser'); -const { deleteTokens } = require('~/models/Token'); +const { deleteTokens } = require('~/models'); async function checkInviteUser(req, res, next) { const token = req.body.token; diff --git a/api/server/middleware/roles/generateCheckAccess.js b/api/server/middleware/roles/access.js similarity index 100% rename from api/server/middleware/roles/generateCheckAccess.js rename to api/server/middleware/roles/access.js diff --git a/api/server/middleware/roles/checkAdmin.js b/api/server/middleware/roles/admin.js similarity index 100% rename from api/server/middleware/roles/checkAdmin.js rename to api/server/middleware/roles/admin.js diff --git a/api/server/middleware/roles/index.js b/api/server/middleware/roles/index.js index a9fc5b2a08..ebc0043f2f 100644 --- a/api/server/middleware/roles/index.js +++ b/api/server/middleware/roles/index.js @@ -1,5 +1,5 @@ -const checkAdmin = require('./checkAdmin'); -const { checkAccess, generateCheckAccess } = require('./generateCheckAccess'); +const checkAdmin = require('./admin'); +const { checkAccess, generateCheckAccess } = require('./access'); module.exports = { checkAdmin, diff --git a/api/server/middleware/setBalanceConfig.js b/api/server/middleware/setBalanceConfig.js index 98d3cf1145..5dd9757965 100644 --- a/api/server/middleware/setBalanceConfig.js +++ b/api/server/middleware/setBalanceConfig.js @@ -1,6 +1,6 @@ +const { logger } = require('@librechat/data-schemas'); const { getBalanceConfig } = require('~/server/services/Config'); -const Balance = require('~/models/Balance'); -const { logger } = require('~/config'); +const { Balance } = require('~/db/models'); /** * Middleware to synchronize user balance settings with current balance configuration. diff --git a/api/server/middleware/validate/convoAccess.js b/api/server/middleware/validate/convoAccess.js index 43cca0097d..afd2aeacef 100644 --- a/api/server/middleware/validate/convoAccess.js +++ b/api/server/middleware/validate/convoAccess.js @@ -1,8 +1,8 @@ +const { isEnabled } = require('@librechat/api'); const { Constants, ViolationTypes, Time } = require('librechat-data-provider'); const { searchConversation } = require('~/models/Conversation'); const denyRequest = require('~/server/middleware/denyRequest'); const { logViolation, getLogStores } = require('~/cache'); -const { isEnabled } = require('~/server/utils'); const { USE_REDIS, CONVO_ACCESS_VIOLATION_SCORE: score = 0 } = process.env ?? {}; diff --git a/api/server/routes/__tests__/config.spec.js b/api/server/routes/__tests__/config.spec.js index 3280bc3864..054e4726f0 100644 --- a/api/server/routes/__tests__/config.spec.js +++ b/api/server/routes/__tests__/config.spec.js @@ -24,6 +24,12 @@ afterEach(() => { delete process.env.GITHUB_CLIENT_SECRET; delete process.env.DISCORD_CLIENT_ID; delete process.env.DISCORD_CLIENT_SECRET; + delete process.env.SAML_ENTRY_POINT; + delete process.env.SAML_ISSUER; + delete process.env.SAML_CERT; + delete process.env.SAML_SESSION_SECRET; + delete process.env.SAML_BUTTON_LABEL; + delete process.env.SAML_IMAGE_URL; delete process.env.DOMAIN_SERVER; delete process.env.ALLOW_REGISTRATION; delete process.env.ALLOW_SOCIAL_LOGIN; @@ -55,6 +61,12 @@ describe.skip('GET /', () => { process.env.GITHUB_CLIENT_SECRET = 'Test Github client Secret'; process.env.DISCORD_CLIENT_ID = 'Test Discord client Id'; process.env.DISCORD_CLIENT_SECRET = 'Test Discord client Secret'; + process.env.SAML_ENTRY_POINT = 'http://test-server.com'; + process.env.SAML_ISSUER = 'Test SAML Issuer'; + process.env.SAML_CERT = 'saml.pem'; + process.env.SAML_SESSION_SECRET = 'Test Secret'; + process.env.SAML_BUTTON_LABEL = 'Test SAML'; + process.env.SAML_IMAGE_URL = 'http://test-server.com'; process.env.DOMAIN_SERVER = 'http://test-server.com'; process.env.ALLOW_REGISTRATION = 'true'; process.env.ALLOW_SOCIAL_LOGIN = 'true'; @@ -70,7 +82,7 @@ describe.skip('GET /', () => { expect(response.statusCode).toBe(200); expect(response.body).toEqual({ appTitle: 'Test Title', - socialLogins: ['google', 'facebook', 'openid', 'github', 'discord'], + socialLogins: ['google', 'facebook', 'openid', 'github', 'discord', 'saml'], discordLoginEnabled: true, facebookLoginEnabled: true, githubLoginEnabled: true, @@ -78,6 +90,9 @@ describe.skip('GET /', () => { openidLoginEnabled: true, openidLabel: 'Test OpenID', openidImageUrl: 'http://test-server.com', + samlLoginEnabled: true, + samlLabel: 'Test SAML', + samlImageUrl: 'http://test-server.com', ldap: { enabled: true, }, diff --git a/api/server/routes/actions.js b/api/server/routes/actions.js index dc474d1a67..242e52e4ae 100644 --- a/api/server/routes/actions.js +++ b/api/server/routes/actions.js @@ -53,6 +53,7 @@ router.get('/:action_id/oauth/callback', async (req, res) => { identifier, client_url: flowState.metadata.client_url, redirect_uri: flowState.metadata.redirect_uri, + token_exchange_method: flowState.metadata.token_exchange_method, /** Encrypted values */ encrypted_oauth_client_id: flowState.metadata.encrypted_oauth_client_id, encrypted_oauth_client_secret: flowState.metadata.encrypted_oauth_client_secret, diff --git a/api/server/routes/agents/actions.js b/api/server/routes/agents/actions.js index d5e771970b..89d6a9dc42 100644 --- a/api/server/routes/agents/actions.js +++ b/api/server/routes/agents/actions.js @@ -107,7 +107,15 @@ router.post('/:agent_id', async (req, res) => { .filter((tool) => !(tool && (tool.includes(domain) || tool.includes(action_id)))) .concat(functions.map((tool) => `${tool.function.name}${actionDelimiter}${domain}`)); - const updatedAgent = await updateAgent(agentQuery, { tools, actions }, req.user.id); + // Force version update since actions are changing + const updatedAgent = await updateAgent( + agentQuery, + { tools, actions }, + { + updatingUserId: req.user.id, + forceVersion: true, + }, + ); // Only update user field for new actions const actionUpdateData = { metadata, agent_id }; @@ -172,7 +180,12 @@ router.delete('/:agent_id/:action_id', async (req, res) => { const updatedTools = tools.filter((tool) => !(tool && tool.includes(domain))); - await updateAgent(agentQuery, { tools: updatedTools, actions: updatedActions }, req.user.id); + // Force version update since actions are being removed + await updateAgent( + agentQuery, + { tools: updatedTools, actions: updatedActions }, + { updatingUserId: req.user.id, forceVersion: true }, + ); // If admin, can delete any action, otherwise only user's actions const actionQuery = admin ? { action_id } : { action_id, user: req.user.id }; await deleteAction(actionQuery); diff --git a/api/server/routes/config.js b/api/server/routes/config.js index e34497688d..a53a636d05 100644 --- a/api/server/routes/config.js +++ b/api/server/routes/config.js @@ -37,6 +37,18 @@ router.get('/', async function (req, res) { const ldap = getLdapConfig(); try { + const isOpenIdEnabled = + !!process.env.OPENID_CLIENT_ID && + !!process.env.OPENID_CLIENT_SECRET && + !!process.env.OPENID_ISSUER && + !!process.env.OPENID_SESSION_SECRET; + + const isSamlEnabled = + !!process.env.SAML_ENTRY_POINT && + !!process.env.SAML_ISSUER && + !!process.env.SAML_CERT && + !!process.env.SAML_SESSION_SECRET; + /** @type {TStartupConfig} */ const payload = { appTitle: process.env.APP_TITLE || 'LibreChat', @@ -51,14 +63,13 @@ router.get('/', async function (req, res) { !!process.env.APPLE_TEAM_ID && !!process.env.APPLE_KEY_ID && !!process.env.APPLE_PRIVATE_KEY_PATH, - openidLoginEnabled: - !!process.env.OPENID_CLIENT_ID && - !!process.env.OPENID_CLIENT_SECRET && - !!process.env.OPENID_ISSUER && - !!process.env.OPENID_SESSION_SECRET, + openidLoginEnabled: isOpenIdEnabled, openidLabel: process.env.OPENID_BUTTON_LABEL || 'Continue with OpenID', openidImageUrl: process.env.OPENID_IMAGE_URL, openidAutoRedirect: isEnabled(process.env.OPENID_AUTO_REDIRECT), + samlLoginEnabled: !isOpenIdEnabled && isSamlEnabled, + samlLabel: process.env.SAML_BUTTON_LABEL, + samlImageUrl: process.env.SAML_IMAGE_URL, serverDomain: process.env.DOMAIN_SERVER || 'http://localhost:3080', emailLoginEnabled, registrationEnabled: !ldap?.enabled && isEnabled(process.env.ALLOW_REGISTRATION), diff --git a/api/server/routes/convos.js b/api/server/routes/convos.js index 87bac6ed29..eb7e2c5c27 100644 --- a/api/server/routes/convos.js +++ b/api/server/routes/convos.js @@ -65,8 +65,14 @@ router.post('/gen_title', async (req, res) => { let title = await titleCache.get(key); if (!title) { - await sleep(2500); - title = await titleCache.get(key); + // Retry every 1s for up to 20s + for (let i = 0; i < 20; i++) { + await sleep(1000); + title = await titleCache.get(key); + if (title) { + break; + } + } } if (title) { diff --git a/api/server/routes/files/multer.js b/api/server/routes/files/multer.js index f23ecd2823..257c309fa2 100644 --- a/api/server/routes/files/multer.js +++ b/api/server/routes/files/multer.js @@ -2,8 +2,8 @@ const fs = require('fs'); const path = require('path'); const crypto = require('crypto'); const multer = require('multer'); +const { sanitizeFilename } = require('@librechat/api'); const { fileConfig: defaultFileConfig, mergeFileConfig } = require('librechat-data-provider'); -const { sanitizeFilename } = require('~/server/utils/handleText'); const { getCustomConfig } = require('~/server/services/Config'); const storage = multer.diskStorage({ diff --git a/api/server/routes/files/multer.spec.js b/api/server/routes/files/multer.spec.js new file mode 100644 index 0000000000..0324262a71 --- /dev/null +++ b/api/server/routes/files/multer.spec.js @@ -0,0 +1,571 @@ +/* eslint-disable no-unused-vars */ +/* eslint-disable jest/no-done-callback */ +const fs = require('fs'); +const os = require('os'); +const path = require('path'); +const crypto = require('crypto'); +const { createMulterInstance, storage, importFileFilter } = require('./multer'); + +// Mock only the config service that requires external dependencies +jest.mock('~/server/services/Config', () => ({ + getCustomConfig: jest.fn(() => + Promise.resolve({ + fileConfig: { + endpoints: { + openAI: { + supportedMimeTypes: ['image/jpeg', 'image/png', 'application/pdf'], + }, + default: { + supportedMimeTypes: ['image/jpeg', 'image/png', 'text/plain'], + }, + }, + serverFileSizeLimit: 10000000, // 10MB + }, + }), + ), +})); + +describe('Multer Configuration', () => { + let tempDir; + let mockReq; + let mockFile; + + beforeEach(() => { + // Create a temporary directory for each test + tempDir = fs.mkdtempSync(path.join(os.tmpdir(), 'multer-test-')); + + mockReq = { + user: { id: 'test-user-123' }, + app: { + locals: { + paths: { + uploads: tempDir, + }, + }, + }, + body: {}, + originalUrl: '/api/files/upload', + }; + + mockFile = { + originalname: 'test-file.jpg', + mimetype: 'image/jpeg', + size: 1024, + }; + + // Clear mocks + jest.clearAllMocks(); + }); + + afterEach(() => { + // Clean up temporary directory + if (fs.existsSync(tempDir)) { + fs.rmSync(tempDir, { recursive: true, force: true }); + } + }); + + describe('Storage Configuration', () => { + describe('destination function', () => { + it('should create the correct destination path', (done) => { + const cb = jest.fn((err, destination) => { + expect(err).toBeNull(); + expect(destination).toBe(path.join(tempDir, 'temp', 'test-user-123')); + expect(fs.existsSync(destination)).toBe(true); + done(); + }); + + storage.getDestination(mockReq, mockFile, cb); + }); + + it("should create directory recursively if it doesn't exist", (done) => { + const deepPath = path.join(tempDir, 'deep', 'nested', 'path'); + mockReq.app.locals.paths.uploads = deepPath; + + const cb = jest.fn((err, destination) => { + expect(err).toBeNull(); + expect(destination).toBe(path.join(deepPath, 'temp', 'test-user-123')); + expect(fs.existsSync(destination)).toBe(true); + done(); + }); + + storage.getDestination(mockReq, mockFile, cb); + }); + }); + + describe('filename function', () => { + it('should generate a UUID for req.file_id', (done) => { + const cb = jest.fn((err, filename) => { + expect(err).toBeNull(); + expect(mockReq.file_id).toBeDefined(); + expect(mockReq.file_id).toMatch( + /^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i, + ); + done(); + }); + + storage.getFilename(mockReq, mockFile, cb); + }); + + it('should decode URI components in filename', (done) => { + const encodedFile = { + ...mockFile, + originalname: encodeURIComponent('test file with spaces.jpg'), + }; + + const cb = jest.fn((err, filename) => { + expect(err).toBeNull(); + expect(encodedFile.originalname).toBe('test file with spaces.jpg'); + done(); + }); + + storage.getFilename(mockReq, encodedFile, cb); + }); + + it('should call real sanitizeFilename with properly encoded filename', (done) => { + // Test with a properly URI-encoded filename that needs sanitization + const unsafeFile = { + ...mockFile, + originalname: encodeURIComponent('test@#$%^&*()file with spaces!.jpg'), + }; + + const cb = jest.fn((err, filename) => { + expect(err).toBeNull(); + // The actual sanitizeFilename should have cleaned this up after decoding + expect(filename).not.toContain('@'); + expect(filename).not.toContain('#'); + expect(filename).not.toContain('*'); + expect(filename).not.toContain('!'); + // Should still preserve dots and hyphens + expect(filename).toContain('.jpg'); + done(); + }); + + storage.getFilename(mockReq, unsafeFile, cb); + }); + + it('should handle very long filenames with actual crypto', (done) => { + const longFile = { + ...mockFile, + originalname: 'a'.repeat(300) + '.jpg', + }; + + const cb = jest.fn((err, filename) => { + expect(err).toBeNull(); + expect(filename.length).toBeLessThanOrEqual(255); + expect(filename).toMatch(/\.jpg$/); // Should still end with .jpg + // Should contain a hex suffix if truncated + if (filename.length === 255) { + expect(filename).toMatch(/-[a-f0-9]{6}\.jpg$/); + } + done(); + }); + + storage.getFilename(mockReq, longFile, cb); + }); + + it('should generate unique file_id for each call', (done) => { + let firstFileId; + + const firstCb = jest.fn((err, filename) => { + expect(err).toBeNull(); + firstFileId = mockReq.file_id; + + // Reset req for second call + delete mockReq.file_id; + + const secondCb = jest.fn((err, filename) => { + expect(err).toBeNull(); + expect(mockReq.file_id).toBeDefined(); + expect(mockReq.file_id).not.toBe(firstFileId); + done(); + }); + + storage.getFilename(mockReq, mockFile, secondCb); + }); + + storage.getFilename(mockReq, mockFile, firstCb); + }); + }); + }); + + describe('Import File Filter', () => { + it('should accept JSON files by mimetype', (done) => { + const jsonFile = { + ...mockFile, + mimetype: 'application/json', + originalname: 'data.json', + }; + + const cb = jest.fn((err, result) => { + expect(err).toBeNull(); + expect(result).toBe(true); + done(); + }); + + importFileFilter(mockReq, jsonFile, cb); + }); + + it('should accept files with .json extension', (done) => { + const jsonFile = { + ...mockFile, + mimetype: 'text/plain', + originalname: 'data.json', + }; + + const cb = jest.fn((err, result) => { + expect(err).toBeNull(); + expect(result).toBe(true); + done(); + }); + + importFileFilter(mockReq, jsonFile, cb); + }); + + it('should reject non-JSON files', (done) => { + const textFile = { + ...mockFile, + mimetype: 'text/plain', + originalname: 'document.txt', + }; + + const cb = jest.fn((err, result) => { + expect(err).toBeInstanceOf(Error); + expect(err.message).toBe('Only JSON files are allowed'); + expect(result).toBe(false); + done(); + }); + + importFileFilter(mockReq, textFile, cb); + }); + + it('should handle files with uppercase .JSON extension', (done) => { + const jsonFile = { + ...mockFile, + mimetype: 'text/plain', + originalname: 'DATA.JSON', + }; + + const cb = jest.fn((err, result) => { + expect(err).toBeNull(); + expect(result).toBe(true); + done(); + }); + + importFileFilter(mockReq, jsonFile, cb); + }); + }); + + describe('File Filter with Real defaultFileConfig', () => { + it('should use real fileConfig.checkType for validation', async () => { + // Test with actual librechat-data-provider functions + const { + fileConfig, + imageMimeTypes, + applicationMimeTypes, + } = require('librechat-data-provider'); + + // Test that the real checkType function works with regex patterns + expect(fileConfig.checkType('image/jpeg', [imageMimeTypes])).toBe(true); + expect(fileConfig.checkType('video/mp4', [imageMimeTypes])).toBe(false); + expect(fileConfig.checkType('application/pdf', [applicationMimeTypes])).toBe(true); + expect(fileConfig.checkType('application/pdf', [])).toBe(false); + }); + + it('should handle audio files for speech-to-text endpoint with real config', async () => { + mockReq.originalUrl = '/api/speech/stt'; + + const multerInstance = await createMulterInstance(); + expect(multerInstance).toBeDefined(); + expect(typeof multerInstance.single).toBe('function'); + }); + + it('should reject unsupported file types using real config', async () => { + // Mock defaultFileConfig for this specific test + const originalCheckType = require('librechat-data-provider').fileConfig.checkType; + const mockCheckType = jest.fn().mockReturnValue(false); + require('librechat-data-provider').fileConfig.checkType = mockCheckType; + + try { + const multerInstance = await createMulterInstance(); + expect(multerInstance).toBeDefined(); + + // Test the actual file filter behavior would reject unsupported files + expect(mockCheckType).toBeDefined(); + } finally { + // Restore original function + require('librechat-data-provider').fileConfig.checkType = originalCheckType; + } + }); + + it('should use real mergeFileConfig function', async () => { + const { mergeFileConfig, mbToBytes } = require('librechat-data-provider'); + + // Test with actual merge function - note that it converts MB to bytes + const testConfig = { + serverFileSizeLimit: 5, // 5 MB + endpoints: { + custom: { + supportedMimeTypes: ['text/plain'], + }, + }, + }; + + const result = mergeFileConfig(testConfig); + + // The function converts MB to bytes, so 5 MB becomes 5 * 1024 * 1024 bytes + expect(result.serverFileSizeLimit).toBe(mbToBytes(5)); + expect(result.endpoints.custom.supportedMimeTypes).toBeDefined(); + // Should still have the default endpoints + expect(result.endpoints.default).toBeDefined(); + }); + }); + + describe('createMulterInstance with Real Functions', () => { + it('should create a multer instance with correct configuration', async () => { + const multerInstance = await createMulterInstance(); + + expect(multerInstance).toBeDefined(); + expect(typeof multerInstance.single).toBe('function'); + expect(typeof multerInstance.array).toBe('function'); + expect(typeof multerInstance.fields).toBe('function'); + }); + + it('should use real config merging', async () => { + const { getCustomConfig } = require('~/server/services/Config'); + + const multerInstance = await createMulterInstance(); + + expect(getCustomConfig).toHaveBeenCalled(); + expect(multerInstance).toBeDefined(); + }); + + it('should create multer instance with expected interface', async () => { + const multerInstance = await createMulterInstance(); + + expect(multerInstance).toBeDefined(); + expect(typeof multerInstance.single).toBe('function'); + expect(typeof multerInstance.array).toBe('function'); + expect(typeof multerInstance.fields).toBe('function'); + }); + }); + + describe('Real Crypto Integration', () => { + it('should use actual crypto.randomUUID()', (done) => { + // Spy on crypto.randomUUID to ensure it's called + const uuidSpy = jest.spyOn(crypto, 'randomUUID'); + + const cb = jest.fn((err, filename) => { + expect(err).toBeNull(); + expect(uuidSpy).toHaveBeenCalled(); + expect(mockReq.file_id).toMatch( + /^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i, + ); + + uuidSpy.mockRestore(); + done(); + }); + + storage.getFilename(mockReq, mockFile, cb); + }); + + it('should generate different UUIDs on subsequent calls', (done) => { + const uuids = []; + let callCount = 0; + const totalCalls = 5; + + const cb = jest.fn((err, filename) => { + expect(err).toBeNull(); + uuids.push(mockReq.file_id); + callCount++; + + if (callCount === totalCalls) { + // Check that all UUIDs are unique + const uniqueUuids = new Set(uuids); + expect(uniqueUuids.size).toBe(totalCalls); + done(); + } else { + // Reset for next call + delete mockReq.file_id; + storage.getFilename(mockReq, mockFile, cb); + } + }); + + // Start the chain + storage.getFilename(mockReq, mockFile, cb); + }); + + it('should generate cryptographically secure UUIDs', (done) => { + const generatedUuids = new Set(); + let callCount = 0; + const totalCalls = 10; + + const cb = jest.fn((err, filename) => { + expect(err).toBeNull(); + + // Verify UUID format and uniqueness + expect(mockReq.file_id).toMatch( + /^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$/i, + ); + + generatedUuids.add(mockReq.file_id); + callCount++; + + if (callCount === totalCalls) { + // All UUIDs should be unique + expect(generatedUuids.size).toBe(totalCalls); + done(); + } else { + // Reset for next call + delete mockReq.file_id; + storage.getFilename(mockReq, mockFile, cb); + } + }); + + // Start the chain + storage.getFilename(mockReq, mockFile, cb); + }); + }); + + describe('Error Handling', () => { + it('should handle CVE-2024-28870: empty field name DoS vulnerability', async () => { + // Test for the CVE where empty field name could cause unhandled exception + const multerInstance = await createMulterInstance(); + + // Create a mock request with empty field name (the vulnerability scenario) + const mockReqWithEmptyField = { + ...mockReq, + headers: { + 'content-type': 'multipart/form-data', + }, + }; + + const mockRes = { + status: jest.fn().mockReturnThis(), + json: jest.fn(), + end: jest.fn(), + }; + + // This should not crash or throw unhandled exceptions + const uploadMiddleware = multerInstance.single(''); // Empty field name + + const mockNext = jest.fn((err) => { + // If there's an error, it should be handled gracefully, not crash + if (err) { + expect(err).toBeInstanceOf(Error); + // The error should be handled, not crash the process + } + }); + + // This should complete without crashing the process + expect(() => { + uploadMiddleware(mockReqWithEmptyField, mockRes, mockNext); + }).not.toThrow(); + }); + + it('should handle file system errors when directory creation fails', (done) => { + // Test with a non-existent parent directory to simulate fs issues + const invalidPath = '/nonexistent/path/that/should/not/exist'; + mockReq.app.locals.paths.uploads = invalidPath; + + try { + // Call getDestination which should fail due to permission/path issues + storage.getDestination(mockReq, mockFile, (err, destination) => { + // If callback is reached, we didn't get the expected error + done(new Error('Expected mkdirSync to throw an error but callback was called')); + }); + // If we get here without throwing, something unexpected happened + done(new Error('Expected mkdirSync to throw an error but no error was thrown')); + } catch (error) { + // This is the expected behavior - mkdirSync throws synchronously for invalid paths + expect(error.code).toBe('EACCES'); + done(); + } + }); + + it('should handle malformed filenames with real sanitization', (done) => { + const malformedFile = { + ...mockFile, + originalname: null, // This should be handled gracefully + }; + + const cb = jest.fn((err, filename) => { + // The function should handle this gracefully + expect(typeof err === 'object' || err === null).toBe(true); + done(); + }); + + try { + storage.getFilename(mockReq, malformedFile, cb); + } catch (error) { + // If it throws, that's also acceptable behavior + done(); + } + }); + + it('should handle edge cases in filename sanitization', (done) => { + const edgeCaseFiles = [ + { originalname: '', expected: /_/ }, + { originalname: '.hidden', expected: /^_\.hidden/ }, + { originalname: '../../../etc/passwd', expected: /passwd/ }, + { originalname: 'file\x00name.txt', expected: /file_name\.txt/ }, + ]; + + let testCount = 0; + + const testNextFile = (fileData) => { + const fileToTest = { ...mockFile, originalname: fileData.originalname }; + + const cb = jest.fn((err, filename) => { + expect(err).toBeNull(); + expect(filename).toMatch(fileData.expected); + + testCount++; + if (testCount === edgeCaseFiles.length) { + done(); + } else { + testNextFile(edgeCaseFiles[testCount]); + } + }); + + storage.getFilename(mockReq, fileToTest, cb); + }; + + testNextFile(edgeCaseFiles[0]); + }); + }); + + describe('Real Configuration Testing', () => { + it('should handle missing custom config gracefully with real mergeFileConfig', async () => { + const { getCustomConfig } = require('~/server/services/Config'); + + // Mock getCustomConfig to return undefined + getCustomConfig.mockResolvedValueOnce(undefined); + + const multerInstance = await createMulterInstance(); + expect(multerInstance).toBeDefined(); + expect(typeof multerInstance.single).toBe('function'); + }); + + it('should properly integrate real fileConfig with custom endpoints', async () => { + const { getCustomConfig } = require('~/server/services/Config'); + + // Mock a custom config with additional endpoints + getCustomConfig.mockResolvedValueOnce({ + fileConfig: { + endpoints: { + anthropic: { + supportedMimeTypes: ['text/plain', 'image/png'], + }, + }, + serverFileSizeLimit: 20, // 20 MB + }, + }); + + const multerInstance = await createMulterInstance(); + expect(multerInstance).toBeDefined(); + + // Verify that getCustomConfig was called (we can't spy on the actual merge function easily) + expect(getCustomConfig).toHaveBeenCalled(); + }); + }); +}); diff --git a/api/server/routes/index.js b/api/server/routes/index.js index 449759383d..06e39d3671 100644 --- a/api/server/routes/index.js +++ b/api/server/routes/index.js @@ -4,6 +4,7 @@ const tokenizer = require('./tokenizer'); const endpoints = require('./endpoints'); const staticRoute = require('./static'); const messages = require('./messages'); +const memories = require('./memories'); const presets = require('./presets'); const prompts = require('./prompts'); const balance = require('./balance'); @@ -51,6 +52,7 @@ module.exports = { presets, balance, messages, + memories, endpoints, tokenizer, assistants, diff --git a/api/server/routes/memories.js b/api/server/routes/memories.js new file mode 100644 index 0000000000..86065fecaa --- /dev/null +++ b/api/server/routes/memories.js @@ -0,0 +1,231 @@ +const express = require('express'); +const { Tokenizer } = require('@librechat/api'); +const { PermissionTypes, Permissions } = require('librechat-data-provider'); +const { + getAllUserMemories, + toggleUserMemories, + createMemory, + setMemory, + deleteMemory, +} = require('~/models'); +const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware'); + +const router = express.Router(); + +const checkMemoryRead = generateCheckAccess(PermissionTypes.MEMORIES, [ + Permissions.USE, + Permissions.READ, +]); +const checkMemoryCreate = generateCheckAccess(PermissionTypes.MEMORIES, [ + Permissions.USE, + Permissions.CREATE, +]); +const checkMemoryUpdate = generateCheckAccess(PermissionTypes.MEMORIES, [ + Permissions.USE, + Permissions.UPDATE, +]); +const checkMemoryDelete = generateCheckAccess(PermissionTypes.MEMORIES, [ + Permissions.USE, + Permissions.UPDATE, +]); +const checkMemoryOptOut = generateCheckAccess(PermissionTypes.MEMORIES, [ + Permissions.USE, + Permissions.OPT_OUT, +]); + +router.use(requireJwtAuth); + +/** + * GET /memories + * Returns all memories for the authenticated user, sorted by updated_at (newest first). + * Also includes memory usage percentage based on token limit. + */ +router.get('/', checkMemoryRead, async (req, res) => { + try { + const memories = await getAllUserMemories(req.user.id); + + const sortedMemories = memories.sort( + (a, b) => new Date(b.updated_at).getTime() - new Date(a.updated_at).getTime(), + ); + + const totalTokens = memories.reduce((sum, memory) => { + return sum + (memory.tokenCount || 0); + }, 0); + + const memoryConfig = req.app.locals?.memory; + const tokenLimit = memoryConfig?.tokenLimit; + + let usagePercentage = null; + if (tokenLimit && tokenLimit > 0) { + usagePercentage = Math.min(100, Math.round((totalTokens / tokenLimit) * 100)); + } + + res.json({ + memories: sortedMemories, + totalTokens, + tokenLimit: tokenLimit || null, + usagePercentage, + }); + } catch (error) { + res.status(500).json({ error: error.message }); + } +}); + +/** + * POST /memories + * Creates a new memory entry for the authenticated user. + * Body: { key: string, value: string } + * Returns 201 and { created: true, memory: } when successful. + */ +router.post('/', checkMemoryCreate, async (req, res) => { + const { key, value } = req.body; + + if (typeof key !== 'string' || key.trim() === '') { + return res.status(400).json({ error: 'Key is required and must be a non-empty string.' }); + } + + if (typeof value !== 'string' || value.trim() === '') { + return res.status(400).json({ error: 'Value is required and must be a non-empty string.' }); + } + + try { + const tokenCount = Tokenizer.getTokenCount(value, 'o200k_base'); + + const memories = await getAllUserMemories(req.user.id); + + // Check token limit + const memoryConfig = req.app.locals?.memory; + const tokenLimit = memoryConfig?.tokenLimit; + + if (tokenLimit) { + const currentTotalTokens = memories.reduce( + (sum, memory) => sum + (memory.tokenCount || 0), + 0, + ); + if (currentTotalTokens + tokenCount > tokenLimit) { + return res.status(400).json({ + error: `Adding this memory would exceed the token limit of ${tokenLimit}. Current usage: ${currentTotalTokens} tokens.`, + }); + } + } + + const result = await createMemory({ + userId: req.user.id, + key: key.trim(), + value: value.trim(), + tokenCount, + }); + + if (!result.ok) { + return res.status(500).json({ error: 'Failed to create memory.' }); + } + + const updatedMemories = await getAllUserMemories(req.user.id); + const newMemory = updatedMemories.find((m) => m.key === key.trim()); + + res.status(201).json({ created: true, memory: newMemory }); + } catch (error) { + if (error.message && error.message.includes('already exists')) { + return res.status(409).json({ error: 'Memory with this key already exists.' }); + } + res.status(500).json({ error: error.message }); + } +}); + +/** + * PATCH /memories/preferences + * Updates the user's memory preferences (e.g., enabling/disabling memories). + * Body: { memories: boolean } + * Returns 200 and { updated: true, preferences: { memories: boolean } } when successful. + */ +router.patch('/preferences', checkMemoryOptOut, async (req, res) => { + const { memories } = req.body; + + if (typeof memories !== 'boolean') { + return res.status(400).json({ error: 'memories must be a boolean value.' }); + } + + try { + const updatedUser = await toggleUserMemories(req.user.id, memories); + + if (!updatedUser) { + return res.status(404).json({ error: 'User not found.' }); + } + + res.json({ + updated: true, + preferences: { + memories: updatedUser.personalization?.memories ?? true, + }, + }); + } catch (error) { + res.status(500).json({ error: error.message }); + } +}); + +/** + * PATCH /memories/:key + * Updates the value of an existing memory entry for the authenticated user. + * Body: { value: string } + * Returns 200 and { updated: true, memory: } when successful. + */ +router.patch('/:key', checkMemoryUpdate, async (req, res) => { + const { key } = req.params; + const { value } = req.body || {}; + + if (typeof value !== 'string' || value.trim() === '') { + return res.status(400).json({ error: 'Value is required and must be a non-empty string.' }); + } + + try { + const tokenCount = Tokenizer.getTokenCount(value, 'o200k_base'); + + const memories = await getAllUserMemories(req.user.id); + const existingMemory = memories.find((m) => m.key === key); + + if (!existingMemory) { + return res.status(404).json({ error: 'Memory not found.' }); + } + + const result = await setMemory({ + userId: req.user.id, + key, + value, + tokenCount, + }); + + if (!result.ok) { + return res.status(500).json({ error: 'Failed to update memory.' }); + } + + const updatedMemories = await getAllUserMemories(req.user.id); + const updatedMemory = updatedMemories.find((m) => m.key === key); + + res.json({ updated: true, memory: updatedMemory }); + } catch (error) { + res.status(500).json({ error: error.message }); + } +}); + +/** + * DELETE /memories/:key + * Deletes a memory entry for the authenticated user. + * Returns 200 and { deleted: true } when successful. + */ +router.delete('/:key', checkMemoryDelete, async (req, res) => { + const { key } = req.params; + + try { + const result = await deleteMemory({ userId: req.user.id, key }); + + if (!result.ok) { + return res.status(404).json({ error: 'Memory not found.' }); + } + + res.json({ deleted: true }); + } catch (error) { + res.status(500).json({ error: error.message }); + } +}); + +module.exports = router; diff --git a/api/server/routes/messages.js b/api/server/routes/messages.js index d5980ae55b..356dd25097 100644 --- a/api/server/routes/messages.js +++ b/api/server/routes/messages.js @@ -1,4 +1,5 @@ const express = require('express'); +const { logger } = require('@librechat/data-schemas'); const { ContentTypes } = require('librechat-data-provider'); const { saveConvo, @@ -13,8 +14,7 @@ const { requireJwtAuth, validateMessageReq } = require('~/server/middleware'); const { cleanUpPrimaryKeyValue } = require('~/lib/utils/misc'); const { getConvosQueried } = require('~/models/Conversation'); const { countTokens } = require('~/server/utils'); -const { Message } = require('~/models/Message'); -const { logger } = require('~/config'); +const { Message } = require('~/db/models'); const router = express.Router(); router.use(requireJwtAuth); @@ -40,7 +40,11 @@ router.get('/', async (req, res) => { const sortOrder = sortDirection === 'asc' ? 1 : -1; if (conversationId && messageId) { - const message = await Message.findOne({ conversationId, messageId, user: user }).lean(); + const message = await Message.findOne({ + conversationId, + messageId, + user: user, + }).lean(); response = { messages: message ? [message] : [], nextCursor: null }; } else if (conversationId) { const filter = { conversationId, user: user }; @@ -253,6 +257,31 @@ router.put('/:conversationId/:messageId', validateMessageReq, async (req, res) = } }); +router.put('/:conversationId/:messageId/feedback', validateMessageReq, async (req, res) => { + try { + const { conversationId, messageId } = req.params; + const { feedback } = req.body; + + const updatedMessage = await updateMessage( + req, + { + messageId, + feedback: feedback || null, + }, + { context: 'updateFeedback' }, + ); + + res.json({ + messageId, + conversationId, + feedback: updatedMessage.feedback, + }); + } catch (error) { + logger.error('Error updating message feedback:', error); + res.status(500).json({ error: 'Failed to update feedback' }); + } +}); + router.delete('/:conversationId/:messageId', validateMessageReq, async (req, res) => { try { const { messageId } = req.params; diff --git a/api/server/routes/oauth.js b/api/server/routes/oauth.js index 9915390a5d..bc8d120ef5 100644 --- a/api/server/routes/oauth.js +++ b/api/server/routes/oauth.js @@ -189,4 +189,24 @@ router.post( oauthHandler, ); +/** + * SAML Routes + */ +router.get( + '/saml', + passport.authenticate('saml', { + session: false, + }), +); + +router.post( + '/saml/callback', + passport.authenticate('saml', { + failureRedirect: `${domains.client}/oauth/error`, + failureMessage: true, + session: false, + }), + oauthHandler, +); + module.exports = router; diff --git a/api/server/routes/roles.js b/api/server/routes/roles.js index 17768c7de6..aefbfcec0c 100644 --- a/api/server/routes/roles.js +++ b/api/server/routes/roles.js @@ -1,6 +1,7 @@ const express = require('express'); const { promptPermissionsSchema, + memoryPermissionsSchema, agentPermissionsSchema, PermissionTypes, roleDefaults, @@ -118,4 +119,43 @@ router.put('/:roleName/agents', checkAdmin, async (req, res) => { } }); +/** + * PUT /api/roles/:roleName/memories + * Update memory permissions for a specific role + */ +router.put('/:roleName/memories', checkAdmin, async (req, res) => { + const { roleName: _r } = req.params; + // TODO: TEMP, use a better parsing for roleName + const roleName = _r.toUpperCase(); + /** @type {TRole['permissions']['MEMORIES']} */ + const updates = req.body; + + try { + const parsedUpdates = memoryPermissionsSchema.partial().parse(updates); + + const role = await getRoleByName(roleName); + if (!role) { + return res.status(404).send({ message: 'Role not found' }); + } + + const currentPermissions = + role.permissions?.[PermissionTypes.MEMORIES] || role[PermissionTypes.MEMORIES] || {}; + + const mergedUpdates = { + permissions: { + ...role.permissions, + [PermissionTypes.MEMORIES]: { + ...currentPermissions, + ...parsedUpdates, + }, + }, + }; + + const updatedRole = await updateRoleByName(roleName, mergedUpdates); + res.status(200).send(updatedRole); + } catch (error) { + return res.status(400).send({ message: 'Invalid memory permissions.', error: error.errors }); + } +}); + module.exports = router; diff --git a/api/server/routes/share.js b/api/server/routes/share.js index e551f4a354..14c25271fc 100644 --- a/api/server/routes/share.js +++ b/api/server/routes/share.js @@ -1,15 +1,15 @@ const express = require('express'); - +const { isEnabled } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { - getSharedLink, getSharedMessages, createSharedLink, updateSharedLink, - getSharedLinks, deleteSharedLink, -} = require('~/models/Share'); + getSharedLinks, + getSharedLink, +} = require('~/models'); const requireJwtAuth = require('~/server/middleware/requireJwtAuth'); -const { isEnabled } = require('~/server/utils'); const router = express.Router(); /** @@ -35,6 +35,7 @@ if (allowSharedLinks) { res.status(404).end(); } } catch (error) { + logger.error('Error getting shared messages:', error); res.status(500).json({ message: 'Error getting shared messages' }); } }, @@ -54,9 +55,7 @@ router.get('/', requireJwtAuth, async (req, res) => { sortDirection: ['asc', 'desc'].includes(req.query.sortDirection) ? req.query.sortDirection : 'desc', - search: req.query.search - ? decodeURIComponent(req.query.search.trim()) - : undefined, + search: req.query.search ? decodeURIComponent(req.query.search.trim()) : undefined, }; const result = await getSharedLinks( @@ -75,7 +74,7 @@ router.get('/', requireJwtAuth, async (req, res) => { hasNextPage: result.hasNextPage, }); } catch (error) { - console.error('Error getting shared links:', error); + logger.error('Error getting shared links:', error); res.status(500).json({ message: 'Error getting shared links', error: error.message, @@ -93,6 +92,7 @@ router.get('/link/:conversationId', requireJwtAuth, async (req, res) => { conversationId: req.params.conversationId, }); } catch (error) { + logger.error('Error getting shared link:', error); res.status(500).json({ message: 'Error getting shared link' }); } }); @@ -106,6 +106,7 @@ router.post('/:conversationId', requireJwtAuth, async (req, res) => { res.status(404).end(); } } catch (error) { + logger.error('Error creating shared link:', error); res.status(500).json({ message: 'Error creating shared link' }); } }); @@ -119,6 +120,7 @@ router.patch('/:shareId', requireJwtAuth, async (req, res) => { res.status(404).end(); } } catch (error) { + logger.error('Error updating shared link:', error); res.status(500).json({ message: 'Error updating shared link' }); } }); @@ -133,7 +135,8 @@ router.delete('/:shareId', requireJwtAuth, async (req, res) => { return res.status(200).json(result); } catch (error) { - return res.status(400).json({ message: error.message }); + logger.error('Error deleting shared link:', error); + return res.status(400).json({ message: 'Error deleting shared link' }); } }); diff --git a/api/server/services/ActionService.js b/api/server/services/ActionService.js index c8a7955427..9bf7491543 100644 --- a/api/server/services/ActionService.js +++ b/api/server/services/ActionService.js @@ -1,7 +1,9 @@ const jwt = require('jsonwebtoken'); const { nanoid } = require('nanoid'); const { tool } = require('@langchain/core/tools'); +const { logger } = require('@librechat/data-schemas'); const { GraphEvents, sleep } = require('@librechat/agents'); +const { sendEvent, logAxiosError } = require('@librechat/api'); const { Time, CacheKeys, @@ -13,13 +15,12 @@ const { actionDomainSeparator, } = require('librechat-data-provider'); const { refreshAccessToken } = require('~/server/services/TokenService'); -const { logger, getFlowStateManager, sendEvent } = require('~/config'); const { encryptV2, decryptV2 } = require('~/server/utils/crypto'); const { getActions, deleteActions } = require('~/models/Action'); const { deleteAssistant } = require('~/models/Assistant'); -const { findToken } = require('~/models/Token'); -const { logAxiosError } = require('~/utils'); +const { getFlowStateManager } = require('~/config'); const { getLogStores } = require('~/cache'); +const { findToken } = require('~/models'); const JWT_SECRET = process.env.JWT_SECRET; const toolNameRegex = /^[a-zA-Z0-9_-]+$/; @@ -207,7 +208,8 @@ async function createActionTool({ state: stateToken, userId: userId, client_url: metadata.auth.client_url, - redirect_uri: `${process.env.DOMAIN_CLIENT}/api/actions/${action_id}/oauth/callback`, + redirect_uri: `${process.env.DOMAIN_SERVER}/api/actions/${action_id}/oauth/callback`, + token_exchange_method: metadata.auth.token_exchange_method, /** Encrypted values */ encrypted_oauth_client_id: encrypted.oauth_client_id, encrypted_oauth_client_secret: encrypted.oauth_client_secret, @@ -262,6 +264,7 @@ async function createActionTool({ refresh_token, client_url: metadata.auth.client_url, encrypted_oauth_client_id: encrypted.oauth_client_id, + token_exchange_method: metadata.auth.token_exchange_method, encrypted_oauth_client_secret: encrypted.oauth_client_secret, }); const flowsCache = getLogStores(CacheKeys.FLOWS); diff --git a/api/server/services/AppService.interface.spec.js b/api/server/services/AppService.interface.spec.js index 0bf9d07dcc..90168d4778 100644 --- a/api/server/services/AppService.interface.spec.js +++ b/api/server/services/AppService.interface.spec.js @@ -1,5 +1,7 @@ -jest.mock('~/models/Role', () => ({ +jest.mock('~/models', () => ({ initializeRoles: jest.fn(), +})); +jest.mock('~/models/Role', () => ({ updateAccessPermissions: jest.fn(), getRoleByName: jest.fn(), updateRoleByName: jest.fn(), diff --git a/api/server/services/AppService.js b/api/server/services/AppService.js index 6a1cdfc695..2e5a0e586b 100644 --- a/api/server/services/AppService.js +++ b/api/server/services/AppService.js @@ -3,9 +3,11 @@ const { loadOCRConfig, processMCPEnv, EModelEndpoint, + loadMemoryConfig, getConfigDefaults, loadWebSearchConfig, } = require('librechat-data-provider'); +const { agentsConfigSetup } = require('@librechat/api'); const { checkHealth, checkConfig, @@ -24,9 +26,8 @@ const { azureConfigSetup } = require('./start/azureOpenAI'); const { processModelSpecs } = require('./start/modelSpecs'); const { initializeS3 } = require('./Files/S3/initialize'); const { loadAndFormatTools } = require('./ToolService'); -const { agentsConfigSetup } = require('./start/agents'); -const { initializeRoles } = require('~/models/Role'); const { isEnabled } = require('~/server/utils'); +const { initializeRoles } = require('~/models'); const { getMCPManager } = require('~/config'); const paths = require('~/config/paths'); @@ -44,6 +45,7 @@ const AppService = async (app) => { const ocr = loadOCRConfig(config.ocr); const webSearch = loadWebSearchConfig(config.webSearch); checkWebSearchConfig(webSearch); + const memory = loadMemoryConfig(config.memory); const filteredTools = config.filteredTools; const includedTools = config.includedTools; const fileStrategy = config.fileStrategy ?? configDefaults.fileStrategy; @@ -88,6 +90,7 @@ const AppService = async (app) => { const defaultLocals = { ocr, paths, + memory, webSearch, fileStrategy, socialLogins, @@ -100,8 +103,13 @@ const AppService = async (app) => { balance, }; + const agentsDefaults = agentsConfigSetup(config); + if (!Object.keys(config).length) { - app.locals = defaultLocals; + app.locals = { + ...defaultLocals, + [EModelEndpoint.agents]: agentsDefaults, + }; return; } @@ -136,9 +144,7 @@ const AppService = async (app) => { ); } - if (endpoints?.[EModelEndpoint.agents]) { - endpointLocals[EModelEndpoint.agents] = agentsConfigSetup(config); - } + endpointLocals[EModelEndpoint.agents] = agentsConfigSetup(config, agentsDefaults); const endpointKeys = [ EModelEndpoint.openAI, diff --git a/api/server/services/AppService.spec.js b/api/server/services/AppService.spec.js index 0c7fac6ed3..70a405ccdb 100644 --- a/api/server/services/AppService.spec.js +++ b/api/server/services/AppService.spec.js @@ -2,8 +2,10 @@ const { FileSources, EModelEndpoint, EImageOutputType, + AgentCapabilities, defaultSocialLogins, validateAzureGroups, + defaultAgentCapabilities, deprecatedAzureVariables, conflictingAzureVariables, } = require('librechat-data-provider'); @@ -24,8 +26,10 @@ jest.mock('./Config/loadCustomConfig', () => { jest.mock('./Files/Firebase/initialize', () => ({ initializeFirebase: jest.fn(), })); -jest.mock('~/models/Role', () => ({ +jest.mock('~/models', () => ({ initializeRoles: jest.fn(), +})); +jest.mock('~/models/Role', () => ({ updateAccessPermissions: jest.fn(), })); jest.mock('./ToolService', () => ({ @@ -149,6 +153,11 @@ describe('AppService', () => { safeSearch: 1, serperApiKey: '${SERPER_API_KEY}', }, + memory: undefined, + agents: { + disableBuilder: false, + capabilities: expect.arrayContaining([...defaultAgentCapabilities]), + }, }); }); @@ -266,6 +275,71 @@ describe('AppService', () => { ); }); + it('should correctly configure Agents endpoint based on custom config', async () => { + require('./Config/loadCustomConfig').mockImplementationOnce(() => + Promise.resolve({ + endpoints: { + [EModelEndpoint.agents]: { + disableBuilder: true, + recursionLimit: 10, + maxRecursionLimit: 20, + allowedProviders: ['openai', 'anthropic'], + capabilities: [AgentCapabilities.tools, AgentCapabilities.actions], + }, + }, + }), + ); + + await AppService(app); + + expect(app.locals).toHaveProperty(EModelEndpoint.agents); + expect(app.locals[EModelEndpoint.agents]).toEqual( + expect.objectContaining({ + disableBuilder: true, + recursionLimit: 10, + maxRecursionLimit: 20, + allowedProviders: expect.arrayContaining(['openai', 'anthropic']), + capabilities: expect.arrayContaining([AgentCapabilities.tools, AgentCapabilities.actions]), + }), + ); + }); + + it('should configure Agents endpoint with defaults when no config is provided', async () => { + require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve({})); + + await AppService(app); + + expect(app.locals).toHaveProperty(EModelEndpoint.agents); + expect(app.locals[EModelEndpoint.agents]).toEqual( + expect.objectContaining({ + disableBuilder: false, + capabilities: expect.arrayContaining([...defaultAgentCapabilities]), + }), + ); + }); + + it('should configure Agents endpoint with defaults when endpoints exist but agents is not defined', async () => { + require('./Config/loadCustomConfig').mockImplementationOnce(() => + Promise.resolve({ + endpoints: { + [EModelEndpoint.openAI]: { + titleConvo: true, + }, + }, + }), + ); + + await AppService(app); + + expect(app.locals).toHaveProperty(EModelEndpoint.agents); + expect(app.locals[EModelEndpoint.agents]).toEqual( + expect.objectContaining({ + disableBuilder: false, + capabilities: expect.arrayContaining([...defaultAgentCapabilities]), + }), + ); + }); + it('should correctly configure minimum Azure OpenAI Assistant values', async () => { const assistantGroups = [azureGroups[0], { ...azureGroups[1], assistants: true }]; require('./Config/loadCustomConfig').mockImplementationOnce(() => diff --git a/api/server/services/AuthService.js b/api/server/services/AuthService.js index ac13172128..2c285512ee 100644 --- a/api/server/services/AuthService.js +++ b/api/server/services/AuthService.js @@ -3,24 +3,23 @@ const { webcrypto } = require('node:crypto'); const { SystemRoles, errorsToString } = require('librechat-data-provider'); const { findUser, - countUsers, createUser, updateUser, - getUserById, - generateToken, - deleteUserById, -} = require('~/models/userMethods'); -const { - createToken, findToken, - deleteTokens, + countUsers, + getUserById, findSession, + createToken, + deleteTokens, deleteSession, createSession, + generateToken, + deleteUserById, generateRefreshToken, } = require('~/models'); const { isEnabled, checkEmailConfig, sendEmail } = require('~/server/utils'); const { isEmailDomainAllowed } = require('~/server/services/domains'); +const { getBalanceConfig } = require('~/server/services/Config'); const { registerSchema } = require('~/strategies/validators'); const { logger } = require('~/config'); @@ -146,6 +145,7 @@ const verifyEmail = async (req) => { } const updatedUser = await updateUser(emailVerificationData.userId, { emailVerified: true }); + if (!updatedUser) { logger.warn(`[verifyEmail] [User update failed] [Email: ${decodedEmail}]`); return new Error('Failed to update user verification status'); @@ -155,6 +155,7 @@ const verifyEmail = async (req) => { logger.info(`[verifyEmail] Email verification successful [Email: ${decodedEmail}]`); return { message: 'Email verification was successful', status: 'success' }; }; + /** * Register a new user. * @param {MongoUser} user @@ -216,7 +217,9 @@ const registerUser = async (user, additionalData = {}) => { const emailEnabled = checkEmailConfig(); const disableTTL = isEnabled(process.env.ALLOW_UNVERIFIED_EMAIL_LOGIN); - const newUser = await createUser(newUserData, disableTTL, true); + const balanceConfig = await getBalanceConfig(); + + const newUser = await createUser(newUserData, balanceConfig, disableTTL, true); newUserId = newUser._id; if (emailEnabled && !newUser.emailVerified) { await sendVerificationEmail({ @@ -389,6 +392,7 @@ const setAuthTokens = async (userId, res, sessionId = null) => { throw error; } }; + /** * @function setOpenIDAuthTokens * Set OpenID Authentication Tokens @@ -405,7 +409,9 @@ const setOpenIDAuthTokens = (tokenset, res) => { return; } const { REFRESH_TOKEN_EXPIRY } = process.env ?? {}; - const expiryInMilliseconds = eval(REFRESH_TOKEN_EXPIRY) ?? 1000 * 60 * 60 * 24 * 7; // 7 days default + const expiryInMilliseconds = REFRESH_TOKEN_EXPIRY + ? eval(REFRESH_TOKEN_EXPIRY) + : 1000 * 60 * 60 * 24 * 7; // 7 days default const expirationDate = new Date(Date.now() + expiryInMilliseconds); if (tokenset == null) { logger.error('[setOpenIDAuthTokens] No tokenset found in request'); diff --git a/api/server/services/Endpoints/agents/agent.js b/api/server/services/Endpoints/agents/agent.js new file mode 100644 index 0000000000..13a42140db --- /dev/null +++ b/api/server/services/Endpoints/agents/agent.js @@ -0,0 +1,196 @@ +const { Providers } = require('@librechat/agents'); +const { primeResources, optionalChainWithEmptyCheck } = require('@librechat/api'); +const { + ErrorTypes, + EModelEndpoint, + EToolResources, + replaceSpecialVars, + providerEndpointMap, +} = require('librechat-data-provider'); +const initAnthropic = require('~/server/services/Endpoints/anthropic/initialize'); +const getBedrockOptions = require('~/server/services/Endpoints/bedrock/options'); +const initOpenAI = require('~/server/services/Endpoints/openAI/initialize'); +const initCustom = require('~/server/services/Endpoints/custom/initialize'); +const initGoogle = require('~/server/services/Endpoints/google/initialize'); +const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts'); +const { getCustomEndpointConfig } = require('~/server/services/Config'); +const { processFiles } = require('~/server/services/Files/process'); +const { getConvoFiles } = require('~/models/Conversation'); +const { getToolFilesByIds } = require('~/models/File'); +const { getModelMaxTokens } = require('~/utils'); +const { getFiles } = require('~/models/File'); + +const providerConfigMap = { + [Providers.XAI]: initCustom, + [Providers.OLLAMA]: initCustom, + [Providers.DEEPSEEK]: initCustom, + [Providers.OPENROUTER]: initCustom, + [EModelEndpoint.openAI]: initOpenAI, + [EModelEndpoint.google]: initGoogle, + [EModelEndpoint.azureOpenAI]: initOpenAI, + [EModelEndpoint.anthropic]: initAnthropic, + [EModelEndpoint.bedrock]: getBedrockOptions, +}; + +/** + * @param {object} params + * @param {ServerRequest} params.req + * @param {ServerResponse} params.res + * @param {Agent} params.agent + * @param {string | null} [params.conversationId] + * @param {Array} [params.requestFiles] + * @param {typeof import('~/server/services/ToolService').loadAgentTools | undefined} [params.loadTools] + * @param {TEndpointOption} [params.endpointOption] + * @param {Set} [params.allowedProviders] + * @param {boolean} [params.isInitialAgent] + * @returns {Promise, toolContextMap: Record, maxContextTokens: number }>} + */ +const initializeAgent = async ({ + req, + res, + agent, + loadTools, + requestFiles, + conversationId, + endpointOption, + allowedProviders, + isInitialAgent = false, +}) => { + if (allowedProviders.size > 0 && !allowedProviders.has(agent.provider)) { + throw new Error( + `{ "type": "${ErrorTypes.INVALID_AGENT_PROVIDER}", "info": "${agent.provider}" }`, + ); + } + let currentFiles; + + if ( + isInitialAgent && + conversationId != null && + (agent.model_parameters?.resendFiles ?? true) === true + ) { + const fileIds = (await getConvoFiles(conversationId)) ?? []; + /** @type {Set} */ + const toolResourceSet = new Set(); + for (const tool of agent.tools) { + if (EToolResources[tool]) { + toolResourceSet.add(EToolResources[tool]); + } + } + const toolFiles = await getToolFilesByIds(fileIds, toolResourceSet); + if (requestFiles.length || toolFiles.length) { + currentFiles = await processFiles(requestFiles.concat(toolFiles)); + } + } else if (isInitialAgent && requestFiles.length) { + currentFiles = await processFiles(requestFiles); + } + + const { attachments, tool_resources } = await primeResources({ + req, + getFiles, + attachments: currentFiles, + tool_resources: agent.tool_resources, + requestFileSet: new Set(requestFiles?.map((file) => file.file_id)), + }); + + const provider = agent.provider; + const { tools, toolContextMap } = + (await loadTools?.({ + req, + res, + provider, + agentId: agent.id, + tools: agent.tools, + model: agent.model, + tool_resources, + })) ?? {}; + + agent.endpoint = provider; + let getOptions = providerConfigMap[provider]; + if (!getOptions && providerConfigMap[provider.toLowerCase()] != null) { + agent.provider = provider.toLowerCase(); + getOptions = providerConfigMap[agent.provider]; + } else if (!getOptions) { + const customEndpointConfig = await getCustomEndpointConfig(provider); + if (!customEndpointConfig) { + throw new Error(`Provider ${provider} not supported`); + } + getOptions = initCustom; + agent.provider = Providers.OPENAI; + } + const model_parameters = Object.assign( + {}, + agent.model_parameters ?? { model: agent.model }, + isInitialAgent === true ? endpointOption?.model_parameters : {}, + ); + const _endpointOption = + isInitialAgent === true + ? Object.assign({}, endpointOption, { model_parameters }) + : { model_parameters }; + + const options = await getOptions({ + req, + res, + optionsOnly: true, + overrideEndpoint: provider, + overrideModel: agent.model, + endpointOption: _endpointOption, + }); + + if ( + agent.endpoint === EModelEndpoint.azureOpenAI && + options.llmConfig?.azureOpenAIApiInstanceName == null + ) { + agent.provider = Providers.OPENAI; + } + + if (options.provider != null) { + agent.provider = options.provider; + } + + /** @type {import('@librechat/agents').ClientOptions} */ + agent.model_parameters = Object.assign(model_parameters, options.llmConfig); + if (options.configOptions) { + agent.model_parameters.configuration = options.configOptions; + } + + if (!agent.model_parameters.model) { + agent.model_parameters.model = agent.model; + } + + if (agent.instructions && agent.instructions !== '') { + agent.instructions = replaceSpecialVars({ + text: agent.instructions, + user: req.user, + }); + } + + if (typeof agent.artifacts === 'string' && agent.artifacts !== '') { + agent.additional_instructions = generateArtifactsPrompt({ + endpoint: agent.provider, + artifacts: agent.artifacts, + }); + } + + const tokensModel = + agent.provider === EModelEndpoint.azureOpenAI ? agent.model : agent.model_parameters.model; + const maxTokens = optionalChainWithEmptyCheck( + agent.model_parameters.maxOutputTokens, + agent.model_parameters.maxTokens, + 0, + ); + const maxContextTokens = optionalChainWithEmptyCheck( + agent.model_parameters.maxContextTokens, + agent.max_context_tokens, + getModelMaxTokens(tokensModel, providerEndpointMap[provider]), + 4096, + ); + return { + ...agent, + tools, + attachments, + toolContextMap, + maxContextTokens: (maxContextTokens - maxTokens) * 0.9, + }; +}; + +module.exports = { initializeAgent }; diff --git a/api/server/services/Endpoints/agents/initialize.js b/api/server/services/Endpoints/agents/initialize.js index c9e363e815..e3154fe13a 100644 --- a/api/server/services/Endpoints/agents/initialize.js +++ b/api/server/services/Endpoints/agents/initialize.js @@ -1,294 +1,41 @@ -const { createContentAggregator, Providers } = require('@librechat/agents'); -const { - Constants, - ErrorTypes, - EModelEndpoint, - EToolResources, - getResponseSender, - AgentCapabilities, - replaceSpecialVars, - providerEndpointMap, -} = require('librechat-data-provider'); +const { logger } = require('@librechat/data-schemas'); +const { createContentAggregator } = require('@librechat/agents'); +const { Constants, EModelEndpoint, getResponseSender } = require('librechat-data-provider'); const { getDefaultHandlers, createToolEndCallback, } = require('~/server/controllers/agents/callbacks'); -const initAnthropic = require('~/server/services/Endpoints/anthropic/initialize'); -const getBedrockOptions = require('~/server/services/Endpoints/bedrock/options'); -const initOpenAI = require('~/server/services/Endpoints/openAI/initialize'); -const initCustom = require('~/server/services/Endpoints/custom/initialize'); -const initGoogle = require('~/server/services/Endpoints/google/initialize'); -const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts'); -const { getCustomEndpointConfig } = require('~/server/services/Config'); -const { processFiles } = require('~/server/services/Files/process'); +const { initializeAgent } = require('~/server/services/Endpoints/agents/agent'); const { loadAgentTools } = require('~/server/services/ToolService'); const AgentClient = require('~/server/controllers/agents/client'); -const { getConvoFiles } = require('~/models/Conversation'); -const { getToolFilesByIds } = require('~/models/File'); -const { getModelMaxTokens } = require('~/utils'); const { getAgent } = require('~/models/Agent'); -const { getFiles } = require('~/models/File'); -const { logger } = require('~/config'); -const providerConfigMap = { - [Providers.XAI]: initCustom, - [Providers.OLLAMA]: initCustom, - [Providers.DEEPSEEK]: initCustom, - [Providers.OPENROUTER]: initCustom, - [EModelEndpoint.openAI]: initOpenAI, - [EModelEndpoint.google]: initGoogle, - [EModelEndpoint.azureOpenAI]: initOpenAI, - [EModelEndpoint.anthropic]: initAnthropic, - [EModelEndpoint.bedrock]: getBedrockOptions, -}; - -/** - * @param {Object} params - * @param {ServerRequest} params.req - * @param {Promise> | undefined} [params.attachments] - * @param {Set} params.requestFileSet - * @param {AgentToolResources | undefined} [params.tool_resources] - * @returns {Promise<{ attachments: Array | undefined, tool_resources: AgentToolResources | undefined }>} - */ -const primeResources = async ({ - req, - attachments: _attachments, - tool_resources: _tool_resources, - requestFileSet, -}) => { - try { - /** @type {Array | undefined} */ - let attachments; - const tool_resources = _tool_resources ?? {}; - const isOCREnabled = (req.app.locals?.[EModelEndpoint.agents]?.capabilities ?? []).includes( - AgentCapabilities.ocr, - ); - if (tool_resources[EToolResources.ocr]?.file_ids && isOCREnabled) { - const context = await getFiles( - { - file_id: { $in: tool_resources.ocr.file_ids }, - }, - {}, - {}, - ); - attachments = (attachments ?? []).concat(context); +function createToolLoader() { + /** + * @param {object} params + * @param {ServerRequest} params.req + * @param {ServerResponse} params.res + * @param {string} params.agentId + * @param {string[]} params.tools + * @param {string} params.provider + * @param {string} params.model + * @param {AgentToolResources} params.tool_resources + * @returns {Promise<{ tools: StructuredTool[], toolContextMap: Record } | undefined>} + */ + return async function loadTools({ req, res, agentId, tools, provider, model, tool_resources }) { + const agent = { id: agentId, tools, provider, model }; + try { + return await loadAgentTools({ + req, + res, + agent, + tool_resources, + }); + } catch (error) { + logger.error('Error loading tools for agent ' + agentId, error); } - if (!_attachments) { - return { attachments, tool_resources }; - } - /** @type {Array | undefined} */ - const files = await _attachments; - if (!attachments) { - /** @type {Array} */ - attachments = []; - } - - for (const file of files) { - if (!file) { - continue; - } - if (file.metadata?.fileIdentifier) { - const execute_code = tool_resources[EToolResources.execute_code] ?? {}; - if (!execute_code.files) { - tool_resources[EToolResources.execute_code] = { ...execute_code, files: [] }; - } - tool_resources[EToolResources.execute_code].files.push(file); - } else if (file.embedded === true) { - const file_search = tool_resources[EToolResources.file_search] ?? {}; - if (!file_search.files) { - tool_resources[EToolResources.file_search] = { ...file_search, files: [] }; - } - tool_resources[EToolResources.file_search].files.push(file); - } else if ( - requestFileSet.has(file.file_id) && - file.type.startsWith('image') && - file.height && - file.width - ) { - const image_edit = tool_resources[EToolResources.image_edit] ?? {}; - if (!image_edit.files) { - tool_resources[EToolResources.image_edit] = { ...image_edit, files: [] }; - } - tool_resources[EToolResources.image_edit].files.push(file); - } - - attachments.push(file); - } - return { attachments, tool_resources }; - } catch (error) { - logger.error('Error priming resources', error); - return { attachments: _attachments, tool_resources: _tool_resources }; - } -}; - -/** - * @param {...string | number} values - * @returns {string | number | undefined} - */ -function optionalChainWithEmptyCheck(...values) { - for (const value of values) { - if (value !== undefined && value !== null && value !== '') { - return value; - } - } - return values[values.length - 1]; -} - -/** - * @param {object} params - * @param {ServerRequest} params.req - * @param {ServerResponse} params.res - * @param {Agent} params.agent - * @param {Set} [params.allowedProviders] - * @param {object} [params.endpointOption] - * @param {boolean} [params.isInitialAgent] - * @returns {Promise} - */ -const initializeAgentOptions = async ({ - req, - res, - agent, - endpointOption, - allowedProviders, - isInitialAgent = false, -}) => { - if (allowedProviders.size > 0 && !allowedProviders.has(agent.provider)) { - throw new Error( - `{ "type": "${ErrorTypes.INVALID_AGENT_PROVIDER}", "info": "${agent.provider}" }`, - ); - } - let currentFiles; - /** @type {Array} */ - const requestFiles = req.body.files ?? []; - if ( - isInitialAgent && - req.body.conversationId != null && - (agent.model_parameters?.resendFiles ?? true) === true - ) { - const fileIds = (await getConvoFiles(req.body.conversationId)) ?? []; - /** @type {Set} */ - const toolResourceSet = new Set(); - for (const tool of agent.tools) { - if (EToolResources[tool]) { - toolResourceSet.add(EToolResources[tool]); - } - } - const toolFiles = await getToolFilesByIds(fileIds, toolResourceSet); - if (requestFiles.length || toolFiles.length) { - currentFiles = await processFiles(requestFiles.concat(toolFiles)); - } - } else if (isInitialAgent && requestFiles.length) { - currentFiles = await processFiles(requestFiles); - } - - const { attachments, tool_resources } = await primeResources({ - req, - attachments: currentFiles, - tool_resources: agent.tool_resources, - requestFileSet: new Set(requestFiles.map((file) => file.file_id)), - }); - - const provider = agent.provider; - const { tools, toolContextMap } = await loadAgentTools({ - req, - res, - agent: { - id: agent.id, - tools: agent.tools, - provider, - model: agent.model, - }, - tool_resources, - }); - - agent.endpoint = provider; - let getOptions = providerConfigMap[provider]; - if (!getOptions && providerConfigMap[provider.toLowerCase()] != null) { - agent.provider = provider.toLowerCase(); - getOptions = providerConfigMap[agent.provider]; - } else if (!getOptions) { - const customEndpointConfig = await getCustomEndpointConfig(provider); - if (!customEndpointConfig) { - throw new Error(`Provider ${provider} not supported`); - } - getOptions = initCustom; - agent.provider = Providers.OPENAI; - } - const model_parameters = Object.assign( - {}, - agent.model_parameters ?? { model: agent.model }, - isInitialAgent === true ? endpointOption?.model_parameters : {}, - ); - const _endpointOption = - isInitialAgent === true - ? Object.assign({}, endpointOption, { model_parameters }) - : { model_parameters }; - - const options = await getOptions({ - req, - res, - optionsOnly: true, - overrideEndpoint: provider, - overrideModel: agent.model, - endpointOption: _endpointOption, - }); - - if ( - agent.endpoint === EModelEndpoint.azureOpenAI && - options.llmConfig?.azureOpenAIApiInstanceName == null - ) { - agent.provider = Providers.OPENAI; - } - - if (options.provider != null) { - agent.provider = options.provider; - } - - /** @type {import('@librechat/agents').ClientOptions} */ - agent.model_parameters = Object.assign(model_parameters, options.llmConfig); - if (options.configOptions) { - agent.model_parameters.configuration = options.configOptions; - } - - if (!agent.model_parameters.model) { - agent.model_parameters.model = agent.model; - } - - if (agent.instructions && agent.instructions !== '') { - agent.instructions = replaceSpecialVars({ - text: agent.instructions, - user: req.user, - }); - } - - if (typeof agent.artifacts === 'string' && agent.artifacts !== '') { - agent.additional_instructions = generateArtifactsPrompt({ - endpoint: agent.provider, - artifacts: agent.artifacts, - }); - } - - const tokensModel = - agent.provider === EModelEndpoint.azureOpenAI ? agent.model : agent.model_parameters.model; - const maxTokens = optionalChainWithEmptyCheck( - agent.model_parameters.maxOutputTokens, - agent.model_parameters.maxTokens, - 0, - ); - const maxContextTokens = optionalChainWithEmptyCheck( - agent.model_parameters.maxContextTokens, - agent.max_context_tokens, - getModelMaxTokens(tokensModel, providerEndpointMap[provider]), - 4096, - ); - return { - ...agent, - tools, - attachments, - toolContextMap, - maxContextTokens: (maxContextTokens - maxTokens) * 0.9, }; -}; +} const initializeClient = async ({ req, res, endpointOption }) => { if (!endpointOption) { @@ -313,7 +60,6 @@ const initializeClient = async ({ req, res, endpointOption }) => { throw new Error('No agent promise provided'); } - // Initialize primary agent const primaryAgent = await endpointOption.agent; if (!primaryAgent) { throw new Error('Agent not found'); @@ -323,10 +69,18 @@ const initializeClient = async ({ req, res, endpointOption }) => { /** @type {Set} */ const allowedProviders = new Set(req?.app?.locals?.[EModelEndpoint.agents]?.allowedProviders); - // Handle primary agent - const primaryConfig = await initializeAgentOptions({ + const loadTools = createToolLoader(); + /** @type {Array} */ + const requestFiles = req.body.files ?? []; + /** @type {string} */ + const conversationId = req.body.conversationId; + + const primaryConfig = await initializeAgent({ req, res, + loadTools, + requestFiles, + conversationId, agent: primaryAgent, endpointOption, allowedProviders, @@ -340,10 +94,13 @@ const initializeClient = async ({ req, res, endpointOption }) => { if (!agent) { throw new Error(`Agent ${agentId} not found`); } - const config = await initializeAgentOptions({ + const config = await initializeAgent({ req, res, agent, + loadTools, + requestFiles, + conversationId, endpointOption, allowedProviders, }); diff --git a/api/server/services/Endpoints/azureAssistants/initialize.js b/api/server/services/Endpoints/azureAssistants/initialize.js index fc8024af07..88acef23e5 100644 --- a/api/server/services/Endpoints/azureAssistants/initialize.js +++ b/api/server/services/Endpoints/azureAssistants/initialize.js @@ -1,5 +1,6 @@ const OpenAI = require('openai'); const { HttpsProxyAgent } = require('https-proxy-agent'); +const { constructAzureURL, isUserProvided } = require('@librechat/api'); const { ErrorTypes, EModelEndpoint, @@ -12,8 +13,6 @@ const { checkUserKeyExpiry, } = require('~/server/services/UserService'); const OpenAIClient = require('~/app/clients/OpenAIClient'); -const { isUserProvided } = require('~/server/utils'); -const { constructAzureURL } = require('~/utils'); class Files { constructor(client) { diff --git a/api/server/services/Endpoints/bedrock/options.js b/api/server/services/Endpoints/bedrock/options.js index da332060e9..fc5536abbf 100644 --- a/api/server/services/Endpoints/bedrock/options.js +++ b/api/server/services/Endpoints/bedrock/options.js @@ -1,4 +1,5 @@ const { HttpsProxyAgent } = require('https-proxy-agent'); +const { createHandleLLMNewToken } = require('@librechat/api'); const { AuthType, Constants, @@ -8,7 +9,6 @@ const { removeNullishValues, } = require('librechat-data-provider'); const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService'); -const { createHandleLLMNewToken } = require('~/app/clients/generators'); const getOptions = async ({ req, overrideModel, endpointOption }) => { const { diff --git a/api/server/services/Endpoints/custom/initialize.js b/api/server/services/Endpoints/custom/initialize.js index 39def8d0d5..754abef5a8 100644 --- a/api/server/services/Endpoints/custom/initialize.js +++ b/api/server/services/Endpoints/custom/initialize.js @@ -6,10 +6,9 @@ const { extractEnvVariable, } = require('librechat-data-provider'); const { Providers } = require('@librechat/agents'); +const { getOpenAIConfig, createHandleLLMNewToken } = require('@librechat/api'); const { getUserKeyValues, checkUserKeyExpiry } = require('~/server/services/UserService'); -const { getLLMConfig } = require('~/server/services/Endpoints/openAI/llm'); const { getCustomEndpointConfig } = require('~/server/services/Config'); -const { createHandleLLMNewToken } = require('~/app/clients/generators'); const { fetchModels } = require('~/server/services/ModelService'); const OpenAIClient = require('~/app/clients/OpenAIClient'); const { isUserProvided } = require('~/server/utils'); @@ -144,7 +143,7 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid clientOptions, ); clientOptions.modelOptions.user = req.user.id; - const options = getLLMConfig(apiKey, clientOptions, endpoint); + const options = getOpenAIConfig(apiKey, clientOptions, endpoint); if (!customOptions.streamRate) { return options; } diff --git a/api/server/services/Endpoints/google/initialize.js b/api/server/services/Endpoints/google/initialize.js index b7419a8a87..b6bc2d6a79 100644 --- a/api/server/services/Endpoints/google/initialize.js +++ b/api/server/services/Endpoints/google/initialize.js @@ -25,9 +25,9 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio const credentials = isUserProvided ? userKey : { - [AuthKeys.GOOGLE_SERVICE_KEY]: serviceKey, - [AuthKeys.GOOGLE_API_KEY]: GOOGLE_KEY, - }; + [AuthKeys.GOOGLE_SERVICE_KEY]: serviceKey, + [AuthKeys.GOOGLE_API_KEY]: GOOGLE_KEY, + }; let clientOptions = {}; diff --git a/api/server/services/Endpoints/google/llm.js b/api/server/services/Endpoints/google/llm.js index a64b33480b..235e1e3df9 100644 --- a/api/server/services/Endpoints/google/llm.js +++ b/api/server/services/Endpoints/google/llm.js @@ -94,7 +94,7 @@ function getLLMConfig(credentials, options = {}) { // Extract from credentials const serviceKeyRaw = creds[AuthKeys.GOOGLE_SERVICE_KEY] ?? {}; const serviceKey = - typeof serviceKeyRaw === 'string' ? JSON.parse(serviceKeyRaw) : serviceKeyRaw ?? {}; + typeof serviceKeyRaw === 'string' ? JSON.parse(serviceKeyRaw) : (serviceKeyRaw ?? {}); const project_id = serviceKey?.project_id ?? null; const apiKey = creds[AuthKeys.GOOGLE_API_KEY] ?? null; @@ -156,10 +156,6 @@ function getLLMConfig(credentials, options = {}) { } if (authHeader) { - /** - * NOTE: NOT SUPPORTED BY LANGCHAIN GENAI CLIENT, - * REQUIRES PR IN https://github.com/langchain-ai/langchainjs - */ llmConfig.customHeaders = { Authorization: `Bearer ${apiKey}`, }; diff --git a/api/server/services/Endpoints/gptPlugins/initialize.js b/api/server/services/Endpoints/gptPlugins/initialize.js index 7bfb43f004..d2af6c757e 100644 --- a/api/server/services/Endpoints/gptPlugins/initialize.js +++ b/api/server/services/Endpoints/gptPlugins/initialize.js @@ -1,11 +1,10 @@ const { EModelEndpoint, - mapModelToAzureConfig, resolveHeaders, + mapModelToAzureConfig, } = require('librechat-data-provider'); +const { isEnabled, isUserProvided, getAzureCredentials } = require('@librechat/api'); const { getUserKeyValues, checkUserKeyExpiry } = require('~/server/services/UserService'); -const { isEnabled, isUserProvided } = require('~/server/utils'); -const { getAzureCredentials } = require('~/utils'); const { PluginsClient } = require('~/app'); const initializeClient = async ({ req, res, endpointOption }) => { diff --git a/api/server/services/Endpoints/gptPlugins/initialize.spec.js b/api/server/services/Endpoints/gptPlugins/initialize.spec.js index 02199c9397..f9cb2750a4 100644 --- a/api/server/services/Endpoints/gptPlugins/initialize.spec.js +++ b/api/server/services/Endpoints/gptPlugins/initialize.spec.js @@ -114,11 +114,11 @@ describe('gptPlugins/initializeClient', () => { test('should initialize PluginsClient with Azure credentials when PLUGINS_USE_AZURE is true', async () => { process.env.AZURE_API_KEY = 'test-azure-api-key'; (process.env.AZURE_OPENAI_API_INSTANCE_NAME = 'some-value'), - (process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME = 'some-value'), - (process.env.AZURE_OPENAI_API_VERSION = 'some-value'), - (process.env.AZURE_OPENAI_API_COMPLETIONS_DEPLOYMENT_NAME = 'some-value'), - (process.env.AZURE_OPENAI_API_EMBEDDINGS_DEPLOYMENT_NAME = 'some-value'), - (process.env.PLUGINS_USE_AZURE = 'true'); + (process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME = 'some-value'), + (process.env.AZURE_OPENAI_API_VERSION = 'some-value'), + (process.env.AZURE_OPENAI_API_COMPLETIONS_DEPLOYMENT_NAME = 'some-value'), + (process.env.AZURE_OPENAI_API_EMBEDDINGS_DEPLOYMENT_NAME = 'some-value'), + (process.env.PLUGINS_USE_AZURE = 'true'); process.env.DEBUG_PLUGINS = 'false'; process.env.OPENAI_SUMMARIZE = 'false'; diff --git a/api/server/services/Endpoints/openAI/initialize.js b/api/server/services/Endpoints/openAI/initialize.js index 714ed5a1e6..bc0907b3de 100644 --- a/api/server/services/Endpoints/openAI/initialize.js +++ b/api/server/services/Endpoints/openAI/initialize.js @@ -4,12 +4,15 @@ const { resolveHeaders, mapModelToAzureConfig, } = require('librechat-data-provider'); +const { + isEnabled, + isUserProvided, + getOpenAIConfig, + getAzureCredentials, + createHandleLLMNewToken, +} = require('@librechat/api'); const { getUserKeyValues, checkUserKeyExpiry } = require('~/server/services/UserService'); -const { getLLMConfig } = require('~/server/services/Endpoints/openAI/llm'); -const { createHandleLLMNewToken } = require('~/app/clients/generators'); -const { isEnabled, isUserProvided } = require('~/server/utils'); const OpenAIClient = require('~/app/clients/OpenAIClient'); -const { getAzureCredentials } = require('~/utils'); const initializeClient = async ({ req, @@ -140,7 +143,7 @@ const initializeClient = async ({ modelOptions.model = modelName; clientOptions = Object.assign({ modelOptions }, clientOptions); clientOptions.modelOptions.user = req.user.id; - const options = getLLMConfig(apiKey, clientOptions); + const options = getOpenAIConfig(apiKey, clientOptions); const streamRate = clientOptions.streamRate; if (!streamRate) { return options; diff --git a/api/server/services/Endpoints/openAI/llm.js b/api/server/services/Endpoints/openAI/llm.js deleted file mode 100644 index c1fd090b28..0000000000 --- a/api/server/services/Endpoints/openAI/llm.js +++ /dev/null @@ -1,170 +0,0 @@ -const { HttpsProxyAgent } = require('https-proxy-agent'); -const { KnownEndpoints } = require('librechat-data-provider'); -const { sanitizeModelName, constructAzureURL } = require('~/utils'); -const { isEnabled } = require('~/server/utils'); - -/** - * Generates configuration options for creating a language model (LLM) instance. - * @param {string} apiKey - The API key for authentication. - * @param {Object} options - Additional options for configuring the LLM. - * @param {Object} [options.modelOptions] - Model-specific options. - * @param {string} [options.modelOptions.model] - The name of the model to use. - * @param {string} [options.modelOptions.user] - The user ID - * @param {number} [options.modelOptions.temperature] - Controls randomness in output generation (0-2). - * @param {number} [options.modelOptions.top_p] - Controls diversity via nucleus sampling (0-1). - * @param {number} [options.modelOptions.frequency_penalty] - Reduces repetition of token sequences (-2 to 2). - * @param {number} [options.modelOptions.presence_penalty] - Encourages discussing new topics (-2 to 2). - * @param {number} [options.modelOptions.max_tokens] - The maximum number of tokens to generate. - * @param {string[]} [options.modelOptions.stop] - Sequences where the API will stop generating further tokens. - * @param {string} [options.reverseProxyUrl] - URL for a reverse proxy, if used. - * @param {boolean} [options.useOpenRouter] - Flag to use OpenRouter API. - * @param {Object} [options.headers] - Additional headers for API requests. - * @param {string} [options.proxy] - Proxy server URL. - * @param {Object} [options.azure] - Azure-specific configurations. - * @param {boolean} [options.streaming] - Whether to use streaming mode. - * @param {Object} [options.addParams] - Additional parameters to add to the model options. - * @param {string[]} [options.dropParams] - Parameters to remove from the model options. - * @param {string|null} [endpoint=null] - The endpoint name - * @returns {Object} Configuration options for creating an LLM instance. - */ -function getLLMConfig(apiKey, options = {}, endpoint = null) { - let { - modelOptions = {}, - reverseProxyUrl, - defaultQuery, - headers, - proxy, - azure, - streaming = true, - addParams, - dropParams, - } = options; - - /** @type {OpenAIClientOptions} */ - let llmConfig = { - streaming, - }; - - Object.assign(llmConfig, modelOptions); - - if (addParams && typeof addParams === 'object') { - Object.assign(llmConfig, addParams); - } - /** Note: OpenAI Web Search models do not support any known parameters besdies `max_tokens` */ - if (modelOptions.model && /gpt-4o.*search/.test(modelOptions.model)) { - const searchExcludeParams = [ - 'frequency_penalty', - 'presence_penalty', - 'temperature', - 'top_p', - 'top_k', - 'stop', - 'logit_bias', - 'seed', - 'response_format', - 'n', - 'logprobs', - 'user', - ]; - - dropParams = dropParams || []; - dropParams = [...new Set([...dropParams, ...searchExcludeParams])]; - } - - if (dropParams && Array.isArray(dropParams)) { - dropParams.forEach((param) => { - if (llmConfig[param]) { - llmConfig[param] = undefined; - } - }); - } - - let useOpenRouter; - /** @type {OpenAIClientOptions['configuration']} */ - const configOptions = {}; - if ( - (reverseProxyUrl && reverseProxyUrl.includes(KnownEndpoints.openrouter)) || - (endpoint && endpoint.toLowerCase().includes(KnownEndpoints.openrouter)) - ) { - useOpenRouter = true; - llmConfig.include_reasoning = true; - configOptions.baseURL = reverseProxyUrl; - configOptions.defaultHeaders = Object.assign( - { - 'HTTP-Referer': 'https://librechat.ai', - 'X-Title': 'LibreChat', - }, - headers, - ); - } else if (reverseProxyUrl) { - configOptions.baseURL = reverseProxyUrl; - if (headers) { - configOptions.defaultHeaders = headers; - } - } - - if (defaultQuery) { - configOptions.defaultQuery = defaultQuery; - } - - if (proxy) { - const proxyAgent = new HttpsProxyAgent(proxy); - Object.assign(configOptions, { - httpAgent: proxyAgent, - httpsAgent: proxyAgent, - }); - } - - if (azure) { - const useModelName = isEnabled(process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME); - azure.azureOpenAIApiDeploymentName = useModelName - ? sanitizeModelName(llmConfig.model) - : azure.azureOpenAIApiDeploymentName; - - if (process.env.AZURE_OPENAI_DEFAULT_MODEL) { - llmConfig.model = process.env.AZURE_OPENAI_DEFAULT_MODEL; - } - - if (configOptions.baseURL) { - const azureURL = constructAzureURL({ - baseURL: configOptions.baseURL, - azureOptions: azure, - }); - azure.azureOpenAIBasePath = azureURL.split(`/${azure.azureOpenAIApiDeploymentName}`)[0]; - } - - Object.assign(llmConfig, azure); - llmConfig.model = llmConfig.azureOpenAIApiDeploymentName; - } else { - llmConfig.apiKey = apiKey; - // Object.assign(llmConfig, { - // configuration: { apiKey }, - // }); - } - - if (process.env.OPENAI_ORGANIZATION && this.azure) { - llmConfig.organization = process.env.OPENAI_ORGANIZATION; - } - - if (useOpenRouter && llmConfig.reasoning_effort != null) { - llmConfig.reasoning = { - effort: llmConfig.reasoning_effort, - }; - delete llmConfig.reasoning_effort; - } - - if (llmConfig?.['max_tokens'] != null) { - /** @type {number} */ - llmConfig.maxTokens = llmConfig['max_tokens']; - delete llmConfig['max_tokens']; - } - - return { - /** @type {OpenAIClientOptions} */ - llmConfig, - /** @type {OpenAIClientOptions['configuration']} */ - configOptions, - }; -} - -module.exports = { getLLMConfig }; diff --git a/api/server/services/Files/Audio/STTService.js b/api/server/services/Files/Audio/STTService.js index d6c8cc4146..49a800336b 100644 --- a/api/server/services/Files/Audio/STTService.js +++ b/api/server/services/Files/Audio/STTService.js @@ -2,9 +2,9 @@ const axios = require('axios'); const fs = require('fs').promises; const FormData = require('form-data'); const { Readable } = require('stream'); +const { genAzureEndpoint } = require('@librechat/api'); const { extractEnvVariable, STTProviders } = require('librechat-data-provider'); const { getCustomConfig } = require('~/server/services/Config'); -const { genAzureEndpoint } = require('~/utils'); const { logger } = require('~/config'); /** diff --git a/api/server/services/Files/Audio/TTSService.js b/api/server/services/Files/Audio/TTSService.js index cd718fdfc1..34d8202156 100644 --- a/api/server/services/Files/Audio/TTSService.js +++ b/api/server/services/Files/Audio/TTSService.js @@ -1,8 +1,8 @@ const axios = require('axios'); +const { genAzureEndpoint } = require('@librechat/api'); const { extractEnvVariable, TTSProviders } = require('librechat-data-provider'); const { getRandomVoiceId, createChunkProcessor, splitTextIntoChunks } = require('./streamAudio'); const { getCustomConfig } = require('~/server/services/Config'); -const { genAzureEndpoint } = require('~/utils'); const { logger } = require('~/config'); /** diff --git a/api/server/services/Files/Azure/images.js b/api/server/services/Files/Azure/images.js index a83b700af3..80d5e76290 100644 --- a/api/server/services/Files/Azure/images.js +++ b/api/server/services/Files/Azure/images.js @@ -1,10 +1,9 @@ const fs = require('fs'); const path = require('path'); const sharp = require('sharp'); +const { logger } = require('@librechat/data-schemas'); const { resizeImageBuffer } = require('../images/resize'); -const { updateUser } = require('~/models/userMethods'); -const { updateFile } = require('~/models/File'); -const { logger } = require('~/config'); +const { updateUser, updateFile } = require('~/models'); const { saveBufferToAzure } = require('./crud'); /** @@ -92,24 +91,44 @@ async function prepareAzureImageURL(req, file) { * @param {Buffer} params.buffer - The avatar image buffer. * @param {string} params.userId - The user's id. * @param {string} params.manual - Flag to indicate manual update. + * @param {string} [params.agentId] - Optional agent ID if this is an agent avatar. * @param {string} [params.basePath='images'] - The base folder within the container. * @param {string} [params.containerName] - The Azure Blob container name. * @returns {Promise} The URL of the avatar. */ -async function processAzureAvatar({ buffer, userId, manual, basePath = 'images', containerName }) { +async function processAzureAvatar({ + buffer, + userId, + manual, + agentId, + basePath = 'images', + containerName, +}) { try { + const metadata = await sharp(buffer).metadata(); + const extension = metadata.format === 'gif' ? 'gif' : 'png'; + const timestamp = new Date().getTime(); + + /** Unique filename with timestamp and optional agent ID */ + const fileName = agentId + ? `agent-${agentId}-avatar-${timestamp}.${extension}` + : `avatar-${timestamp}.${extension}`; + const downloadURL = await saveBufferToAzure({ userId, buffer, - fileName: 'avatar.png', + fileName, basePath, containerName, }); const isManual = manual === 'true'; const url = `${downloadURL}?manual=${isManual}`; - if (isManual) { + + // Only update user record if this is a user avatar (manual === 'true') + if (isManual && !agentId) { await updateUser(userId, { avatar: url }); } + return url; } catch (error) { logger.error('[processAzureAvatar] Error uploading profile picture to Azure:', error); diff --git a/api/server/services/Files/Code/crud.js b/api/server/services/Files/Code/crud.js index caea9ab30a..c696eae0c4 100644 --- a/api/server/services/Files/Code/crud.js +++ b/api/server/services/Files/Code/crud.js @@ -1,7 +1,6 @@ const FormData = require('form-data'); const { getCodeBaseURL } = require('@librechat/agents'); -const { createAxiosInstance } = require('~/config'); -const { logAxiosError } = require('~/utils'); +const { createAxiosInstance, logAxiosError } = require('@librechat/api'); const axios = createAxiosInstance(); diff --git a/api/server/services/Files/Code/process.js b/api/server/services/Files/Code/process.js index c92e628589..cf65154983 100644 --- a/api/server/services/Files/Code/process.js +++ b/api/server/services/Files/Code/process.js @@ -1,6 +1,8 @@ const path = require('path'); const { v4 } = require('uuid'); const axios = require('axios'); +const { logAxiosError } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { getCodeBaseURL } = require('@librechat/agents'); const { Tools, @@ -12,8 +14,6 @@ const { const { getStrategyFunctions } = require('~/server/services/Files/strategies'); const { convertImage } = require('~/server/services/Files/images/convert'); const { createFile, getFiles, updateFile } = require('~/models/File'); -const { logAxiosError } = require('~/utils'); -const { logger } = require('~/config'); /** * Process OpenAI image files, convert to target format, save and return file metadata. diff --git a/api/server/services/Files/Firebase/images.js b/api/server/services/Files/Firebase/images.js index 7345f30df1..8b0866b5d0 100644 --- a/api/server/services/Files/Firebase/images.js +++ b/api/server/services/Files/Firebase/images.js @@ -1,11 +1,10 @@ const fs = require('fs'); const path = require('path'); const sharp = require('sharp'); +const { logger } = require('@librechat/data-schemas'); const { resizeImageBuffer } = require('../images/resize'); -const { updateUser } = require('~/models/userMethods'); +const { updateUser, updateFile } = require('~/models'); const { saveBufferToFirebase } = require('./crud'); -const { updateFile } = require('~/models/File'); -const { logger } = require('~/config'); /** * Converts an image file to the target format. The function first resizes the image based on the specified @@ -83,22 +82,32 @@ async function prepareImageURL(req, file) { * @param {Buffer} params.buffer - The Buffer containing the avatar image. * @param {string} params.userId - The user ID. * @param {string} params.manual - A string flag indicating whether the update is manual ('true' or 'false'). + * @param {string} [params.agentId] - Optional agent ID if this is an agent avatar. * @returns {Promise} - A promise that resolves with the URL of the uploaded avatar. * @throws {Error} - Throws an error if Firebase is not initialized or if there is an error in uploading. */ -async function processFirebaseAvatar({ buffer, userId, manual }) { +async function processFirebaseAvatar({ buffer, userId, manual, agentId }) { try { + const metadata = await sharp(buffer).metadata(); + const extension = metadata.format === 'gif' ? 'gif' : 'png'; + const timestamp = new Date().getTime(); + + /** Unique filename with timestamp and optional agent ID */ + const fileName = agentId + ? `agent-${agentId}-avatar-${timestamp}.${extension}` + : `avatar-${timestamp}.${extension}`; + const downloadURL = await saveBufferToFirebase({ userId, buffer, - fileName: 'avatar.png', + fileName, }); const isManual = manual === 'true'; - const url = `${downloadURL}?manual=${isManual}`; - if (isManual) { + // Only update user record if this is a user avatar (manual === 'true') + if (isManual && !agentId) { await updateUser(userId, { avatar: url }); } diff --git a/api/server/services/Files/Local/crud.js b/api/server/services/Files/Local/crud.js index 783230f2f6..7df528c5e1 100644 --- a/api/server/services/Files/Local/crud.js +++ b/api/server/services/Files/Local/crud.js @@ -201,6 +201,10 @@ const unlinkFile = async (filepath) => { */ const deleteLocalFile = async (req, file) => { const { publicPath, uploads } = req.app.locals.paths; + + /** Filepath stripped of query parameters (e.g., ?manual=true) */ + const cleanFilepath = file.filepath.split('?')[0]; + if (file.embedded && process.env.RAG_API_URL) { const jwtToken = req.headers.authorization.split(' ')[1]; axios.delete(`${process.env.RAG_API_URL}/documents`, { @@ -213,32 +217,32 @@ const deleteLocalFile = async (req, file) => { }); } - if (file.filepath.startsWith(`/uploads/${req.user.id}`)) { + if (cleanFilepath.startsWith(`/uploads/${req.user.id}`)) { const userUploadDir = path.join(uploads, req.user.id); - const basePath = file.filepath.split(`/uploads/${req.user.id}/`)[1]; + const basePath = cleanFilepath.split(`/uploads/${req.user.id}/`)[1]; if (!basePath) { - throw new Error(`Invalid file path: ${file.filepath}`); + throw new Error(`Invalid file path: ${cleanFilepath}`); } const filepath = path.join(userUploadDir, basePath); const rel = path.relative(userUploadDir, filepath); if (rel.startsWith('..') || path.isAbsolute(rel) || rel.includes(`..${path.sep}`)) { - throw new Error(`Invalid file path: ${file.filepath}`); + throw new Error(`Invalid file path: ${cleanFilepath}`); } await unlinkFile(filepath); return; } - const parts = file.filepath.split(path.sep); + const parts = cleanFilepath.split(path.sep); const subfolder = parts[1]; if (!subfolder && parts[0] === EModelEndpoint.agents) { logger.warn(`Agent File ${file.file_id} is missing filepath, may have been deleted already`); return; } - const filepath = path.join(publicPath, file.filepath); + const filepath = path.join(publicPath, cleanFilepath); if (!isValidPath(req, publicPath, subfolder, filepath)) { throw new Error('Invalid file path'); diff --git a/api/server/services/Files/Local/images.js b/api/server/services/Files/Local/images.js index 1305505381..ea3af87c70 100644 --- a/api/server/services/Files/Local/images.js +++ b/api/server/services/Files/Local/images.js @@ -2,8 +2,7 @@ const fs = require('fs'); const path = require('path'); const sharp = require('sharp'); const { resizeImageBuffer } = require('../images/resize'); -const { updateUser } = require('~/models/userMethods'); -const { updateFile } = require('~/models/File'); +const { updateUser, updateFile } = require('~/models'); /** * Converts an image file to the target format. The function first resizes the image based on the specified @@ -113,10 +112,11 @@ async function prepareImagesLocal(req, file) { * @param {Buffer} params.buffer - The Buffer containing the avatar image. * @param {string} params.userId - The user ID. * @param {string} params.manual - A string flag indicating whether the update is manual ('true' or 'false'). + * @param {string} [params.agentId] - Optional agent ID if this is an agent avatar. * @returns {Promise} - A promise that resolves with the URL of the uploaded avatar. * @throws {Error} - Throws an error if Firebase is not initialized or if there is an error in uploading. */ -async function processLocalAvatar({ buffer, userId, manual }) { +async function processLocalAvatar({ buffer, userId, manual, agentId }) { const userDir = path.resolve( __dirname, '..', @@ -130,7 +130,14 @@ async function processLocalAvatar({ buffer, userId, manual }) { userId, ); - const fileName = `avatar-${new Date().getTime()}.png`; + const metadata = await sharp(buffer).metadata(); + const extension = metadata.format === 'gif' ? 'gif' : 'png'; + + const timestamp = new Date().getTime(); + /** Unique filename with timestamp and optional agent ID */ + const fileName = agentId + ? `agent-${agentId}-avatar-${timestamp}.${extension}` + : `avatar-${timestamp}.${extension}`; const urlRoute = `/images/${userId}/${fileName}`; const avatarPath = path.join(userDir, fileName); @@ -140,7 +147,8 @@ async function processLocalAvatar({ buffer, userId, manual }) { const isManual = manual === 'true'; let url = `${urlRoute}?manual=${isManual}`; - if (isManual) { + // Only update user record if this is a user avatar (manual === 'true') + if (isManual && !agentId) { await updateUser(userId, { avatar: url }); } diff --git a/api/server/services/Files/MistralOCR/crud.js b/api/server/services/Files/MistralOCR/crud.js deleted file mode 100644 index cc01d803b0..0000000000 --- a/api/server/services/Files/MistralOCR/crud.js +++ /dev/null @@ -1,237 +0,0 @@ -// ~/server/services/Files/MistralOCR/crud.js -const fs = require('fs'); -const path = require('path'); -const FormData = require('form-data'); -const { - FileSources, - envVarRegex, - extractEnvVariable, - extractVariableName, -} = require('librechat-data-provider'); -const { loadAuthValues } = require('~/server/services/Tools/credentials'); -const { logger, createAxiosInstance } = require('~/config'); -const { logAxiosError } = require('~/utils/axios'); - -const axios = createAxiosInstance(); - -/** - * Uploads a document to Mistral API using file streaming to avoid loading the entire file into memory - * - * @param {Object} params Upload parameters - * @param {string} params.filePath The path to the file on disk - * @param {string} [params.fileName] Optional filename to use (defaults to the name from filePath) - * @param {string} params.apiKey Mistral API key - * @param {string} [params.baseURL=https://api.mistral.ai/v1] Mistral API base URL - * @returns {Promise} The response from Mistral API - */ -async function uploadDocumentToMistral({ - filePath, - fileName = '', - apiKey, - baseURL = 'https://api.mistral.ai/v1', -}) { - const form = new FormData(); - form.append('purpose', 'ocr'); - const actualFileName = fileName || path.basename(filePath); - const fileStream = fs.createReadStream(filePath); - form.append('file', fileStream, { filename: actualFileName }); - - return axios - .post(`${baseURL}/files`, form, { - headers: { - Authorization: `Bearer ${apiKey}`, - ...form.getHeaders(), - }, - maxBodyLength: Infinity, - maxContentLength: Infinity, - }) - .then((res) => res.data) - .catch((error) => { - throw error; - }); -} - -async function getSignedUrl({ - apiKey, - fileId, - expiry = 24, - baseURL = 'https://api.mistral.ai/v1', -}) { - return axios - .get(`${baseURL}/files/${fileId}/url?expiry=${expiry}`, { - headers: { - Authorization: `Bearer ${apiKey}`, - }, - }) - .then((res) => res.data) - .catch((error) => { - logger.error('Error fetching signed URL:', error.message); - throw error; - }); -} - -/** - * @param {Object} params - * @param {string} params.apiKey - * @param {string} params.url - The document or image URL - * @param {string} [params.documentType='document_url'] - 'document_url' or 'image_url' - * @param {string} [params.model] - * @param {string} [params.baseURL] - * @returns {Promise} - */ -async function performOCR({ - apiKey, - url, - documentType = 'document_url', - model = 'mistral-ocr-latest', - baseURL = 'https://api.mistral.ai/v1', -}) { - const documentKey = documentType === 'image_url' ? 'image_url' : 'document_url'; - return axios - .post( - `${baseURL}/ocr`, - { - model, - include_image_base64: false, - document: { - type: documentType, - [documentKey]: url, - }, - }, - { - headers: { - 'Content-Type': 'application/json', - Authorization: `Bearer ${apiKey}`, - }, - }, - ) - .then((res) => res.data) - .catch((error) => { - logger.error('Error performing OCR:', error.message); - throw error; - }); -} - -/** - * Uploads a file to the Mistral OCR API and processes the OCR result. - * - * @param {Object} params - The params object. - * @param {ServerRequest} params.req - The request object from Express. It should have a `user` property with an `id` - * representing the user - * @param {Express.Multer.File} params.file - The file object, which is part of the request. The file object should - * have a `mimetype` property that tells us the file type - * @param {string} params.file_id - The file ID. - * @param {string} [params.entity_id] - The entity ID, not used here but passed for consistency. - * @returns {Promise<{ filepath: string, bytes: number }>} - The result object containing the processed `text` and `images` (not currently used), - * along with the `filename` and `bytes` properties. - */ -const uploadMistralOCR = async ({ req, file, file_id, entity_id }) => { - try { - /** @type {TCustomConfig['ocr']} */ - const ocrConfig = req.app.locals?.ocr; - - const apiKeyConfig = ocrConfig.apiKey || ''; - const baseURLConfig = ocrConfig.baseURL || ''; - - const isApiKeyEnvVar = envVarRegex.test(apiKeyConfig); - const isBaseURLEnvVar = envVarRegex.test(baseURLConfig); - - const isApiKeyEmpty = !apiKeyConfig.trim(); - const isBaseURLEmpty = !baseURLConfig.trim(); - - let apiKey, baseURL; - - if (isApiKeyEnvVar || isBaseURLEnvVar || isApiKeyEmpty || isBaseURLEmpty) { - const apiKeyVarName = isApiKeyEnvVar ? extractVariableName(apiKeyConfig) : 'OCR_API_KEY'; - const baseURLVarName = isBaseURLEnvVar ? extractVariableName(baseURLConfig) : 'OCR_BASEURL'; - - const authValues = await loadAuthValues({ - userId: req.user.id, - authFields: [baseURLVarName, apiKeyVarName], - optional: new Set([baseURLVarName]), - }); - - apiKey = authValues[apiKeyVarName]; - baseURL = authValues[baseURLVarName]; - } else { - apiKey = apiKeyConfig; - baseURL = baseURLConfig; - } - - const mistralFile = await uploadDocumentToMistral({ - filePath: file.path, - fileName: file.originalname, - apiKey, - baseURL, - }); - - const modelConfig = ocrConfig.mistralModel || ''; - const model = envVarRegex.test(modelConfig) - ? extractEnvVariable(modelConfig) - : modelConfig.trim() || 'mistral-ocr-latest'; - - const signedUrlResponse = await getSignedUrl({ - apiKey, - baseURL, - fileId: mistralFile.id, - }); - - const mimetype = (file.mimetype || '').toLowerCase(); - const originalname = file.originalname || ''; - const isImage = - mimetype.startsWith('image') || /\.(png|jpe?g|gif|bmp|webp|tiff?)$/i.test(originalname); - const documentType = isImage ? 'image_url' : 'document_url'; - - const ocrResult = await performOCR({ - apiKey, - baseURL, - model, - url: signedUrlResponse.url, - documentType, - }); - - let aggregatedText = ''; - const images = []; - ocrResult.pages.forEach((page, index) => { - if (ocrResult.pages.length > 1) { - aggregatedText += `# PAGE ${index + 1}\n`; - } - - aggregatedText += page.markdown + '\n\n'; - - if (page.images && page.images.length > 0) { - page.images.forEach((image) => { - if (image.image_base64) { - images.push(image.image_base64); - } - }); - } - }); - - return { - filename: file.originalname, - bytes: aggregatedText.length * 4, - filepath: FileSources.mistral_ocr, - text: aggregatedText, - images, - }; - } catch (error) { - let message = 'Error uploading document to Mistral OCR API'; - const detail = error?.response?.data?.detail; - if (detail && detail !== '') { - message = detail; - } - - const responseMessage = error?.response?.data?.message; - throw new Error( - `${logAxiosError({ error, message })}${responseMessage && responseMessage !== '' ? ` - ${responseMessage}` : ''}`, - ); - } -}; - -module.exports = { - uploadDocumentToMistral, - uploadMistralOCR, - getSignedUrl, - performOCR, -}; diff --git a/api/server/services/Files/MistralOCR/crud.spec.js b/api/server/services/Files/MistralOCR/crud.spec.js deleted file mode 100644 index 8cc63cade2..0000000000 --- a/api/server/services/Files/MistralOCR/crud.spec.js +++ /dev/null @@ -1,846 +0,0 @@ -const fs = require('fs'); - -const mockAxios = { - interceptors: { - request: { use: jest.fn(), eject: jest.fn() }, - response: { use: jest.fn(), eject: jest.fn() }, - }, - create: jest.fn().mockReturnValue({ - defaults: { - proxy: null, - }, - get: jest.fn().mockResolvedValue({ data: {} }), - post: jest.fn().mockResolvedValue({ data: {} }), - put: jest.fn().mockResolvedValue({ data: {} }), - delete: jest.fn().mockResolvedValue({ data: {} }), - }), - get: jest.fn().mockResolvedValue({ data: {} }), - post: jest.fn().mockResolvedValue({ data: {} }), - put: jest.fn().mockResolvedValue({ data: {} }), - delete: jest.fn().mockResolvedValue({ data: {} }), - reset: jest.fn().mockImplementation(function () { - this.get.mockClear(); - this.post.mockClear(); - this.put.mockClear(); - this.delete.mockClear(); - this.create.mockClear(); - }), -}; - -jest.mock('axios', () => mockAxios); -jest.mock('fs'); -jest.mock('~/config', () => ({ - logger: { - error: jest.fn(), - }, - createAxiosInstance: () => mockAxios, -})); -jest.mock('~/server/services/Tools/credentials', () => ({ - loadAuthValues: jest.fn(), -})); - -const { uploadDocumentToMistral, uploadMistralOCR, getSignedUrl, performOCR } = require('./crud'); - -describe('MistralOCR Service', () => { - afterEach(() => { - mockAxios.reset(); - jest.clearAllMocks(); - }); - - describe('uploadDocumentToMistral', () => { - beforeEach(() => { - // Create a more complete mock for file streams that FormData can work with - const mockReadStream = { - on: jest.fn().mockImplementation(function (event, handler) { - // Simulate immediate 'end' event to make FormData complete processing - if (event === 'end') { - handler(); - } - return this; - }), - pipe: jest.fn().mockImplementation(function () { - return this; - }), - pause: jest.fn(), - resume: jest.fn(), - emit: jest.fn(), - once: jest.fn(), - destroy: jest.fn(), - }; - - fs.createReadStream = jest.fn().mockReturnValue(mockReadStream); - - // Mock FormData's append to avoid actual stream processing - jest.mock('form-data', () => { - const mockFormData = function () { - return { - append: jest.fn(), - getHeaders: jest - .fn() - .mockReturnValue({ 'content-type': 'multipart/form-data; boundary=---boundary' }), - getBuffer: jest.fn().mockReturnValue(Buffer.from('mock-form-data')), - getLength: jest.fn().mockReturnValue(100), - }; - }; - return mockFormData; - }); - }); - - it('should upload a document to Mistral API using file streaming', async () => { - const mockResponse = { data: { id: 'file-123', purpose: 'ocr' } }; - mockAxios.post.mockResolvedValueOnce(mockResponse); - - const result = await uploadDocumentToMistral({ - filePath: '/path/to/test.pdf', - fileName: 'test.pdf', - apiKey: 'test-api-key', - }); - - // Check that createReadStream was called with the correct file path - expect(fs.createReadStream).toHaveBeenCalledWith('/path/to/test.pdf'); - - // Since we're mocking FormData, we'll just check that axios was called correctly - expect(mockAxios.post).toHaveBeenCalledWith( - 'https://api.mistral.ai/v1/files', - expect.anything(), - expect.objectContaining({ - headers: expect.objectContaining({ - Authorization: 'Bearer test-api-key', - }), - maxBodyLength: Infinity, - maxContentLength: Infinity, - }), - ); - expect(result).toEqual(mockResponse.data); - }); - - it('should handle errors during document upload', async () => { - const errorMessage = 'API error'; - mockAxios.post.mockRejectedValueOnce(new Error(errorMessage)); - - await expect( - uploadDocumentToMistral({ - filePath: '/path/to/test.pdf', - fileName: 'test.pdf', - apiKey: 'test-api-key', - }), - ).rejects.toThrow(errorMessage); - }); - }); - - describe('getSignedUrl', () => { - it('should fetch signed URL from Mistral API', async () => { - const mockResponse = { data: { url: 'https://document-url.com' } }; - mockAxios.get.mockResolvedValueOnce(mockResponse); - - const result = await getSignedUrl({ - fileId: 'file-123', - apiKey: 'test-api-key', - }); - - expect(mockAxios.get).toHaveBeenCalledWith( - 'https://api.mistral.ai/v1/files/file-123/url?expiry=24', - { - headers: { - Authorization: 'Bearer test-api-key', - }, - }, - ); - expect(result).toEqual(mockResponse.data); - }); - - it('should handle errors when fetching signed URL', async () => { - const errorMessage = 'API error'; - mockAxios.get.mockRejectedValueOnce(new Error(errorMessage)); - - await expect( - getSignedUrl({ - fileId: 'file-123', - apiKey: 'test-api-key', - }), - ).rejects.toThrow(); - - const { logger } = require('~/config'); - expect(logger.error).toHaveBeenCalledWith('Error fetching signed URL:', errorMessage); - }); - }); - - describe('performOCR', () => { - it('should perform OCR using Mistral API (document_url)', async () => { - const mockResponse = { - data: { - pages: [{ markdown: 'Page 1 content' }, { markdown: 'Page 2 content' }], - }, - }; - mockAxios.post.mockResolvedValueOnce(mockResponse); - - const result = await performOCR({ - apiKey: 'test-api-key', - url: 'https://document-url.com', - model: 'mistral-ocr-latest', - documentType: 'document_url', - }); - - expect(mockAxios.post).toHaveBeenCalledWith( - 'https://api.mistral.ai/v1/ocr', - { - model: 'mistral-ocr-latest', - include_image_base64: false, - document: { - type: 'document_url', - document_url: 'https://document-url.com', - }, - }, - { - headers: { - 'Content-Type': 'application/json', - Authorization: 'Bearer test-api-key', - }, - }, - ); - expect(result).toEqual(mockResponse.data); - }); - - it('should perform OCR using Mistral API (image_url)', async () => { - const mockResponse = { - data: { - pages: [{ markdown: 'Image OCR content' }], - }, - }; - mockAxios.post.mockResolvedValueOnce(mockResponse); - - const result = await performOCR({ - apiKey: 'test-api-key', - url: 'https://image-url.com/image.png', - model: 'mistral-ocr-latest', - documentType: 'image_url', - }); - - expect(mockAxios.post).toHaveBeenCalledWith( - 'https://api.mistral.ai/v1/ocr', - { - model: 'mistral-ocr-latest', - include_image_base64: false, - document: { - type: 'image_url', - image_url: 'https://image-url.com/image.png', - }, - }, - { - headers: { - 'Content-Type': 'application/json', - Authorization: 'Bearer test-api-key', - }, - }, - ); - expect(result).toEqual(mockResponse.data); - }); - - it('should handle errors during OCR processing', async () => { - const errorMessage = 'OCR processing error'; - mockAxios.post.mockRejectedValueOnce(new Error(errorMessage)); - - await expect( - performOCR({ - apiKey: 'test-api-key', - url: 'https://document-url.com', - }), - ).rejects.toThrow(); - - const { logger } = require('~/config'); - expect(logger.error).toHaveBeenCalledWith('Error performing OCR:', errorMessage); - }); - }); - - describe('uploadMistralOCR', () => { - beforeEach(() => { - const mockReadStream = { - on: jest.fn().mockImplementation(function (event, handler) { - if (event === 'end') { - handler(); - } - return this; - }), - pipe: jest.fn().mockImplementation(function () { - return this; - }), - pause: jest.fn(), - resume: jest.fn(), - emit: jest.fn(), - once: jest.fn(), - destroy: jest.fn(), - }; - - fs.createReadStream = jest.fn().mockReturnValue(mockReadStream); - }); - - it('should process OCR for a file with standard configuration', async () => { - // Setup mocks - const { loadAuthValues } = require('~/server/services/Tools/credentials'); - loadAuthValues.mockResolvedValue({ - OCR_API_KEY: 'test-api-key', - OCR_BASEURL: 'https://api.mistral.ai/v1', - }); - - // Mock file upload response - mockAxios.post.mockResolvedValueOnce({ - data: { id: 'file-123', purpose: 'ocr' }, - }); - - // Mock signed URL response - mockAxios.get.mockResolvedValueOnce({ - data: { url: 'https://signed-url.com' }, - }); - - // Mock OCR response with text and images - mockAxios.post.mockResolvedValueOnce({ - data: { - pages: [ - { - markdown: 'Page 1 content', - images: [{ image_base64: 'base64image1' }], - }, - { - markdown: 'Page 2 content', - images: [{ image_base64: 'base64image2' }], - }, - ], - }, - }); - - const req = { - user: { id: 'user123' }, - app: { - locals: { - ocr: { - // Use environment variable syntax to ensure loadAuthValues is called - apiKey: '${OCR_API_KEY}', - baseURL: '${OCR_BASEURL}', - mistralModel: 'mistral-medium', - }, - }, - }, - }; - - const file = { - path: '/tmp/upload/file.pdf', - originalname: 'document.pdf', - mimetype: 'application/pdf', - }; - - const result = await uploadMistralOCR({ - req, - file, - file_id: 'file123', - entity_id: 'entity123', - }); - - expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); - - expect(loadAuthValues).toHaveBeenCalledWith({ - userId: 'user123', - authFields: ['OCR_BASEURL', 'OCR_API_KEY'], - optional: expect.any(Set), - }); - - // Verify OCR result - expect(result).toEqual({ - filename: 'document.pdf', - bytes: expect.any(Number), - filepath: 'mistral_ocr', - text: expect.stringContaining('# PAGE 1'), - images: ['base64image1', 'base64image2'], - }); - }); - - it('should process OCR for an image file and use image_url type', async () => { - const { loadAuthValues } = require('~/server/services/Tools/credentials'); - loadAuthValues.mockResolvedValue({ - OCR_API_KEY: 'test-api-key', - OCR_BASEURL: 'https://api.mistral.ai/v1', - }); - - // Mock file upload response - mockAxios.post.mockResolvedValueOnce({ - data: { id: 'file-456', purpose: 'ocr' }, - }); - - // Mock signed URL response - mockAxios.get.mockResolvedValueOnce({ - data: { url: 'https://signed-url.com/image.png' }, - }); - - // Mock OCR response for image - mockAxios.post.mockResolvedValueOnce({ - data: { - pages: [ - { - markdown: 'Image OCR result', - images: [{ image_base64: 'imgbase64' }], - }, - ], - }, - }); - - const req = { - user: { id: 'user456' }, - app: { - locals: { - ocr: { - apiKey: '${OCR_API_KEY}', - baseURL: '${OCR_BASEURL}', - mistralModel: 'mistral-medium', - }, - }, - }, - }; - - const file = { - path: '/tmp/upload/image.png', - originalname: 'image.png', - mimetype: 'image/png', - }; - - const result = await uploadMistralOCR({ - req, - file, - file_id: 'file456', - entity_id: 'entity456', - }); - - expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/image.png'); - - expect(loadAuthValues).toHaveBeenCalledWith({ - userId: 'user456', - authFields: ['OCR_BASEURL', 'OCR_API_KEY'], - optional: expect.any(Set), - }); - - // Check that the OCR API was called with image_url type - expect(mockAxios.post).toHaveBeenCalledWith( - 'https://api.mistral.ai/v1/ocr', - expect.objectContaining({ - document: expect.objectContaining({ - type: 'image_url', - image_url: 'https://signed-url.com/image.png', - }), - }), - expect.any(Object), - ); - - expect(result).toEqual({ - filename: 'image.png', - bytes: expect.any(Number), - filepath: 'mistral_ocr', - text: expect.stringContaining('Image OCR result'), - images: ['imgbase64'], - }); - }); - - it('should process variable references in configuration', async () => { - // Setup mocks with environment variables - const { loadAuthValues } = require('~/server/services/Tools/credentials'); - loadAuthValues.mockResolvedValue({ - CUSTOM_API_KEY: 'custom-api-key', - CUSTOM_BASEURL: 'https://custom-api.mistral.ai/v1', - }); - - // Mock API responses - mockAxios.post.mockResolvedValueOnce({ - data: { id: 'file-123', purpose: 'ocr' }, - }); - mockAxios.get.mockResolvedValueOnce({ - data: { url: 'https://signed-url.com' }, - }); - mockAxios.post.mockResolvedValueOnce({ - data: { - pages: [{ markdown: 'Content from custom API' }], - }, - }); - - const req = { - user: { id: 'user123' }, - app: { - locals: { - ocr: { - apiKey: '${CUSTOM_API_KEY}', - baseURL: '${CUSTOM_BASEURL}', - mistralModel: '${CUSTOM_MODEL}', - }, - }, - }, - }; - - // Set environment variable for model - process.env.CUSTOM_MODEL = 'mistral-large'; - - const file = { - path: '/tmp/upload/file.pdf', - originalname: 'document.pdf', - }; - - const result = await uploadMistralOCR({ - req, - file, - file_id: 'file123', - entity_id: 'entity123', - }); - - expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); - - // Verify that custom environment variables were extracted and used - expect(loadAuthValues).toHaveBeenCalledWith({ - userId: 'user123', - authFields: ['CUSTOM_BASEURL', 'CUSTOM_API_KEY'], - optional: expect.any(Set), - }); - - // Check that mistral-large was used in the OCR API call - expect(mockAxios.post).toHaveBeenCalledWith( - expect.anything(), - expect.objectContaining({ - model: 'mistral-large', - }), - expect.anything(), - ); - - expect(result.text).toEqual('Content from custom API\n\n'); - }); - - it('should fall back to default values when variables are not properly formatted', async () => { - const { loadAuthValues } = require('~/server/services/Tools/credentials'); - loadAuthValues.mockResolvedValue({ - OCR_API_KEY: 'default-api-key', - OCR_BASEURL: undefined, // Testing optional parameter - }); - - mockAxios.post.mockResolvedValueOnce({ - data: { id: 'file-123', purpose: 'ocr' }, - }); - mockAxios.get.mockResolvedValueOnce({ - data: { url: 'https://signed-url.com' }, - }); - mockAxios.post.mockResolvedValueOnce({ - data: { - pages: [{ markdown: 'Default API result' }], - }, - }); - - const req = { - user: { id: 'user123' }, - app: { - locals: { - ocr: { - // Use environment variable syntax to ensure loadAuthValues is called - apiKey: '${INVALID_FORMAT}', // Using valid env var format but with an invalid name - baseURL: '${OCR_BASEURL}', // Using valid env var format - mistralModel: 'mistral-ocr-latest', // Plain string value - }, - }, - }, - }; - - const file = { - path: '/tmp/upload/file.pdf', - originalname: 'document.pdf', - }; - - await uploadMistralOCR({ - req, - file, - file_id: 'file123', - entity_id: 'entity123', - }); - - expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); - - // Should use the default values - expect(loadAuthValues).toHaveBeenCalledWith({ - userId: 'user123', - authFields: ['OCR_BASEURL', 'INVALID_FORMAT'], - optional: expect.any(Set), - }); - - // Should use the default model when not using environment variable format - expect(mockAxios.post).toHaveBeenCalledWith( - expect.anything(), - expect.objectContaining({ - model: 'mistral-ocr-latest', - }), - expect.anything(), - ); - }); - - it('should handle API errors during OCR process', async () => { - const { loadAuthValues } = require('~/server/services/Tools/credentials'); - loadAuthValues.mockResolvedValue({ - OCR_API_KEY: 'test-api-key', - }); - - // Mock file upload to fail - mockAxios.post.mockRejectedValueOnce(new Error('Upload failed')); - - const req = { - user: { id: 'user123' }, - app: { - locals: { - ocr: { - apiKey: 'OCR_API_KEY', - baseURL: 'OCR_BASEURL', - }, - }, - }, - }; - - const file = { - path: '/tmp/upload/file.pdf', - originalname: 'document.pdf', - }; - - await expect( - uploadMistralOCR({ - req, - file, - file_id: 'file123', - entity_id: 'entity123', - }), - ).rejects.toThrow('Error uploading document to Mistral OCR API'); - expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); - }); - - it('should handle single page documents without page numbering', async () => { - const { loadAuthValues } = require('~/server/services/Tools/credentials'); - loadAuthValues.mockResolvedValue({ - OCR_API_KEY: 'test-api-key', - OCR_BASEURL: 'https://api.mistral.ai/v1', // Make sure this is included - }); - - // Clear all previous mocks - mockAxios.post.mockClear(); - mockAxios.get.mockClear(); - - // 1. First mock: File upload response - mockAxios.post.mockImplementationOnce(() => - Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }), - ); - - // 2. Second mock: Signed URL response - mockAxios.get.mockImplementationOnce(() => - Promise.resolve({ data: { url: 'https://signed-url.com' } }), - ); - - // 3. Third mock: OCR response - mockAxios.post.mockImplementationOnce(() => - Promise.resolve({ - data: { - pages: [{ markdown: 'Single page content' }], - }, - }), - ); - - const req = { - user: { id: 'user123' }, - app: { - locals: { - ocr: { - apiKey: 'OCR_API_KEY', - baseURL: 'OCR_BASEURL', - mistralModel: 'mistral-ocr-latest', - }, - }, - }, - }; - - const file = { - path: '/tmp/upload/file.pdf', - originalname: 'single-page.pdf', - }; - - const result = await uploadMistralOCR({ - req, - file, - file_id: 'file123', - entity_id: 'entity123', - }); - - expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); - - // Verify that single page documents don't include page numbering - expect(result.text).not.toContain('# PAGE'); - expect(result.text).toEqual('Single page content\n\n'); - }); - - it('should use literal values in configuration when provided directly', async () => { - const { loadAuthValues } = require('~/server/services/Tools/credentials'); - // We'll still mock this but it should not be used for literal values - loadAuthValues.mockResolvedValue({}); - - // Clear all previous mocks - mockAxios.post.mockClear(); - mockAxios.get.mockClear(); - - // 1. First mock: File upload response - mockAxios.post.mockImplementationOnce(() => - Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }), - ); - - // 2. Second mock: Signed URL response - mockAxios.get.mockImplementationOnce(() => - Promise.resolve({ data: { url: 'https://signed-url.com' } }), - ); - - // 3. Third mock: OCR response - mockAxios.post.mockImplementationOnce(() => - Promise.resolve({ - data: { - pages: [{ markdown: 'Processed with literal config values' }], - }, - }), - ); - - const req = { - user: { id: 'user123' }, - app: { - locals: { - ocr: { - // Direct values that should be used as-is, without variable substitution - apiKey: 'actual-api-key-value', - baseURL: 'https://direct-api-url.mistral.ai/v1', - mistralModel: 'mistral-direct-model', - }, - }, - }, - }; - - const file = { - path: '/tmp/upload/file.pdf', - originalname: 'direct-values.pdf', - }; - - const result = await uploadMistralOCR({ - req, - file, - file_id: 'file123', - entity_id: 'entity123', - }); - - expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); - - // Verify the correct URL was used with the direct baseURL value - expect(mockAxios.post).toHaveBeenCalledWith( - 'https://direct-api-url.mistral.ai/v1/files', - expect.any(Object), - expect.objectContaining({ - headers: expect.objectContaining({ - Authorization: 'Bearer actual-api-key-value', - }), - }), - ); - - // Check the OCR call was made with the direct model value - expect(mockAxios.post).toHaveBeenCalledWith( - 'https://direct-api-url.mistral.ai/v1/ocr', - expect.objectContaining({ - model: 'mistral-direct-model', - }), - expect.any(Object), - ); - - // Verify the result - expect(result.text).toEqual('Processed with literal config values\n\n'); - - // Verify loadAuthValues was never called since we used direct values - expect(loadAuthValues).not.toHaveBeenCalled(); - }); - - it('should handle empty configuration values and use defaults', async () => { - const { loadAuthValues } = require('~/server/services/Tools/credentials'); - // Set up the mock values to be returned by loadAuthValues - loadAuthValues.mockResolvedValue({ - OCR_API_KEY: 'default-from-env-key', - OCR_BASEURL: 'https://default-from-env.mistral.ai/v1', - }); - - // Clear all previous mocks - mockAxios.post.mockClear(); - mockAxios.get.mockClear(); - - // 1. First mock: File upload response - mockAxios.post.mockImplementationOnce(() => - Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }), - ); - - // 2. Second mock: Signed URL response - mockAxios.get.mockImplementationOnce(() => - Promise.resolve({ data: { url: 'https://signed-url.com' } }), - ); - - // 3. Third mock: OCR response - mockAxios.post.mockImplementationOnce(() => - Promise.resolve({ - data: { - pages: [{ markdown: 'Content from default configuration' }], - }, - }), - ); - - const req = { - user: { id: 'user123' }, - app: { - locals: { - ocr: { - // Empty string values - should fall back to defaults - apiKey: '', - baseURL: '', - mistralModel: '', - }, - }, - }, - }; - - const file = { - path: '/tmp/upload/file.pdf', - originalname: 'empty-config.pdf', - }; - - const result = await uploadMistralOCR({ - req, - file, - file_id: 'file123', - entity_id: 'entity123', - }); - - expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); - - // Verify loadAuthValues was called with the default variable names - expect(loadAuthValues).toHaveBeenCalledWith({ - userId: 'user123', - authFields: ['OCR_BASEURL', 'OCR_API_KEY'], - optional: expect.any(Set), - }); - - // Verify the API calls used the default values from loadAuthValues - expect(mockAxios.post).toHaveBeenCalledWith( - 'https://default-from-env.mistral.ai/v1/files', - expect.any(Object), - expect.objectContaining({ - headers: expect.objectContaining({ - Authorization: 'Bearer default-from-env-key', - }), - }), - ); - - // Verify the OCR model defaulted to mistral-ocr-latest - expect(mockAxios.post).toHaveBeenCalledWith( - 'https://default-from-env.mistral.ai/v1/ocr', - expect.objectContaining({ - model: 'mistral-ocr-latest', - }), - expect.any(Object), - ); - - // Check result - expect(result.text).toEqual('Content from default configuration\n\n'); - }); - }); -}); diff --git a/api/server/services/Files/MistralOCR/index.js b/api/server/services/Files/MistralOCR/index.js deleted file mode 100644 index a6223d1ee5..0000000000 --- a/api/server/services/Files/MistralOCR/index.js +++ /dev/null @@ -1,5 +0,0 @@ -const crud = require('./crud'); - -module.exports = { - ...crud, -}; diff --git a/api/server/services/Files/S3/images.js b/api/server/services/Files/S3/images.js index 378212cb5e..688d5eb68b 100644 --- a/api/server/services/Files/S3/images.js +++ b/api/server/services/Files/S3/images.js @@ -1,11 +1,10 @@ const fs = require('fs'); const path = require('path'); const sharp = require('sharp'); +const { logger } = require('@librechat/data-schemas'); const { resizeImageBuffer } = require('../images/resize'); -const { updateUser } = require('~/models/userMethods'); +const { updateUser, updateFile } = require('~/models'); const { saveBufferToS3 } = require('./crud'); -const { updateFile } = require('~/models/File'); -const { logger } = require('~/config'); const defaultBasePath = 'images'; @@ -95,15 +94,28 @@ async function prepareImageURLS3(req, file) { * @param {Buffer} params.buffer - Avatar image buffer. * @param {string} params.userId - User's unique identifier. * @param {string} params.manual - 'true' or 'false' flag for manual update. + * @param {string} [params.agentId] - Optional agent ID if this is an agent avatar. * @param {string} [params.basePath='images'] - Base path in the bucket. * @returns {Promise} Signed URL of the uploaded avatar. */ -async function processS3Avatar({ buffer, userId, manual, basePath = defaultBasePath }) { +async function processS3Avatar({ buffer, userId, manual, agentId, basePath = defaultBasePath }) { try { - const downloadURL = await saveBufferToS3({ userId, buffer, fileName: 'avatar.png', basePath }); - if (manual === 'true') { + const metadata = await sharp(buffer).metadata(); + const extension = metadata.format === 'gif' ? 'gif' : 'png'; + const timestamp = new Date().getTime(); + + /** Unique filename with timestamp and optional agent ID */ + const fileName = agentId + ? `agent-${agentId}-avatar-${timestamp}.${extension}` + : `avatar-${timestamp}.${extension}`; + + const downloadURL = await saveBufferToS3({ userId, buffer, fileName, basePath }); + + // Only update user record if this is a user avatar (manual === 'true') + if (manual === 'true' && !agentId) { await updateUser(userId, { avatar: downloadURL }); } + return downloadURL; } catch (error) { logger.error('[processS3Avatar] Error processing S3 avatar:', error.message); diff --git a/api/server/services/Files/VectorDB/crud.js b/api/server/services/Files/VectorDB/crud.js index 37a1e81487..1aeabc6c46 100644 --- a/api/server/services/Files/VectorDB/crud.js +++ b/api/server/services/Files/VectorDB/crud.js @@ -1,9 +1,9 @@ const fs = require('fs'); const axios = require('axios'); const FormData = require('form-data'); +const { logAxiosError } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { FileSources } = require('librechat-data-provider'); -const { logAxiosError } = require('~/utils'); -const { logger } = require('~/config'); /** * Deletes a file from the vector database. This function takes a file object, constructs the full path, and diff --git a/api/server/services/Files/images/avatar.js b/api/server/services/Files/images/avatar.js index 3c1068a453..8e81dea26c 100644 --- a/api/server/services/Files/images/avatar.js +++ b/api/server/services/Files/images/avatar.js @@ -44,8 +44,25 @@ async function resizeAvatar({ userId, input, desiredFormat = EImageOutputType.PN throw new Error('Invalid input type. Expected URL, Buffer, or File.'); } - const { width, height } = await sharp(imageBuffer).metadata(); + const metadata = await sharp(imageBuffer).metadata(); + const { width, height } = metadata; const minSize = Math.min(width, height); + + if (metadata.format === 'gif') { + const resizedBuffer = await sharp(imageBuffer, { animated: true }) + .extract({ + left: Math.floor((width - minSize) / 2), + top: Math.floor((height - minSize) / 2), + width: minSize, + height: minSize, + }) + .resize(250, 250) + .gif() + .toBuffer(); + + return resizedBuffer; + } + const squaredBuffer = await sharp(imageBuffer) .extract({ left: Math.floor((width - minSize) / 2), diff --git a/api/server/services/Files/images/encode.js b/api/server/services/Files/images/encode.js index 154941fd89..e87654b378 100644 --- a/api/server/services/Files/images/encode.js +++ b/api/server/services/Files/images/encode.js @@ -1,4 +1,5 @@ const axios = require('axios'); +const { logAxiosError } = require('@librechat/api'); const { FileSources, VisionModes, @@ -7,8 +8,6 @@ const { EModelEndpoint, } = require('librechat-data-provider'); const { getStrategyFunctions } = require('~/server/services/Files/strategies'); -const { logAxiosError } = require('~/utils'); -const { logger } = require('~/config'); /** * Converts a readable stream to a base64 encoded string. diff --git a/api/server/services/Files/process.js b/api/server/services/Files/process.js index 94b1bc4dad..8910163047 100644 --- a/api/server/services/Files/process.js +++ b/api/server/services/Files/process.js @@ -522,7 +522,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => { throw new Error('OCR capability is not enabled for Agents'); } - const { handleFileUpload: uploadMistralOCR } = getStrategyFunctions( + const { handleFileUpload: uploadOCR } = getStrategyFunctions( req.app.locals?.ocr?.strategy ?? FileSources.mistral_ocr, ); const { file_id, temp_file_id } = metadata; @@ -534,7 +534,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => { images, filename, filepath: ocrFileURL, - } = await uploadMistralOCR({ req, file, file_id, entity_id: agent_id, basePath }); + } = await uploadOCR({ req, file, loadAuthValues }); const fileInfo = removeNullishValues({ text, diff --git a/api/server/services/Files/strategies.js b/api/server/services/Files/strategies.js index c6cfe77069..41dcd5518a 100644 --- a/api/server/services/Files/strategies.js +++ b/api/server/services/Files/strategies.js @@ -1,4 +1,5 @@ const { FileSources } = require('librechat-data-provider'); +const { uploadMistralOCR, uploadAzureMistralOCR } = require('@librechat/api'); const { getFirebaseURL, prepareImageURL, @@ -46,7 +47,6 @@ const { const { uploadOpenAIFile, deleteOpenAIFile, getOpenAIFileStream } = require('./OpenAI'); const { getCodeOutputDownloadStream, uploadCodeEnvFile } = require('./Code'); const { uploadVectors, deleteVectors } = require('./VectorDB'); -const { uploadMistralOCR } = require('./MistralOCR'); /** * Firebase Storage Strategy Functions @@ -202,6 +202,26 @@ const mistralOCRStrategy = () => ({ handleFileUpload: uploadMistralOCR, }); +const azureMistralOCRStrategy = () => ({ + /** @type {typeof saveFileFromURL | null} */ + saveURL: null, + /** @type {typeof getLocalFileURL | null} */ + getFileURL: null, + /** @type {typeof saveLocalBuffer | null} */ + saveBuffer: null, + /** @type {typeof processLocalAvatar | null} */ + processAvatar: null, + /** @type {typeof uploadLocalImage | null} */ + handleImageUpload: null, + /** @type {typeof prepareImagesLocal | null} */ + prepareImagePayload: null, + /** @type {typeof deleteLocalFile | null} */ + deleteFile: null, + /** @type {typeof getLocalFileStream | null} */ + getDownloadStream: null, + handleFileUpload: uploadAzureMistralOCR, +}); + // Strategy Selector const getStrategyFunctions = (fileSource) => { if (fileSource === FileSources.firebase) { @@ -222,6 +242,8 @@ const getStrategyFunctions = (fileSource) => { return codeOutputStrategy(); } else if (fileSource === FileSources.mistral_ocr) { return mistralOCRStrategy(); + } else if (fileSource === FileSources.azure_mistral_ocr) { + return azureMistralOCRStrategy(); } else { throw new Error('Invalid file source'); } diff --git a/api/server/services/MCP.js b/api/server/services/MCP.js index b9baef462e..357913e519 100644 --- a/api/server/services/MCP.js +++ b/api/server/services/MCP.js @@ -1,6 +1,6 @@ const { z } = require('zod'); const { tool } = require('@langchain/core/tools'); -const { normalizeServerName } = require('librechat-mcp'); +const { normalizeServerName } = require('@librechat/api'); const { Constants: AgentConstants, Providers } = require('@librechat/agents'); const { Constants, @@ -50,9 +50,10 @@ async function createMCPTool({ req, toolKey, provider: _provider }) { /** @type {(toolArguments: Object | string, config?: GraphRunnableConfig) => Promise} */ const _call = async (toolArguments, config) => { + const userId = config?.configurable?.user?.id || config?.configurable?.user_id; try { const derivedSignal = config?.signal ? AbortSignal.any([config.signal]) : undefined; - const mcpManager = getMCPManager(config?.configurable?.user_id); + const mcpManager = getMCPManager(userId); const provider = (config?.metadata?.provider || _provider)?.toLowerCase(); const result = await mcpManager.callTool({ serverName, @@ -60,8 +61,8 @@ async function createMCPTool({ req, toolKey, provider: _provider }) { provider, toolArguments, options: { - userId: config?.configurable?.user_id, signal: derivedSignal, + user: config?.configurable?.user, }, }); @@ -74,7 +75,7 @@ async function createMCPTool({ req, toolKey, provider: _provider }) { return result; } catch (error) { logger.error( - `[MCP][User: ${config?.configurable?.user_id}][${serverName}] Error calling "${toolName}" MCP tool:`, + `[MCP][User: ${userId}][${serverName}] Error calling "${toolName}" MCP tool:`, error, ); throw new Error( diff --git a/api/server/services/ModelService.js b/api/server/services/ModelService.js index a1ccd7643b..0db13ec318 100644 --- a/api/server/services/ModelService.js +++ b/api/server/services/ModelService.js @@ -1,12 +1,13 @@ const axios = require('axios'); const { Providers } = require('@librechat/agents'); +const { logAxiosError } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { HttpsProxyAgent } = require('https-proxy-agent'); const { EModelEndpoint, defaultModels, CacheKeys } = require('librechat-data-provider'); -const { inputSchema, logAxiosError, extractBaseURL, processModelData } = require('~/utils'); +const { inputSchema, extractBaseURL, processModelData } = require('~/utils'); const { OllamaClient } = require('~/app/clients/OllamaClient'); const { isUserProvided } = require('~/server/utils'); const getLogStores = require('~/cache/getLogStores'); -const { logger } = require('~/config'); /** * Splits a string by commas and trims each resulting value. diff --git a/api/server/services/ModelService.spec.js b/api/server/services/ModelService.spec.js index fb4481f840..33ab9a7aaf 100644 --- a/api/server/services/ModelService.spec.js +++ b/api/server/services/ModelService.spec.js @@ -1,6 +1,6 @@ const axios = require('axios'); +const { logger } = require('@librechat/data-schemas'); const { EModelEndpoint, defaultModels } = require('librechat-data-provider'); -const { logger } = require('~/config'); const { fetchModels, @@ -28,7 +28,8 @@ jest.mock('~/cache/getLogStores', () => set: jest.fn().mockResolvedValue(true), })), ); -jest.mock('~/config', () => ({ +jest.mock('@librechat/data-schemas', () => ({ + ...jest.requireActual('@librechat/data-schemas'), logger: { error: jest.fn(), }, diff --git a/api/server/services/PluginService.js b/api/server/services/PluginService.js index 03e90bce41..7463e0814e 100644 --- a/api/server/services/PluginService.js +++ b/api/server/services/PluginService.js @@ -1,5 +1,5 @@ -const PluginAuth = require('~/models/schema/pluginAuthSchema'); -const { encrypt, decrypt } = require('~/server/utils/'); +const { encrypt, decrypt } = require('~/server/utils/crypto'); +const { PluginAuth } = require('~/db/models'); const { logger } = require('~/config'); /** diff --git a/api/server/services/Runs/methods.js b/api/server/services/Runs/methods.js index 3c18e9969b..167b9cc2ba 100644 --- a/api/server/services/Runs/methods.js +++ b/api/server/services/Runs/methods.js @@ -1,6 +1,6 @@ const axios = require('axios'); +const { logAxiosError } = require('@librechat/api'); const { EModelEndpoint } = require('librechat-data-provider'); -const { logAxiosError } = require('~/utils'); /** * @typedef {Object} RetrieveOptions diff --git a/api/server/services/TokenService.js b/api/server/services/TokenService.js index 3dd2e79ffa..ec74844197 100644 --- a/api/server/services/TokenService.js +++ b/api/server/services/TokenService.js @@ -1,8 +1,9 @@ const axios = require('axios'); +const { logAxiosError } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); +const { TokenExchangeMethodEnum } = require('librechat-data-provider'); const { handleOAuthToken } = require('~/models/Token'); const { decryptV2 } = require('~/server/utils/crypto'); -const { logAxiosError } = require('~/utils'); -const { logger } = require('~/config'); /** * Processes the access tokens and stores them in the database. @@ -49,6 +50,7 @@ async function processAccessTokens(tokenData, { userId, identifier }) { * @param {string} fields.client_url - The URL of the OAuth provider. * @param {string} fields.identifier - The identifier for the token. * @param {string} fields.refresh_token - The refresh token to use. + * @param {string} fields.token_exchange_method - The token exchange method ('default_post' or 'basic_auth_header'). * @param {string} fields.encrypted_oauth_client_id - The client ID for the OAuth provider. * @param {string} fields.encrypted_oauth_client_secret - The client secret for the OAuth provider. * @returns {Promise<{ @@ -63,26 +65,36 @@ const refreshAccessToken = async ({ client_url, identifier, refresh_token, + token_exchange_method, encrypted_oauth_client_id, encrypted_oauth_client_secret, }) => { try { const oauth_client_id = await decryptV2(encrypted_oauth_client_id); const oauth_client_secret = await decryptV2(encrypted_oauth_client_secret); + + const headers = { + 'Content-Type': 'application/x-www-form-urlencoded', + Accept: 'application/json', + }; + const params = new URLSearchParams({ - client_id: oauth_client_id, - client_secret: oauth_client_secret, grant_type: 'refresh_token', refresh_token, }); + if (token_exchange_method === TokenExchangeMethodEnum.BasicAuthHeader) { + const basicAuth = Buffer.from(`${oauth_client_id}:${oauth_client_secret}`).toString('base64'); + headers['Authorization'] = `Basic ${basicAuth}`; + } else { + params.append('client_id', oauth_client_id); + params.append('client_secret', oauth_client_secret); + } + const response = await axios({ method: 'POST', url: client_url, - headers: { - 'Content-Type': 'application/x-www-form-urlencoded', - Accept: 'application/json', - }, + headers, data: params.toString(), }); await processAccessTokens(response.data, { @@ -110,6 +122,7 @@ const refreshAccessToken = async ({ * @param {string} fields.identifier - The identifier for the token. * @param {string} fields.client_url - The URL of the OAuth provider. * @param {string} fields.redirect_uri - The redirect URI for the OAuth provider. + * @param {string} fields.token_exchange_method - The token exchange method ('default_post' or 'basic_auth_header'). * @param {string} fields.encrypted_oauth_client_id - The client ID for the OAuth provider. * @param {string} fields.encrypted_oauth_client_secret - The client secret for the OAuth provider. * @returns {Promise<{ @@ -125,27 +138,37 @@ const getAccessToken = async ({ identifier, client_url, redirect_uri, + token_exchange_method, encrypted_oauth_client_id, encrypted_oauth_client_secret, }) => { const oauth_client_id = await decryptV2(encrypted_oauth_client_id); const oauth_client_secret = await decryptV2(encrypted_oauth_client_secret); + + const headers = { + 'Content-Type': 'application/x-www-form-urlencoded', + Accept: 'application/json', + }; + const params = new URLSearchParams({ code, - client_id: oauth_client_id, - client_secret: oauth_client_secret, grant_type: 'authorization_code', redirect_uri, }); + if (token_exchange_method === TokenExchangeMethodEnum.BasicAuthHeader) { + const basicAuth = Buffer.from(`${oauth_client_id}:${oauth_client_secret}`).toString('base64'); + headers['Authorization'] = `Basic ${basicAuth}`; + } else { + params.append('client_id', oauth_client_id); + params.append('client_secret', oauth_client_secret); + } + try { const response = await axios({ method: 'POST', url: client_url, - headers: { - 'Content-Type': 'application/x-www-form-urlencoded', - Accept: 'application/json', - }, + headers, data: params.toString(), }); diff --git a/api/server/services/ToolService.js b/api/server/services/ToolService.js index 8dd2fbf865..9172c25e96 100644 --- a/api/server/services/ToolService.js +++ b/api/server/services/ToolService.js @@ -500,6 +500,8 @@ async function processRequiredActions(client, requiredActions) { async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey }) { if (!agent.tools || agent.tools.length === 0) { return {}; + } else if (agent.tools && agent.tools.length === 1 && agent.tools[0] === AgentCapabilities.ocr) { + return {}; } const endpointsConfig = await getEndpointsConfig(req); diff --git a/api/server/services/UserService.js b/api/server/services/UserService.js index 91d772477b..b729607f69 100644 --- a/api/server/services/UserService.js +++ b/api/server/services/UserService.js @@ -1,7 +1,8 @@ +const { logger } = require('@librechat/data-schemas'); const { ErrorTypes } = require('librechat-data-provider'); -const { encrypt, decrypt } = require('~/server/utils'); -const { updateUser, Key } = require('~/models'); -const { logger } = require('~/config'); +const { encrypt, decrypt } = require('~/server/utils/crypto'); +const { updateUser } = require('~/models'); +const { Key } = require('~/db/models'); /** * Updates the plugins for a user based on the action specified (install/uninstall). diff --git a/api/server/services/signPayload.js b/api/server/services/signPayload.js deleted file mode 100644 index a7bb0c64fc..0000000000 --- a/api/server/services/signPayload.js +++ /dev/null @@ -1,26 +0,0 @@ -const jwt = require('jsonwebtoken'); - -/** - * Signs a given payload using either the `jose` library (for Bun runtime) or `jsonwebtoken`. - * - * @async - * @function - * @param {Object} options - The options for signing the payload. - * @param {Object} options.payload - The payload to be signed. - * @param {string} options.secret - The secret key used for signing. - * @param {number} options.expirationTime - The expiration time in seconds. - * @returns {Promise} Returns a promise that resolves to the signed JWT. - * @throws {Error} Throws an error if there's an issue during signing. - * - * @example - * const signedPayload = await signPayload({ - * payload: { userId: 123 }, - * secret: 'my-secret-key', - * expirationTime: 3600 - * }); - */ -async function signPayload({ payload, secret, expirationTime }) { - return jwt.sign(payload, secret, { expiresIn: expirationTime }); -} - -module.exports = signPayload; diff --git a/api/server/services/start/agents.js b/api/server/services/start/agents.js deleted file mode 100644 index 10653f3fb6..0000000000 --- a/api/server/services/start/agents.js +++ /dev/null @@ -1,14 +0,0 @@ -const { EModelEndpoint, agentsEndpointSChema } = require('librechat-data-provider'); - -/** - * Sets up the Agents configuration from the config (`librechat.yaml`) file. - * @param {TCustomConfig} config - The loaded custom configuration. - * @returns {Partial} The Agents endpoint configuration. - */ -function agentsConfigSetup(config) { - const agentsConfig = config.endpoints[EModelEndpoint.agents]; - const parsedConfig = agentsEndpointSChema.parse(agentsConfig); - return parsedConfig; -} - -module.exports = { agentsConfigSetup }; diff --git a/api/server/services/start/interface.js b/api/server/services/start/interface.js index 7578c036b2..c98fdb60bc 100644 --- a/api/server/services/start/interface.js +++ b/api/server/services/start/interface.js @@ -2,6 +2,7 @@ const { SystemRoles, Permissions, PermissionTypes, + isMemoryEnabled, removeNullishValues, } = require('librechat-data-provider'); const { updateAccessPermissions } = require('~/models/Role'); @@ -20,6 +21,14 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol const hasModelSpecs = config?.modelSpecs?.list?.length > 0; const includesAddedEndpoints = config?.modelSpecs?.addedEndpoints?.length > 0; + const memoryConfig = config?.memory; + const memoryEnabled = isMemoryEnabled(memoryConfig); + /** Only disable memories if memory config is present but disabled/invalid */ + const shouldDisableMemories = memoryConfig && !memoryEnabled; + /** Check if personalization is enabled (defaults to true if memory is configured and enabled) */ + const isPersonalizationEnabled = + memoryConfig && memoryEnabled && memoryConfig.personalize !== false; + /** @type {TCustomConfig['interface']} */ const loadedInterface = removeNullishValues({ endpointsMenu: @@ -33,6 +42,7 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol privacyPolicy: interfaceConfig?.privacyPolicy ?? defaults.privacyPolicy, termsOfService: interfaceConfig?.termsOfService ?? defaults.termsOfService, bookmarks: interfaceConfig?.bookmarks ?? defaults.bookmarks, + memories: shouldDisableMemories ? false : (interfaceConfig?.memories ?? defaults.memories), prompts: interfaceConfig?.prompts ?? defaults.prompts, multiConvo: interfaceConfig?.multiConvo ?? defaults.multiConvo, agents: interfaceConfig?.agents ?? defaults.agents, @@ -45,6 +55,10 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol await updateAccessPermissions(roleName, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: loadedInterface.prompts }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: loadedInterface.bookmarks }, + [PermissionTypes.MEMORIES]: { + [Permissions.USE]: loadedInterface.memories, + [Permissions.OPT_OUT]: isPersonalizationEnabled, + }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: loadedInterface.multiConvo }, [PermissionTypes.AGENTS]: { [Permissions.USE]: loadedInterface.agents }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: loadedInterface.temporaryChat }, @@ -54,6 +68,10 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol await updateAccessPermissions(SystemRoles.ADMIN, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: loadedInterface.prompts }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: loadedInterface.bookmarks }, + [PermissionTypes.MEMORIES]: { + [Permissions.USE]: loadedInterface.memories, + [Permissions.OPT_OUT]: isPersonalizationEnabled, + }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: loadedInterface.multiConvo }, [PermissionTypes.AGENTS]: { [Permissions.USE]: loadedInterface.agents }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: loadedInterface.temporaryChat }, diff --git a/api/server/services/start/interface.spec.js b/api/server/services/start/interface.spec.js index d0dcfaf55f..1a05c9cf12 100644 --- a/api/server/services/start/interface.spec.js +++ b/api/server/services/start/interface.spec.js @@ -12,6 +12,7 @@ describe('loadDefaultInterface', () => { interface: { prompts: true, bookmarks: true, + memories: true, multiConvo: true, agents: true, temporaryChat: true, @@ -26,6 +27,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: true }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: true }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true }, [PermissionTypes.AGENTS]: { [Permissions.USE]: true }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true }, @@ -39,6 +41,7 @@ describe('loadDefaultInterface', () => { interface: { prompts: false, bookmarks: false, + memories: false, multiConvo: false, agents: false, temporaryChat: false, @@ -53,6 +56,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: false }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: false }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: false }, [PermissionTypes.AGENTS]: { [Permissions.USE]: false }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: false }, @@ -70,6 +74,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: undefined }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: undefined }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, @@ -83,6 +88,7 @@ describe('loadDefaultInterface', () => { interface: { prompts: undefined, bookmarks: undefined, + memories: undefined, multiConvo: undefined, agents: undefined, temporaryChat: undefined, @@ -97,6 +103,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: undefined }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: undefined }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, @@ -110,6 +117,7 @@ describe('loadDefaultInterface', () => { interface: { prompts: true, bookmarks: false, + memories: true, multiConvo: undefined, agents: true, temporaryChat: undefined, @@ -124,6 +132,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: true }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: true }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined }, [PermissionTypes.AGENTS]: { [Permissions.USE]: true }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, @@ -138,6 +147,7 @@ describe('loadDefaultInterface', () => { interface: { prompts: true, bookmarks: true, + memories: true, multiConvo: true, agents: true, temporaryChat: true, @@ -151,6 +161,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: true }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: true }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true }, [PermissionTypes.AGENTS]: { [Permissions.USE]: true }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true }, @@ -168,6 +179,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: undefined }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: undefined }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, @@ -185,6 +197,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: undefined }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: undefined }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: false }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, @@ -202,6 +215,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: undefined }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: undefined }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, @@ -215,6 +229,7 @@ describe('loadDefaultInterface', () => { interface: { prompts: true, bookmarks: false, + memories: true, multiConvo: true, agents: false, temporaryChat: true, @@ -228,6 +243,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: true }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: true }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true }, [PermissionTypes.AGENTS]: { [Permissions.USE]: false }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true }, @@ -242,6 +258,7 @@ describe('loadDefaultInterface', () => { interface: { prompts: true, bookmarks: true, + memories: false, multiConvo: false, agents: undefined, temporaryChat: undefined, @@ -255,6 +272,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: true }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: false }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: false }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, @@ -268,6 +286,7 @@ describe('loadDefaultInterface', () => { interface: { prompts: true, bookmarks: false, + memories: true, multiConvo: true, agents: false, temporaryChat: true, @@ -281,6 +300,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: true }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: true }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true }, [PermissionTypes.AGENTS]: { [Permissions.USE]: false }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true }, diff --git a/api/server/services/twoFactorService.js b/api/server/services/twoFactorService.js index d000c8fcfc..0274842367 100644 --- a/api/server/services/twoFactorService.js +++ b/api/server/services/twoFactorService.js @@ -1,6 +1,6 @@ const { webcrypto } = require('node:crypto'); -const { decryptV3, decryptV2 } = require('../utils/crypto'); -const { hashBackupCode } = require('~/server/utils/crypto'); +const { hashBackupCode, decryptV3, decryptV2 } = require('~/server/utils/crypto'); +const { updateUser } = require('~/models'); // Base32 alphabet for TOTP secret encoding. const BASE32_ALPHABET = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567'; @@ -172,7 +172,6 @@ const verifyBackupCode = async ({ user, backupCode }) => { : codeObj, ); // Update the user record with the marked backup code. - const { updateUser } = require('~/models'); await updateUser(user._id, { backupCodes: updatedBackupCodes }); return true; } diff --git a/api/server/socialLogins.js b/api/server/socialLogins.js index ba335018dc..9b9541cdcd 100644 --- a/api/server/socialLogins.js +++ b/api/server/socialLogins.js @@ -10,6 +10,7 @@ const { discordLogin, facebookLogin, appleLogin, + setupSaml, openIdJwtLogin, } = require('~/strategies'); const { isEnabled } = require('~/server/utils'); @@ -70,6 +71,34 @@ const configureSocialLogins = async (app) => { } logger.info('OpenID Connect configured.'); } + if ( + process.env.SAML_ENTRY_POINT && + process.env.SAML_ISSUER && + process.env.SAML_CERT && + process.env.SAML_SESSION_SECRET + ) { + logger.info('Configuring SAML Connect...'); + const sessionOptions = { + secret: process.env.SAML_SESSION_SECRET, + resave: false, + saveUninitialized: false, + }; + if (isEnabled(process.env.USE_REDIS)) { + logger.debug('Using Redis for session storage in SAML...'); + const keyv = new Keyv({ store: keyvRedis }); + const client = keyv.opts.store.client; + sessionOptions.store = new RedisStore({ client, prefix: 'saml_session' }); + } else { + sessionOptions.store = new MemoryStore({ + checkPeriod: 86400000, // prune expired entries every 24h + }); + } + app.use(session(sessionOptions)); + app.use(passport.session()); + setupSaml(); + + logger.info('SAML Connect configured.'); + } }; module.exports = configureSocialLogins; diff --git a/api/server/utils/crypto.js b/api/server/utils/crypto.js index 333cd7573a..2f176fedee 100644 --- a/api/server/utils/crypto.js +++ b/api/server/utils/crypto.js @@ -106,12 +106,6 @@ function decryptV3(encryptedValue) { return decrypted.toString('utf8'); } -async function hashToken(str) { - const data = new TextEncoder().encode(str); - const hashBuffer = await webcrypto.subtle.digest('SHA-256', data); - return Buffer.from(hashBuffer).toString('hex'); -} - async function getRandomValues(length) { if (!Number.isInteger(length) || length <= 0) { throw new Error('Length must be a positive integer'); @@ -141,7 +135,6 @@ module.exports = { decryptV2, encryptV3, decryptV3, - hashToken, hashBackupCode, getRandomValues, }; diff --git a/api/server/utils/emails/passwordReset.handlebars b/api/server/utils/emails/passwordReset.handlebars index 9076b92edb..6b735f53fd 100644 --- a/api/server/utils/emails/passwordReset.handlebars +++ b/api/server/utils/emails/passwordReset.handlebars @@ -22,17 +22,71 @@ diff --git a/api/server/utils/emails/requestPasswordReset.handlebars b/api/server/utils/emails/requestPasswordReset.handlebars index 2600b5a9d3..b7005254ba 100644 --- a/api/server/utils/emails/requestPasswordReset.handlebars +++ b/api/server/utils/emails/requestPasswordReset.handlebars @@ -22,18 +22,78 @@ diff --git a/api/server/utils/emails/verifyEmail.handlebars b/api/server/utils/emails/verifyEmail.handlebars index 63b52e79be..fa77575053 100644 --- a/api/server/utils/emails/verifyEmail.handlebars +++ b/api/server/utils/emails/verifyEmail.handlebars @@ -22,18 +22,75 @@ diff --git a/api/server/utils/handleText.js b/api/server/utils/handleText.js index 86c17f1dda..680da5da44 100644 --- a/api/server/utils/handleText.js +++ b/api/server/utils/handleText.js @@ -1,5 +1,3 @@ -const path = require('path'); -const crypto = require('crypto'); const { Capabilities, EModelEndpoint, @@ -218,38 +216,6 @@ function normalizeEndpointName(name = '') { return name.toLowerCase() === Providers.OLLAMA ? Providers.OLLAMA : name; } -/** - * Sanitize a filename by removing any directory components, replacing non-alphanumeric characters - * @param {string} inputName - * @returns {string} - */ -function sanitizeFilename(inputName) { - // Remove any directory components - let name = path.basename(inputName); - - // Replace any non-alphanumeric characters except for '.' and '-' - name = name.replace(/[^a-zA-Z0-9.-]/g, '_'); - - // Ensure the name doesn't start with a dot (hidden file in Unix-like systems) - if (name.startsWith('.') || name === '') { - name = '_' + name; - } - - // Limit the length of the filename - const MAX_LENGTH = 255; - if (name.length > MAX_LENGTH) { - const ext = path.extname(name); - const nameWithoutExt = path.basename(name, ext); - name = - nameWithoutExt.slice(0, MAX_LENGTH - ext.length - 7) + - '-' + - crypto.randomBytes(3).toString('hex') + - ext; - } - - return name; -} - module.exports = { isEnabled, handleText, @@ -260,6 +226,5 @@ module.exports = { generateConfig, addSpaceIfNeeded, createOnProgress, - sanitizeFilename, normalizeEndpointName, }; diff --git a/api/server/utils/handleText.spec.js b/api/server/utils/handleText.spec.js deleted file mode 100644 index 8b1b6eef8d..0000000000 --- a/api/server/utils/handleText.spec.js +++ /dev/null @@ -1,99 +0,0 @@ -const { isEnabled, sanitizeFilename } = require('./handleText'); - -describe('isEnabled', () => { - test('should return true when input is "true"', () => { - expect(isEnabled('true')).toBe(true); - }); - - test('should return true when input is "TRUE"', () => { - expect(isEnabled('TRUE')).toBe(true); - }); - - test('should return true when input is true', () => { - expect(isEnabled(true)).toBe(true); - }); - - test('should return false when input is "false"', () => { - expect(isEnabled('false')).toBe(false); - }); - - test('should return false when input is false', () => { - expect(isEnabled(false)).toBe(false); - }); - - test('should return false when input is null', () => { - expect(isEnabled(null)).toBe(false); - }); - - test('should return false when input is undefined', () => { - expect(isEnabled()).toBe(false); - }); - - test('should return false when input is an empty string', () => { - expect(isEnabled('')).toBe(false); - }); - - test('should return false when input is a whitespace string', () => { - expect(isEnabled(' ')).toBe(false); - }); - - test('should return false when input is a number', () => { - expect(isEnabled(123)).toBe(false); - }); - - test('should return false when input is an object', () => { - expect(isEnabled({})).toBe(false); - }); - - test('should return false when input is an array', () => { - expect(isEnabled([])).toBe(false); - }); -}); - -jest.mock('crypto', () => ({ - randomBytes: jest.fn().mockReturnValue(Buffer.from('abc123', 'hex')), -})); - -describe('sanitizeFilename', () => { - test('removes directory components (1/2)', () => { - expect(sanitizeFilename('/path/to/file.txt')).toBe('file.txt'); - }); - - test('removes directory components (2/2)', () => { - expect(sanitizeFilename('../../../../file.txt')).toBe('file.txt'); - }); - - test('replaces non-alphanumeric characters', () => { - expect(sanitizeFilename('file name@#$.txt')).toBe('file_name___.txt'); - }); - - test('preserves dots and hyphens', () => { - expect(sanitizeFilename('file-name.with.dots.txt')).toBe('file-name.with.dots.txt'); - }); - - test('prepends underscore to filenames starting with a dot', () => { - expect(sanitizeFilename('.hiddenfile')).toBe('_.hiddenfile'); - }); - - test('truncates long filenames', () => { - const longName = 'a'.repeat(300) + '.txt'; - const result = sanitizeFilename(longName); - expect(result.length).toBe(255); - expect(result).toMatch(/^a+-abc123\.txt$/); - }); - - test('handles filenames with no extension', () => { - const longName = 'a'.repeat(300); - const result = sanitizeFilename(longName); - expect(result.length).toBe(255); - expect(result).toMatch(/^a+-abc123$/); - }); - - test('handles empty input', () => { - expect(sanitizeFilename('')).toBe('_'); - }); - - test('handles input with only special characters', () => { - expect(sanitizeFilename('@#$%^&*')).toBe('_______'); - }); -}); diff --git a/api/server/utils/index.js b/api/server/utils/index.js index b79b42f00d..aa432ec379 100644 --- a/api/server/utils/index.js +++ b/api/server/utils/index.js @@ -13,12 +13,19 @@ const math = require('./math'); * @returns {Boolean} */ function checkEmailConfig() { - return ( + // Check if Mailgun is configured + const hasMailgunConfig = + !!process.env.MAILGUN_API_KEY && !!process.env.MAILGUN_DOMAIN && !!process.env.EMAIL_FROM; + + // Check if SMTP is configured + const hasSMTPConfig = (!!process.env.EMAIL_SERVICE || !!process.env.EMAIL_HOST) && !!process.env.EMAIL_USERNAME && !!process.env.EMAIL_PASSWORD && - !!process.env.EMAIL_FROM - ); + !!process.env.EMAIL_FROM; + + // Return true if either Mailgun or SMTP is properly configured + return hasMailgunConfig || hasSMTPConfig; } module.exports = { diff --git a/api/server/utils/sendEmail.js b/api/server/utils/sendEmail.js index 59d75830f4..c0afd0eebe 100644 --- a/api/server/utils/sendEmail.js +++ b/api/server/utils/sendEmail.js @@ -1,9 +1,69 @@ const fs = require('fs'); const path = require('path'); +const axios = require('axios'); +const FormData = require('form-data'); const nodemailer = require('nodemailer'); const handlebars = require('handlebars'); +const { logAxiosError } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { isEnabled } = require('~/server/utils/handleText'); -const logger = require('~/config/winston'); + +/** + * Sends an email using Mailgun API. + * + * @async + * @function sendEmailViaMailgun + * @param {Object} params - The parameters for sending the email. + * @param {string} params.to - The recipient's email address. + * @param {string} params.from - The sender's email address. + * @param {string} params.subject - The subject of the email. + * @param {string} params.html - The HTML content of the email. + * @returns {Promise} - A promise that resolves to the response from Mailgun API. + */ +const sendEmailViaMailgun = async ({ to, from, subject, html }) => { + const mailgunApiKey = process.env.MAILGUN_API_KEY; + const mailgunDomain = process.env.MAILGUN_DOMAIN; + const mailgunHost = process.env.MAILGUN_HOST || 'https://api.mailgun.net'; + + if (!mailgunApiKey || !mailgunDomain) { + throw new Error('Mailgun API key and domain are required'); + } + + const formData = new FormData(); + formData.append('from', from); + formData.append('to', to); + formData.append('subject', subject); + formData.append('html', html); + formData.append('o:tracking-clicks', 'no'); + + try { + const response = await axios.post(`${mailgunHost}/v3/${mailgunDomain}/messages`, formData, { + headers: { + ...formData.getHeaders(), + Authorization: `Basic ${Buffer.from(`api:${mailgunApiKey}`).toString('base64')}`, + }, + }); + + return response.data; + } catch (error) { + throw new Error(logAxiosError({ error, message: 'Failed to send email via Mailgun' })); + } +}; + +/** + * Sends an email using SMTP via Nodemailer. + * + * @async + * @function sendEmailViaSMTP + * @param {Object} params - The parameters for sending the email. + * @param {Object} params.transporterOptions - The transporter configuration options. + * @param {Object} params.mailOptions - The email options. + * @returns {Promise} - A promise that resolves to the info object of the sent email. + */ +const sendEmailViaSMTP = async ({ transporterOptions, mailOptions }) => { + const transporter = nodemailer.createTransport(transporterOptions); + return await transporter.sendMail(mailOptions); +}; /** * Sends an email using the specified template, subject, and payload. @@ -34,6 +94,30 @@ const logger = require('~/config/winston'); */ const sendEmail = async ({ email, subject, payload, template, throwError = true }) => { try { + // Read and compile the email template + const source = fs.readFileSync(path.join(__dirname, 'emails', template), 'utf8'); + const compiledTemplate = handlebars.compile(source); + const html = compiledTemplate(payload); + + // Prepare common email data + const fromName = process.env.EMAIL_FROM_NAME || process.env.APP_TITLE; + const fromEmail = process.env.EMAIL_FROM; + const fromAddress = `"${fromName}" <${fromEmail}>`; + const toAddress = `"${payload.name}" <${email}>`; + + // Check if Mailgun is configured + if (process.env.MAILGUN_API_KEY && process.env.MAILGUN_DOMAIN) { + logger.debug('[sendEmail] Using Mailgun provider'); + return await sendEmailViaMailgun({ + from: fromAddress, + to: toAddress, + subject: subject, + html: html, + }); + } + + // Default to SMTP + logger.debug('[sendEmail] Using SMTP provider'); const transporterOptions = { // Use STARTTLS by default instead of obligatory TLS secure: process.env.EMAIL_ENCRYPTION === 'tls', @@ -62,30 +146,21 @@ const sendEmail = async ({ email, subject, payload, template, throwError = true transporterOptions.port = process.env.EMAIL_PORT ?? 25; } - const transporter = nodemailer.createTransport(transporterOptions); - - const source = fs.readFileSync(path.join(__dirname, 'emails', template), 'utf8'); - const compiledTemplate = handlebars.compile(source); - const options = () => { - return { - // Header address should contain name-addr - from: - `"${process.env.EMAIL_FROM_NAME || process.env.APP_TITLE}"` + - `<${process.env.EMAIL_FROM}>`, - to: `"${payload.name}" <${email}>`, - envelope: { - // Envelope from should contain addr-spec - // Mistake in the Nodemailer documentation? - from: process.env.EMAIL_FROM, - to: email, - }, - subject: subject, - html: compiledTemplate(payload), - }; + const mailOptions = { + // Header address should contain name-addr + from: fromAddress, + to: toAddress, + envelope: { + // Envelope from should contain addr-spec + // Mistake in the Nodemailer documentation? + from: fromEmail, + to: email, + }, + subject: subject, + html: html, }; - // Send email - return await transporter.sendMail(options()); + return await sendEmailViaSMTP({ transporterOptions, mailOptions }); } catch (error) { if (throwError) { throw error; diff --git a/api/server/utils/staticCache.js b/api/server/utils/staticCache.js index 5925a56be5..e885273223 100644 --- a/api/server/utils/staticCache.js +++ b/api/server/utils/staticCache.js @@ -1,3 +1,4 @@ +const path = require('path'); const expressStaticGzip = require('express-static-gzip'); const oneDayInSeconds = 24 * 60 * 60; @@ -5,16 +6,45 @@ const oneDayInSeconds = 24 * 60 * 60; const sMaxAge = process.env.STATIC_CACHE_S_MAX_AGE || oneDayInSeconds; const maxAge = process.env.STATIC_CACHE_MAX_AGE || oneDayInSeconds * 2; -const staticCache = (staticPath) => - expressStaticGzip(staticPath, { - enableBrotli: false, // disable Brotli, only using gzip +/** + * Creates an Express static middleware with gzip compression and configurable caching + * + * @param {string} staticPath - The file system path to serve static files from + * @param {Object} [options={}] - Configuration options + * @param {boolean} [options.noCache=false] - If true, disables caching entirely for all files + * @returns {ReturnType} Express middleware function for serving static files + */ +function staticCache(staticPath, options = {}) { + const { noCache = false } = options; + return expressStaticGzip(staticPath, { + enableBrotli: false, orderPreference: ['gz'], - setHeaders: (res, _path) => { - if (process.env.NODE_ENV?.toLowerCase() === 'production') { + setHeaders: (res, filePath) => { + if (process.env.NODE_ENV?.toLowerCase() !== 'production') { + return; + } + if (noCache) { + res.setHeader('Cache-Control', 'no-store, no-cache, must-revalidate'); + return; + } + if (filePath.includes('/dist/images/')) { + return; + } + const fileName = path.basename(filePath); + + if ( + fileName === 'index.html' || + fileName.endsWith('.webmanifest') || + fileName === 'manifest.json' || + fileName === 'sw.js' + ) { + res.setHeader('Cache-Control', 'no-store, no-cache, must-revalidate'); + } else { res.setHeader('Cache-Control', `public, max-age=${maxAge}, s-maxage=${sMaxAge}`); } }, index: false, }); +} module.exports = staticCache; diff --git a/api/strategies/appleStrategy.js b/api/strategies/appleStrategy.js index a45f10fc62..4dbac2e364 100644 --- a/api/strategies/appleStrategy.js +++ b/api/strategies/appleStrategy.js @@ -18,17 +18,13 @@ const getProfileDetails = ({ idToken, profile }) => { const decoded = jwt.decode(idToken); - logger.debug( - `Decoded Apple JWT: ${JSON.stringify(decoded, null, 2)}`, - ); + logger.debug(`Decoded Apple JWT: ${JSON.stringify(decoded, null, 2)}`); return { email: decoded.email, id: decoded.sub, avatarUrl: null, // Apple does not provide an avatar URL - username: decoded.email - ? decoded.email.split('@')[0].toLowerCase() - : `user_${decoded.sub}`, + username: decoded.email ? decoded.email.split('@')[0].toLowerCase() : `user_${decoded.sub}`, name: decoded.name ? `${decoded.name.firstName} ${decoded.name.lastName}` : profile.displayName || null, diff --git a/api/strategies/appleStrategy.test.js b/api/strategies/appleStrategy.test.js index c457e15fdc..65a961bd4d 100644 --- a/api/strategies/appleStrategy.test.js +++ b/api/strategies/appleStrategy.test.js @@ -1,22 +1,25 @@ -const mongoose = require('mongoose'); -const { MongoMemoryServer } = require('mongodb-memory-server'); const jwt = require('jsonwebtoken'); +const mongoose = require('mongoose'); +const { logger } = require('@librechat/data-schemas'); const { Strategy: AppleStrategy } = require('passport-apple'); -const socialLogin = require('./socialLogin'); -const User = require('~/models/User'); -const { logger } = require('~/config'); +const { MongoMemoryServer } = require('mongodb-memory-server'); const { createSocialUser, handleExistingUser } = require('./process'); const { isEnabled } = require('~/server/utils'); +const socialLogin = require('./socialLogin'); const { findUser } = require('~/models'); +const { User } = require('~/db/models'); -// Mocking external dependencies jest.mock('jsonwebtoken'); -jest.mock('~/config', () => ({ - logger: { - error: jest.fn(), - debug: jest.fn(), - }, -})); +jest.mock('@librechat/data-schemas', () => { + const actualModule = jest.requireActual('@librechat/data-schemas'); + return { + ...actualModule, + logger: { + error: jest.fn(), + debug: jest.fn(), + }, + }; +}); jest.mock('./process', () => ({ createSocialUser: jest.fn(), handleExistingUser: jest.fn(), @@ -64,7 +67,6 @@ describe('Apple Login Strategy', () => { // Define getProfileDetails within the test scope getProfileDetails = ({ idToken, profile }) => { - console.log('getProfileDetails called with idToken:', idToken); if (!idToken) { logger.error('idToken is missing'); throw new Error('idToken is missing'); @@ -84,9 +86,7 @@ describe('Apple Login Strategy', () => { email: decoded.email, id: decoded.sub, avatarUrl: null, // Apple does not provide an avatar URL - username: decoded.email - ? decoded.email.split('@')[0].toLowerCase() - : `user_${decoded.sub}`, + username: decoded.email ? decoded.email.split('@')[0].toLowerCase() : `user_${decoded.sub}`, name: decoded.name ? `${decoded.name.firstName} ${decoded.name.lastName}` : profile.displayName || null, @@ -96,8 +96,12 @@ describe('Apple Login Strategy', () => { // Mock isEnabled based on environment variable isEnabled.mockImplementation((flag) => { - if (flag === 'true') { return true; } - if (flag === 'false') { return false; } + if (flag === 'true') { + return true; + } + if (flag === 'false') { + return false; + } return false; }); @@ -154,9 +158,7 @@ describe('Apple Login Strategy', () => { }); expect(jwt.decode).toHaveBeenCalledWith('fake_id_token'); - expect(logger.debug).toHaveBeenCalledWith( - expect.stringContaining('Decoded Apple JWT'), - ); + expect(logger.debug).toHaveBeenCalledWith(expect.stringContaining('Decoded Apple JWT')); expect(profileDetails).toEqual({ email: 'john.doe@example.com', id: 'apple-sub-1234', @@ -209,7 +211,7 @@ describe('Apple Login Strategy', () => { beforeEach(() => { jwt.decode.mockReturnValue(decodedToken); - findUser.mockImplementation(({ email }) => User.findOne({ email })); + findUser.mockResolvedValue(null); }); it('should create a new user if one does not exist and registration is allowed', async () => { @@ -248,7 +250,7 @@ describe('Apple Login Strategy', () => { }); it('should handle existing user and update avatarUrl', async () => { - // Create an existing user + // Create an existing user without saving to database const existingUser = new User({ email: 'jane.doe@example.com', username: 'jane.doe', @@ -257,15 +259,15 @@ describe('Apple Login Strategy', () => { providerId: 'apple-sub-9012', avatarUrl: 'old_avatar.png', }); - await existingUser.save(); // Mock findUser to return the existing user findUser.mockResolvedValue(existingUser); - // Mock handleExistingUser to update avatarUrl + // Mock handleExistingUser to update avatarUrl without saving to database handleExistingUser.mockImplementation(async (user, avatarUrl) => { user.avatarUrl = avatarUrl; - await user.save(); + // Don't call save() to avoid database operations + return user; }); const mockVerifyCallback = jest.fn(); @@ -297,7 +299,7 @@ describe('Apple Login Strategy', () => { appleStrategyInstance._verify( fakeAccessToken, fakeRefreshToken, - null, // idToken is missing + null, // idToken is missing mockProfile, (err, user) => { mockVerifyCallback(err, user); diff --git a/api/strategies/index.js b/api/strategies/index.js index dbb1bd8703..725e04224a 100644 --- a/api/strategies/index.js +++ b/api/strategies/index.js @@ -7,6 +7,7 @@ const facebookLogin = require('./facebookStrategy'); const { setupOpenId, getOpenIdConfig } = require('./openidStrategy'); const jwtLogin = require('./jwtStrategy'); const ldapLogin = require('./ldapStrategy'); +const { setupSaml } = require('./samlStrategy'); const openIdJwtLogin = require('./openIdJwtStrategy'); module.exports = { @@ -20,5 +21,6 @@ module.exports = { setupOpenId, getOpenIdConfig, ldapLogin, + setupSaml, openIdJwtLogin, }; diff --git a/api/strategies/jwtStrategy.js b/api/strategies/jwtStrategy.js index eb4b34fd85..6793873ee8 100644 --- a/api/strategies/jwtStrategy.js +++ b/api/strategies/jwtStrategy.js @@ -1,7 +1,7 @@ +const { logger } = require('@librechat/data-schemas'); const { SystemRoles } = require('librechat-data-provider'); const { Strategy: JwtStrategy, ExtractJwt } = require('passport-jwt'); const { getUserById, updateUser } = require('~/models'); -const { logger } = require('~/config'); // JWT strategy const jwtLogin = () => diff --git a/api/strategies/ldapStrategy.js b/api/strategies/ldapStrategy.js index beb9b8c2fd..434534c264 100644 --- a/api/strategies/ldapStrategy.js +++ b/api/strategies/ldapStrategy.js @@ -1,10 +1,10 @@ const fs = require('fs'); const LdapStrategy = require('passport-ldapauth'); const { SystemRoles } = require('librechat-data-provider'); -const { findUser, createUser, updateUser } = require('~/models/userMethods'); -const { countUsers } = require('~/models/userMethods'); +const { logger } = require('@librechat/data-schemas'); +const { createUser, findUser, updateUser, countUsers } = require('~/models'); +const { getBalanceConfig } = require('~/server/services/Config'); const { isEnabled } = require('~/server/utils'); -const logger = require('~/utils/logger'); const { LDAP_URL, @@ -124,7 +124,8 @@ const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => { name: fullName, role: isFirstRegisteredUser ? SystemRoles.ADMIN : SystemRoles.USER, }; - const userId = await createUser(user); + const balanceConfig = await getBalanceConfig(); + const userId = await createUser(user, balanceConfig); user._id = userId; } else { // Users registered in LDAP are assumed to have their user information managed in LDAP, diff --git a/api/strategies/localStrategy.js b/api/strategies/localStrategy.js index bffb4f845f..bc84e7c6b5 100644 --- a/api/strategies/localStrategy.js +++ b/api/strategies/localStrategy.js @@ -1,9 +1,9 @@ +const { logger } = require('@librechat/data-schemas'); const { errorsToString } = require('librechat-data-provider'); const { Strategy: PassportLocalStrategy } = require('passport-local'); -const { findUser, comparePassword, updateUser } = require('~/models'); const { isEnabled, checkEmailConfig } = require('~/server/utils'); +const { findUser, comparePassword, updateUser } = require('~/models'); const { loginSchema } = require('./validators'); -const logger = require('~/utils/logger'); // Unix timestamp for 2024-06-07 15:20:18 Eastern Time const verificationEnabledTimestamp = 1717788018; @@ -29,6 +29,12 @@ async function passportLogin(req, email, password, done) { return done(null, false, { message: 'Email does not exist.' }); } + if (!user.password) { + logError('Passport Local Strategy - User has no password', { email }); + logger.error(`[Login] [Login failed] [Username: ${email}] [Request-IP: ${req.ip}]`); + return done(null, false, { message: 'Email does not exist.' }); + } + const isMatch = await comparePassword(user, password); if (!isMatch) { logError('Passport Local Strategy - Password does not match', { isMatch }); diff --git a/api/strategies/openidStrategy.js b/api/strategies/openidStrategy.js index ea109358d7..2449872a9d 100644 --- a/api/strategies/openidStrategy.js +++ b/api/strategies/openidStrategy.js @@ -1,15 +1,16 @@ -const { CacheKeys } = require('librechat-data-provider'); +const undici = require('undici'); const fetch = require('node-fetch'); const passport = require('passport'); -const jwtDecode = require('jsonwebtoken/decode'); -const { HttpsProxyAgent } = require('https-proxy-agent'); const client = require('openid-client'); +const jwtDecode = require('jsonwebtoken/decode'); +const { CacheKeys } = require('librechat-data-provider'); +const { HttpsProxyAgent } = require('https-proxy-agent'); +const { hashToken, logger } = require('@librechat/data-schemas'); const { Strategy: OpenIDStrategy } = require('openid-client/passport'); +const { isEnabled, safeStringify, logHeaders } = require('@librechat/api'); const { getStrategyFunctions } = require('~/server/services/Files/strategies'); -const { findUser, createUser, updateUser } = require('~/models/userMethods'); -const { hashToken } = require('~/server/utils/crypto'); -const { isEnabled } = require('~/server/utils'); -const { logger } = require('~/config'); +const { findUser, createUser, updateUser } = require('~/models'); +const { getBalanceConfig } = require('~/server/services/Config'); const getLogStores = require('~/cache/getLogStores'); /** @@ -17,6 +18,76 @@ const getLogStores = require('~/cache/getLogStores'); * @typedef {import('openid-client').Configuration} Configuration **/ +/** + * @param {string} url + * @param {client.CustomFetchOptions} options + */ +async function customFetch(url, options) { + const urlStr = url.toString(); + logger.debug(`[openidStrategy] Request to: ${urlStr}`); + const debugOpenId = isEnabled(process.env.DEBUG_OPENID_REQUESTS); + if (debugOpenId) { + logger.debug(`[openidStrategy] Request method: ${options.method || 'GET'}`); + logger.debug(`[openidStrategy] Request headers: ${logHeaders(options.headers)}`); + if (options.body) { + let bodyForLogging = ''; + if (options.body instanceof URLSearchParams) { + bodyForLogging = options.body.toString(); + } else if (typeof options.body === 'string') { + bodyForLogging = options.body; + } else { + bodyForLogging = safeStringify(options.body); + } + logger.debug(`[openidStrategy] Request body: ${bodyForLogging}`); + } + } + + try { + /** @type {undici.RequestInit} */ + let fetchOptions = options; + if (process.env.PROXY) { + logger.info(`[openidStrategy] proxy agent configured: ${process.env.PROXY}`); + fetchOptions = { + ...options, + dispatcher: new HttpsProxyAgent(process.env.PROXY), + }; + } + + const response = await undici.fetch(url, fetchOptions); + + if (debugOpenId) { + logger.debug(`[openidStrategy] Response status: ${response.status} ${response.statusText}`); + logger.debug(`[openidStrategy] Response headers: ${logHeaders(response.headers)}`); + } + + if (response.status === 200 && response.headers.has('www-authenticate')) { + const wwwAuth = response.headers.get('www-authenticate'); + logger.warn(`[openidStrategy] Non-standard WWW-Authenticate header found in successful response (200 OK): ${wwwAuth}. +This violates RFC 7235 and may cause issues with strict OAuth clients. Removing header for compatibility.`); + + /** Cloned response without the WWW-Authenticate header */ + const responseBody = await response.arrayBuffer(); + const newHeaders = new Headers(); + for (const [key, value] of response.headers.entries()) { + if (key.toLowerCase() !== 'www-authenticate') { + newHeaders.append(key, value); + } + } + + return new Response(responseBody, { + status: response.status, + statusText: response.statusText, + headers: newHeaders, + }); + } + + return response; + } catch (error) { + logger.error(`[openidStrategy] Fetch error: ${error.message}`); + throw error; + } +} + /** @typedef {Configuration | null} */ let openidConfig = null; @@ -208,14 +279,12 @@ async function setupOpenId() { new URL(process.env.OPENID_ISSUER), process.env.OPENID_CLIENT_ID, clientMetadata, + undefined, + { + [client.customFetch]: customFetch, + }, ); - if (process.env.PROXY) { - const proxyAgent = new HttpsProxyAgent(process.env.PROXY); - openidConfig[client.customFetch] = (...args) => { - return fetch(args[0], { ...args[1], agent: proxyAgent }); - }; - logger.info(`[openidStrategy] proxy agent added: ${process.env.PROXY}`); - } + const requiredRole = process.env.OPENID_REQUIRED_ROLE; const requiredRoleParameterPath = process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH; const requiredRoleTokenKind = process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND; @@ -297,7 +366,10 @@ async function setupOpenId() { emailVerified: userinfo.email_verified || false, name: fullName, }; - user = await createUser(user, true, true); + + const balanceConfig = await getBalanceConfig(); + + user = await createUser(user, balanceConfig, true, true); } else { user.provider = 'openid'; user.openidId = userinfo.sub; diff --git a/api/strategies/openidStrategy.spec.js b/api/strategies/openidStrategy.spec.js index e70dfa5529..3e52ad01f1 100644 --- a/api/strategies/openidStrategy.spec.js +++ b/api/strategies/openidStrategy.spec.js @@ -1,7 +1,7 @@ const fetch = require('node-fetch'); const jwtDecode = require('jsonwebtoken/decode'); -const { findUser, createUser, updateUser } = require('~/models/userMethods'); const { setupOpenId } = require('./openidStrategy'); +const { findUser, createUser, updateUser } = require('~/models'); // --- Mocks --- jest.mock('node-fetch'); @@ -11,7 +11,12 @@ jest.mock('~/server/services/Files/strategies', () => ({ saveBuffer: jest.fn().mockResolvedValue('/fake/path/to/avatar.png'), })), })); -jest.mock('~/models/userMethods', () => ({ +jest.mock('~/server/services/Config', () => ({ + getBalanceConfig: jest.fn(() => ({ + enabled: false, + })), +})); +jest.mock('~/models', () => ({ findUser: jest.fn(), createUser: jest.fn(), updateUser: jest.fn(), @@ -36,11 +41,6 @@ jest.mock('~/cache/getLogStores', () => set: jest.fn(), })), ); -jest.mock('librechat-data-provider', () => ({ - CacheKeys: { - OPENID_EXCHANGED_TOKENS: 'openid-exchanged-tokens', - }, -})); // Mock the openid-client module and all its dependencies jest.mock('openid-client', () => { @@ -174,6 +174,7 @@ describe('setupOpenId', () => { email: userinfo.email, name: `${userinfo.given_name} ${userinfo.family_name}`, }), + { enabled: false }, true, true, ); @@ -193,6 +194,7 @@ describe('setupOpenId', () => { expect(user.username).toBe(expectUsername); expect(createUser).toHaveBeenCalledWith( expect.objectContaining({ username: expectUsername }), + { enabled: false }, true, true, ); @@ -212,6 +214,7 @@ describe('setupOpenId', () => { expect(user.username).toBe(expectUsername); expect(createUser).toHaveBeenCalledWith( expect.objectContaining({ username: expectUsername }), + { enabled: false }, true, true, ); @@ -229,6 +232,7 @@ describe('setupOpenId', () => { expect(user.username).toBe(userinfo.sub); expect(createUser).toHaveBeenCalledWith( expect.objectContaining({ username: userinfo.sub }), + { enabled: false }, true, true, ); diff --git a/api/strategies/process.js b/api/strategies/process.js index e9a908ffd0..1f7e7c81d2 100644 --- a/api/strategies/process.js +++ b/api/strategies/process.js @@ -1,7 +1,8 @@ const { FileSources } = require('librechat-data-provider'); -const { createUser, updateUser, getUserById } = require('~/models/userMethods'); const { getStrategyFunctions } = require('~/server/services/Files/strategies'); const { resizeAvatar } = require('~/server/services/Files/images/avatar'); +const { updateUser, createUser, getUserById } = require('~/models'); +const { getBalanceConfig } = require('~/server/services/Config'); /** * Updates the avatar URL of an existing user. If the user's avatar URL does not include the query parameter @@ -30,7 +31,7 @@ const handleExistingUser = async (oldUser, avatarUrl) => { input: avatarUrl, }); const { processAvatar } = getStrategyFunctions(fileStrategy); - updatedAvatar = await processAvatar({ buffer: resizedBuffer, userId }); + updatedAvatar = await processAvatar({ buffer: resizedBuffer, userId, manual: 'false' }); } if (updatedAvatar) { @@ -78,7 +79,8 @@ const createSocialUser = async ({ emailVerified, }; - const newUserId = await createUser(update); + const balanceConfig = await getBalanceConfig(); + const newUserId = await createUser(update, balanceConfig); const fileStrategy = process.env.CDN_PROVIDER; const isLocal = fileStrategy === FileSources.local; @@ -88,7 +90,11 @@ const createSocialUser = async ({ input: avatarUrl, }); const { processAvatar } = getStrategyFunctions(fileStrategy); - const avatar = await processAvatar({ buffer: resizedBuffer, userId: newUserId }); + const avatar = await processAvatar({ + buffer: resizedBuffer, + userId: newUserId, + manual: 'false', + }); await updateUser(newUserId, { avatar }); } diff --git a/api/strategies/samlStrategy.js b/api/strategies/samlStrategy.js new file mode 100644 index 0000000000..376434f733 --- /dev/null +++ b/api/strategies/samlStrategy.js @@ -0,0 +1,277 @@ +const fs = require('fs'); +const path = require('path'); +const fetch = require('node-fetch'); +const passport = require('passport'); +const { hashToken, logger } = require('@librechat/data-schemas'); +const { Strategy: SamlStrategy } = require('@node-saml/passport-saml'); +const { getStrategyFunctions } = require('~/server/services/Files/strategies'); +const { findUser, createUser, updateUser } = require('~/models'); +const { getBalanceConfig } = require('~/server/services/Config'); +const paths = require('~/config/paths'); + +let crypto; +try { + crypto = require('node:crypto'); +} catch (err) { + logger.error('[samlStrategy] crypto support is disabled!', err); +} + +/** + * Retrieves the certificate content from the given value. + * + * This function determines whether the provided value is a certificate string (RFC7468 format or + * base64-encoded without a header) or a valid file path. If the value matches one of these formats, + * the certificate content is returned. Otherwise, an error is thrown. + * + * @see https://github.com/node-saml/node-saml/tree/master?tab=readme-ov-file#configuration-option-idpcert + * @param {string} value - The certificate string or file path. + * @returns {string} The certificate content if valid. + * @throws {Error} If the value is not a valid certificate string or file path. + */ +function getCertificateContent(value) { + if (typeof value !== 'string') { + throw new Error('Invalid input: SAML_CERT must be a string.'); + } + + // Check if it's an RFC7468 formatted PEM certificate + const pemRegex = new RegExp( + '-----BEGIN (CERTIFICATE|PUBLIC KEY)-----\n' + // header + '([A-Za-z0-9+/=]{64}\n)+' + // base64 content (64 characters per line) + '[A-Za-z0-9+/=]{1,64}\n' + // base64 content (last line) + '-----END (CERTIFICATE|PUBLIC KEY)-----', // footer + ); + if (pemRegex.test(value)) { + logger.info('[samlStrategy] Detected RFC7468-formatted certificate string.'); + return value; + } + + // Check if it's a Base64-encoded certificate (no header) + if (/^[A-Za-z0-9+/=]+$/.test(value) && value.length % 4 === 0) { + logger.info('[samlStrategy] Detected base64-encoded certificate string (no header).'); + return value; + } + + // Check if file exists and is readable + const certPath = path.normalize(path.isAbsolute(value) ? value : path.join(paths.root, value)); + if (fs.existsSync(certPath) && fs.statSync(certPath).isFile()) { + try { + logger.info(`[samlStrategy] Loading certificate from file: ${certPath}`); + return fs.readFileSync(certPath, 'utf8').trim(); + } catch (error) { + throw new Error(`Error reading certificate file: ${error.message}`); + } + } + + throw new Error('Invalid cert: SAML_CERT must be a valid file path or certificate string.'); +} + +/** + * Retrieves a SAML claim from a profile object based on environment configuration. + * @param {object} profile - Saml profile + * @param {string} envVar - Environment variable name (SAML_*) + * @param {string} defaultKey - Default key to use if the environment variable is not set + * @returns {string} + */ +function getSamlClaim(profile, envVar, defaultKey) { + const claimKey = process.env[envVar]; + + // Avoids accessing `profile[""]` when the environment variable is empty string. + if (claimKey) { + return profile[claimKey] ?? profile[defaultKey]; + } + return profile[defaultKey]; +} + +function getEmail(profile) { + return getSamlClaim(profile, 'SAML_EMAIL_CLAIM', 'email'); +} + +function getUserName(profile) { + return getSamlClaim(profile, 'SAML_USERNAME_CLAIM', 'username'); +} + +function getGivenName(profile) { + return getSamlClaim(profile, 'SAML_GIVEN_NAME_CLAIM', 'given_name'); +} + +function getFamilyName(profile) { + return getSamlClaim(profile, 'SAML_FAMILY_NAME_CLAIM', 'family_name'); +} + +function getPicture(profile) { + return getSamlClaim(profile, 'SAML_PICTURE_CLAIM', 'picture'); +} + +/** + * Downloads an image from a URL using an access token. + * @param {string} url + * @returns {Promise} + */ +const downloadImage = async (url) => { + try { + const response = await fetch(url); + if (response.ok) { + return await response.buffer(); + } else { + throw new Error(`${response.statusText} (HTTP ${response.status})`); + } + } catch (error) { + logger.error(`[samlStrategy] Error downloading image at URL "${url}": ${error}`); + return null; + } +}; + +/** + * Determines the full name of a user based on SAML profile and environment configuration. + * + * @param {Object} profile - The user profile object from SAML Connect + * @returns {string} The determined full name of the user + */ +function getFullName(profile) { + if (process.env.SAML_NAME_CLAIM) { + logger.info( + `[samlStrategy] Using SAML_NAME_CLAIM: ${process.env.SAML_NAME_CLAIM}, profile: ${profile[process.env.SAML_NAME_CLAIM]}`, + ); + return profile[process.env.SAML_NAME_CLAIM]; + } + + const givenName = getGivenName(profile); + const familyName = getFamilyName(profile); + + if (givenName && familyName) { + return `${givenName} ${familyName}`; + } + + if (givenName) { + return givenName; + } + if (familyName) { + return familyName; + } + + return getUserName(profile) || getEmail(profile); +} + +/** + * Converts an input into a string suitable for a username. + * If the input is a string, it will be returned as is. + * If the input is an array, elements will be joined with underscores. + * In case of undefined or other falsy values, a default value will be returned. + * + * @param {string | string[] | undefined} input - The input value to be converted into a username. + * @param {string} [defaultValue=''] - The default value to return if the input is falsy. + * @returns {string} The processed input as a string suitable for a username. + */ +function convertToUsername(input, defaultValue = '') { + if (typeof input === 'string') { + return input; + } else if (Array.isArray(input)) { + return input.join('_'); + } + + return defaultValue; +} + +async function setupSaml() { + try { + const samlConfig = { + entryPoint: process.env.SAML_ENTRY_POINT, + issuer: process.env.SAML_ISSUER, + callbackUrl: process.env.SAML_CALLBACK_URL, + idpCert: getCertificateContent(process.env.SAML_CERT), + wantAssertionsSigned: process.env.SAML_USE_AUTHN_RESPONSE_SIGNED === 'true' ? false : true, + wantAuthnResponseSigned: process.env.SAML_USE_AUTHN_RESPONSE_SIGNED === 'true' ? true : false, + }; + + passport.use( + 'saml', + new SamlStrategy(samlConfig, async (profile, done) => { + try { + logger.info(`[samlStrategy] SAML authentication received for NameID: ${profile.nameID}`); + logger.debug('[samlStrategy] SAML profile:', profile); + + let user = await findUser({ samlId: profile.nameID }); + logger.info( + `[samlStrategy] User ${user ? 'found' : 'not found'} with SAML ID: ${profile.nameID}`, + ); + + if (!user) { + const email = getEmail(profile) || ''; + user = await findUser({ email }); + logger.info( + `[samlStrategy] User ${user ? 'found' : 'not found'} with email: ${profile.email}`, + ); + } + + const fullName = getFullName(profile); + + const username = convertToUsername( + getUserName(profile) || getGivenName(profile) || getEmail(profile), + ); + + if (!user) { + user = { + provider: 'saml', + samlId: profile.nameID, + username, + email: getEmail(profile) || '', + emailVerified: true, + name: fullName, + }; + const balanceConfig = await getBalanceConfig(); + user = await createUser(user, balanceConfig, true, true); + } else { + user.provider = 'saml'; + user.samlId = profile.nameID; + user.username = username; + user.name = fullName; + } + + const picture = getPicture(profile); + if (picture && !user.avatar?.includes('manual=true')) { + const imageBuffer = await downloadImage(profile.picture); + if (imageBuffer) { + let fileName; + if (crypto) { + fileName = (await hashToken(profile.nameID)) + '.png'; + } else { + fileName = profile.nameID + '.png'; + } + + const { saveBuffer } = getStrategyFunctions(process.env.CDN_PROVIDER); + const imagePath = await saveBuffer({ + fileName, + userId: user._id.toString(), + buffer: imageBuffer, + }); + user.avatar = imagePath ?? ''; + } + } + + user = await updateUser(user._id, user); + + logger.info( + `[samlStrategy] Login success SAML ID: ${user.samlId} | email: ${user.email} | username: ${user.username}`, + { + user: { + samlId: user.samlId, + username: user.username, + email: user.email, + name: user.name, + }, + }, + ); + + done(null, user); + } catch (err) { + logger.error('[samlStrategy] Login failed', err); + done(err); + } + }), + ); + } catch (err) { + logger.error('[samlStrategy]', err); + } +} + +module.exports = { setupSaml, getCertificateContent }; diff --git a/api/strategies/samlStrategy.spec.js b/api/strategies/samlStrategy.spec.js new file mode 100644 index 0000000000..675bdc998b --- /dev/null +++ b/api/strategies/samlStrategy.spec.js @@ -0,0 +1,418 @@ +const fs = require('fs'); +const path = require('path'); +const fetch = require('node-fetch'); +const { Strategy: SamlStrategy } = require('@node-saml/passport-saml'); +const { findUser, createUser, updateUser } = require('~/models'); +const { setupSaml, getCertificateContent } = require('./samlStrategy'); + +// --- Mocks --- +jest.mock('fs'); +jest.mock('path'); +jest.mock('node-fetch'); +jest.mock('@node-saml/passport-saml'); +jest.mock('~/models', () => ({ + findUser: jest.fn(), + createUser: jest.fn(), + updateUser: jest.fn(), +})); +jest.mock('~/server/services/Config', () => ({ + config: { + registration: { + socialLogins: ['saml'], + }, + }, + getBalanceConfig: jest.fn().mockResolvedValue({ + tokenCredits: 1000, + startingBalance: 1000, + }), +})); +jest.mock('~/server/services/Config/EndpointService', () => ({ + config: {}, +})); +jest.mock('~/server/utils', () => ({ + isEnabled: jest.fn(() => false), + isUserProvided: jest.fn(() => false), +})); +jest.mock('~/server/services/Files/strategies', () => ({ + getStrategyFunctions: jest.fn(() => ({ + saveBuffer: jest.fn().mockResolvedValue('/fake/path/to/avatar.png'), + })), +})); +jest.mock('~/server/utils/crypto', () => ({ + hashToken: jest.fn().mockResolvedValue('hashed-token'), +})); +jest.mock('~/config', () => ({ + logger: { + info: jest.fn(), + debug: jest.fn(), + error: jest.fn(), + }, +})); + +// To capture the verify callback from the strategy, we grab it from the mock constructor +let verifyCallback; +SamlStrategy.mockImplementation((options, verify) => { + verifyCallback = verify; + return { name: 'saml', options, verify }; +}); + +describe('getCertificateContent', () => { + const certWithHeader = `-----BEGIN CERTIFICATE----- +MIIDazCCAlOgAwIBAgIUKhXaFJGJJPx466rlwYORIsqCq7MwDQYJKoZIhvcNAQEL +BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yNTAzMDQwODUxNTJaFw0yNjAz +MDQwODUxNTJaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEw +HwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3DQEB +AQUAA4IBDwAwggEKAoIBAQCWP09NZg0xaRiLpNygCVgV3M+4RFW2S0c5X/fg/uFT +O5MfaVYzG5GxzhXzWRB8RtNPsxX/nlbPsoUroeHbz+SABkOsNEv6JuKRH4VXRH34 +VzjazVkPAwj+N4WqsC/Wo4EGGpKIGeGi8Zed4yvMqoTyE3mrS19fY0nMHT62wUwS +GMm2pAQdAQePZ9WY7A5XOA1IoxW2Zh2Oxaf1p59epBkZDhoxSMu8GoSkvK27Km4A +4UXftzdg/wHNPrNirmcYouioHdmrOtYxPjrhUBQ74AmE1/QK45B6wEgirKH1A1AW +6C+ApLwpBMvy9+8Gbyvc8G18W3CjdEVKmAeWb9JUedSXAgMBAAGjUzBRMB0GA1Ud +DgQWBBRxpaqBx8VDLLc8IkHATujj8IOs6jAfBgNVHSMEGDAWgBRxpaqBx8VDLLc8 +IkHATujj8IOs6jAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQBc +Puk6i+yowwGccB3LhfxZ+Fz6s6/Lfx6bP/Hy4NYOxmx2/awGBgyfp1tmotjaS9Cf +FWd67LuEru4TYtz12RNMDBF5ypcEfibvb3I8O6igOSQX/Jl5D2pMChesZxhmCift +Qp09T41MA8PmHf1G9oMG0A3ZnjKDG5ebaJNRFImJhMHsgh/TP7V3uZy7YHTgopKX +Hv63V3Uo3Oihav29Q7urwmf7Ly7X7J2WE86/w3vRHi5dhaWWqEqxmnAXl+H+sG4V +meeVRI332bg1Nuy8KnnX8v3ZeJzMBkAhzvSr6Ri96R0/Un/oEFwVC5jDTq8sXVn6 +u7wlOSk+oFzDIO/UILIA +-----END CERTIFICATE-----`; + + const certWithoutHeader = certWithHeader + .replace(/-----BEGIN CERTIFICATE-----/g, '') + .replace(/-----END CERTIFICATE-----/g, '') + .replace(/\s+/g, ''); + + it('should throw an error if SAML_CERT is not set', () => { + process.env.SAML_CERT; + expect(() => getCertificateContent(process.env.SAML_CERT)).toThrow( + 'Invalid input: SAML_CERT must be a string.', + ); + }); + + it('should throw an error if SAML_CERT is empty', () => { + process.env.SAML_CERT = ''; + expect(() => getCertificateContent(process.env.SAML_CERT)).toThrow( + 'Invalid cert: SAML_CERT must be a valid file path or certificate string.', + ); + }); + + it('should load cert from an environment variable if it is a single-line string(with header)', () => { + process.env.SAML_CERT = certWithHeader; + + const actual = getCertificateContent(process.env.SAML_CERT); + expect(actual).toBe(certWithHeader); + }); + + it('should load cert from an environment variable if it is a single-line string(with no header)', () => { + process.env.SAML_CERT = certWithoutHeader; + + const actual = getCertificateContent(process.env.SAML_CERT); + expect(actual).toBe(certWithoutHeader); + }); + + it('should throw an error if SAML_CERT is a single-line string (with header, no newline characters)', () => { + process.env.SAML_CERT = certWithHeader.replace(/\n/g, ''); + expect(() => getCertificateContent(process.env.SAML_CERT)).toThrow( + 'Invalid cert: SAML_CERT must be a valid file path or certificate string.', + ); + }); + + it('should load cert from a relative file path if SAML_CERT is valid', () => { + process.env.SAML_CERT = 'test.pem'; + const resolvedPath = '/absolute/path/to/test.pem'; + + path.isAbsolute.mockReturnValue(false); + path.join.mockReturnValue(resolvedPath); + path.normalize.mockReturnValue(resolvedPath); + + fs.existsSync.mockReturnValue(true); + fs.statSync.mockReturnValue({ isFile: () => true }); + fs.readFileSync.mockReturnValue(certWithHeader); + + const actual = getCertificateContent(process.env.SAML_CERT); + expect(actual).toBe(certWithHeader); + }); + + it('should load cert from an absolute file path if SAML_CERT is valid', () => { + process.env.SAML_CERT = '/absolute/path/to/test.pem'; + + path.isAbsolute.mockReturnValue(true); + path.normalize.mockReturnValue(process.env.SAML_CERT); + + fs.existsSync.mockReturnValue(true); + fs.statSync.mockReturnValue({ isFile: () => true }); + fs.readFileSync.mockReturnValue(certWithHeader); + + const actual = getCertificateContent(process.env.SAML_CERT); + expect(actual).toBe(certWithHeader); + }); + + it('should throw an error if the file does not exist', () => { + process.env.SAML_CERT = 'missing.pem'; + const resolvedPath = '/absolute/path/to/missing.pem'; + + path.isAbsolute.mockReturnValue(false); + path.join.mockReturnValue(resolvedPath); + path.normalize.mockReturnValue(resolvedPath); + + fs.existsSync.mockReturnValue(false); + + expect(() => getCertificateContent(process.env.SAML_CERT)).toThrow( + 'Invalid cert: SAML_CERT must be a valid file path or certificate string.', + ); + }); + + it('should throw an error if the file is not readable', () => { + process.env.SAML_CERT = 'unreadable.pem'; + const resolvedPath = '/absolute/path/to/unreadable.pem'; + + path.isAbsolute.mockReturnValue(false); + path.join.mockReturnValue(resolvedPath); + path.normalize.mockReturnValue(resolvedPath); + + fs.existsSync.mockReturnValue(true); + fs.statSync.mockReturnValue({ isFile: () => true }); + fs.readFileSync.mockImplementation(() => { + throw new Error('Permission denied'); + }); + + expect(() => getCertificateContent(process.env.SAML_CERT)).toThrow( + 'Error reading certificate file: Permission denied', + ); + }); +}); + +describe('setupSaml', () => { + // Helper to wrap the verify callback in a promise + const validate = (profile) => + new Promise((resolve, reject) => { + verifyCallback(profile, (err, user, details) => { + if (err) { + reject(err); + } else { + resolve({ user, details }); + } + }); + }); + + const baseProfile = { + nameID: 'saml-1234', + email: 'test@example.com', + given_name: 'First', + family_name: 'Last', + name: 'My Full Name', + username: 'flast', + picture: 'https://example.com/avatar.png', + custom_name: 'custom', + }; + + beforeEach(async () => { + jest.clearAllMocks(); + + // Configure mocks + const { findUser, createUser, updateUser } = require('~/models'); + findUser.mockResolvedValue(null); + createUser.mockImplementation(async (userData) => ({ + _id: 'mock-user-id', + ...userData, + })); + updateUser.mockImplementation(async (id, userData) => ({ + _id: id, + ...userData, + })); + + const cert = ` +-----BEGIN CERTIFICATE----- +MIIDazCCAlOgAwIBAgIUKhXaFJGJJPx466rlwYORIsqCq7MwDQYJKoZIhvcNAQEL +BQAwRTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yNTAzMDQwODUxNTJaFw0yNjAz +MDQwODUxNTJaMEUxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEw +HwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3DQEB +AQUAA4IBDwAwggEKAoIBAQCWP09NZg0xaRiLpNygCVgV3M+4RFW2S0c5X/fg/uFT +O5MfaVYzG5GxzhXzWRB8RtNPsxX/nlbPsoUroeHbz+SABkOsNEv6JuKRH4VXRH34 +VzjazVkPAwj+N4WqsC/Wo4EGGpKIGeGi8Zed4yvMqoTyE3mrS19fY0nMHT62wUwS +GMm2pAQdAQePZ9WY7A5XOA1IoxW2Zh2Oxaf1p59epBkZDhoxSMu8GoSkvK27Km4A +4UXftzdg/wHNPrNirmcYouioHdmrOtYxPjrhUBQ74AmE1/QK45B6wEgirKH1A1AW +6C+ApLwpBMvy9+8Gbyvc8G18W3CjdEVKmAeWb9JUedSXAgMBAAGjUzBRMB0GA1Ud +DgQWBBRxpaqBx8VDLLc8IkHATujj8IOs6jAfBgNVHSMEGDAWgBRxpaqBx8VDLLc8 +IkHATujj8IOs6jAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQBc +Puk6i+yowwGccB3LhfxZ+Fz6s6/Lfx6bP/Hy4NYOxmx2/awGBgyfp1tmotjaS9Cf +FWd67LuEru4TYtz12RNMDBF5ypcEfibvb3I8O6igOSQX/Jl5D2pMChesZxhmCift +Qp09T41MA8PmHf1G9oMG0A3ZnjKDG5ebaJNRFImJhMHsgh/TP7V3uZy7YHTgopKX +Hv63V3Uo3Oihav29Q7urwmf7Ly7X7J2WE86/w3vRHi5dhaWWqEqxmnAXl+H+sG4V +meeVRI332bg1Nuy8KnnX8v3ZeJzMBkAhzvSr6Ri96R0/Un/oEFwVC5jDTq8sXVn6 +u7wlOSk+oFzDIO/UILIA +-----END CERTIFICATE----- + `; + + // Reset environment variables + process.env.SAML_ENTRY_POINT = 'https://example.com/saml'; + process.env.SAML_ISSUER = 'saml-issuer'; + process.env.SAML_CERT = cert; + process.env.SAML_CALLBACK_URL = '/oauth/saml/callback'; + delete process.env.SAML_EMAIL_CLAIM; + delete process.env.SAML_USERNAME_CLAIM; + delete process.env.SAML_GIVEN_NAME_CLAIM; + delete process.env.SAML_FAMILY_NAME_CLAIM; + delete process.env.SAML_PICTURE_CLAIM; + delete process.env.SAML_NAME_CLAIM; + + // Simulate image download + const fakeBuffer = Buffer.from('fake image'); + fetch.mockResolvedValue({ + ok: true, + buffer: jest.fn().mockResolvedValue(fakeBuffer), + }); + + await setupSaml(); + }); + + it('should create a new user with correct username when username claim exists', async () => { + const profile = { ...baseProfile }; + const { user } = await validate(profile); + + expect(user.username).toBe(profile.username); + expect(user.provider).toBe('saml'); + expect(user.samlId).toBe(profile.nameID); + expect(user.email).toBe(profile.email); + expect(user.name).toBe(`${profile.given_name} ${profile.family_name}`); + }); + + it('should use given_name as username when username claim is missing', async () => { + const profile = { ...baseProfile }; + delete profile.username; + const expectUsername = profile.given_name; + + const { user } = await validate(profile); + + expect(user.username).toBe(expectUsername); + expect(user.provider).toBe('saml'); + }); + + it('should use email as username when username and given_name are missing', async () => { + const profile = { ...baseProfile }; + delete profile.username; + delete profile.given_name; + const expectUsername = profile.email; + + const { user } = await validate(profile); + + expect(user.username).toBe(expectUsername); + expect(user.provider).toBe('saml'); + }); + + it('should override username with SAML_USERNAME_CLAIM when set', async () => { + process.env.SAML_USERNAME_CLAIM = 'nameID'; + const profile = { ...baseProfile }; + + const { user } = await validate(profile); + + expect(user.username).toBe(profile.nameID); + expect(user.provider).toBe('saml'); + }); + + it('should set the full name correctly when given_name and family_name exist', async () => { + const profile = { ...baseProfile }; + const expectedFullName = `${profile.given_name} ${profile.family_name}`; + + const { user } = await validate(profile); + + expect(user.name).toBe(expectedFullName); + }); + + it('should set the full name correctly when given_name exist', async () => { + const profile = { ...baseProfile }; + delete profile.family_name; + const expectedFullName = profile.given_name; + + const { user } = await validate(profile); + + expect(user.name).toBe(expectedFullName); + }); + + it('should set the full name correctly when family_name exist', async () => { + const profile = { ...baseProfile }; + delete profile.given_name; + const expectedFullName = profile.family_name; + + const { user } = await validate(profile); + + expect(user.name).toBe(expectedFullName); + }); + + it('should set the full name correctly when username exist', async () => { + const profile = { ...baseProfile }; + delete profile.family_name; + delete profile.given_name; + const expectedFullName = profile.username; + + const { user } = await validate(profile); + + expect(user.name).toBe(expectedFullName); + }); + + it('should set the full name correctly when email only exist', async () => { + const profile = { ...baseProfile }; + delete profile.family_name; + delete profile.given_name; + delete profile.username; + const expectedFullName = profile.email; + + const { user } = await validate(profile); + + expect(user.name).toBe(expectedFullName); + }); + + it('should set the full name correctly with SAML_NAME_CLAIM when set', async () => { + process.env.SAML_NAME_CLAIM = 'custom_name'; + const profile = { ...baseProfile }; + const expectedFullName = profile.custom_name; + + const { user } = await validate(profile); + + expect(user.name).toBe(expectedFullName); + }); + + it('should update an existing user on login', async () => { + // Set up findUser to return an existing user + const { findUser } = require('~/models'); + const existingUser = { + _id: 'existing-user-id', + provider: 'local', + email: baseProfile.email, + samlId: '', + username: 'oldusername', + name: 'Old Name', + }; + findUser.mockResolvedValue(existingUser); + + const profile = { ...baseProfile }; + const { user } = await validate(profile); + + expect(user.provider).toBe('saml'); + expect(user.samlId).toBe(baseProfile.nameID); + expect(user.username).toBe(baseProfile.username); + expect(user.name).toBe(`${baseProfile.given_name} ${baseProfile.family_name}`); + expect(user.email).toBe(baseProfile.email); + }); + + it('should attempt to download and save the avatar if picture is provided', async () => { + const profile = { ...baseProfile }; + + const { user } = await validate(profile); + + expect(fetch).toHaveBeenCalled(); + expect(user.avatar).toBe('/fake/path/to/avatar.png'); + }); + + it('should not attempt to download avatar if picture is not provided', async () => { + const profile = { ...baseProfile }; + delete profile.picture; + + await validate(profile); + + expect(fetch).not.toHaveBeenCalled(); + }); +}); diff --git a/api/strategies/socialLogin.js b/api/strategies/socialLogin.js index 925c2de34d..4f9462316a 100644 --- a/api/strategies/socialLogin.js +++ b/api/strategies/socialLogin.js @@ -1,7 +1,7 @@ +const { logger } = require('@librechat/data-schemas'); const { createSocialUser, handleExistingUser } = require('./process'); const { isEnabled } = require('~/server/utils'); const { findUser } = require('~/models'); -const { logger } = require('~/config'); const socialLogin = (provider, getProfileDetails) => async (accessToken, refreshToken, idToken, profile, cb) => { diff --git a/api/test/__mocks__/logger.js b/api/test/__mocks__/logger.js index f9f6d78c87..56fb28cbab 100644 --- a/api/test/__mocks__/logger.js +++ b/api/test/__mocks__/logger.js @@ -41,10 +41,7 @@ jest.mock('winston-daily-rotate-file', () => { }); jest.mock('~/config', () => { - const actualModule = jest.requireActual('~/config'); return { - sendEvent: actualModule.sendEvent, - createAxiosInstance: actualModule.createAxiosInstance, logger: { info: jest.fn(), warn: jest.fn(), diff --git a/api/typedefs.js b/api/typedefs.js index 8da5b34809..5bc7ebf664 100644 --- a/api/typedefs.js +++ b/api/typedefs.js @@ -1073,7 +1073,7 @@ /** * @exports MCPServers - * @typedef {import('librechat-mcp').MCPServers} MCPServers + * @typedef {import('@librechat/api').MCPServers} MCPServers * @memberof typedefs */ @@ -1085,31 +1085,31 @@ /** * @exports MCPManager - * @typedef {import('librechat-mcp').MCPManager} MCPManager + * @typedef {import('@librechat/api').MCPManager} MCPManager * @memberof typedefs */ /** * @exports FlowStateManager - * @typedef {import('librechat-mcp').FlowStateManager} FlowStateManager + * @typedef {import('@librechat/api').FlowStateManager} FlowStateManager * @memberof typedefs */ /** * @exports LCAvailableTools - * @typedef {import('librechat-mcp').LCAvailableTools} LCAvailableTools + * @typedef {import('@librechat/api').LCAvailableTools} LCAvailableTools * @memberof typedefs */ /** * @exports LCTool - * @typedef {import('librechat-mcp').LCTool} LCTool + * @typedef {import('@librechat/api').LCTool} LCTool * @memberof typedefs */ /** * @exports FormattedContent - * @typedef {import('librechat-mcp').FormattedContent} FormattedContent + * @typedef {import('@librechat/api').FormattedContent} FormattedContent * @memberof typedefs */ @@ -1232,7 +1232,7 @@ * @typedef {Object} AgentClientOptions * @property {Agent} agent - The agent configuration object * @property {string} endpoint - The endpoint identifier for the agent - * @property {Object} req - The request object + * @property {ServerRequest} req - The request object * @property {string} [name] - The username * @property {string} [modelLabel] - The label for the model being used * @property {number} [maxContextTokens] - Maximum number of tokens allowed in context diff --git a/api/utils/axios.js b/api/utils/axios.js deleted file mode 100644 index 91c1fbb223..0000000000 --- a/api/utils/axios.js +++ /dev/null @@ -1,46 +0,0 @@ -const { logger } = require('~/config'); - -/** - * Logs Axios errors based on the error object and a custom message. - * - * @param {Object} options - The options object. - * @param {string} options.message - The custom message to be logged. - * @param {import('axios').AxiosError} options.error - The Axios error object. - * @returns {string} The log message. - */ -const logAxiosError = ({ message, error }) => { - let logMessage = message; - try { - const stack = error.stack || 'No stack trace available'; - - if (error.response?.status) { - const { status, headers, data } = error.response; - logMessage = `${message} The server responded with status ${status}: ${error.message}`; - logger.error(logMessage, { - status, - headers, - data, - stack, - }); - } else if (error.request) { - const { method, url } = error.config || {}; - logMessage = `${message} No response received for ${method ? method.toUpperCase() : ''} ${url || ''}: ${error.message}`; - logger.error(logMessage, { - requestInfo: { method, url }, - stack, - }); - } else if (error?.message?.includes("Cannot read properties of undefined (reading 'status')")) { - logMessage = `${message} It appears the request timed out or was unsuccessful: ${error.message}`; - logger.error(logMessage, { stack }); - } else { - logMessage = `${message} An error occurred while setting up the request: ${error.message}`; - logger.error(logMessage, { stack }); - } - } catch (err) { - logMessage = `Error in logAxiosError: ${err.message}`; - logger.error(logMessage, { stack: err.stack || 'No stack trace available' }); - } - return logMessage; -}; - -module.exports = { logAxiosError }; diff --git a/api/utils/index.js b/api/utils/index.js index 62d61586bf..50b8c46d99 100644 --- a/api/utils/index.js +++ b/api/utils/index.js @@ -1,7 +1,5 @@ const loadYaml = require('./loadYaml'); -const axiosHelpers = require('./axios'); const tokenHelpers = require('./tokens'); -const azureUtils = require('./azureUtils'); const deriveBaseURL = require('./deriveBaseURL'); const extractBaseURL = require('./extractBaseURL'); const findMessageContent = require('./findMessageContent'); @@ -10,8 +8,6 @@ module.exports = { loadYaml, deriveBaseURL, extractBaseURL, - ...azureUtils, - ...axiosHelpers, ...tokenHelpers, findMessageContent, }; diff --git a/client/package.json b/client/package.json index 04069a8807..7cb983d218 100644 --- a/client/package.json +++ b/client/package.json @@ -6,7 +6,7 @@ "scripts": { "data-provider": "cd .. && npm run build:data-provider", "build:file": "cross-env NODE_ENV=production vite build --debug > vite-output.log 2>&1", - "build": "cross-env NODE_ENV=production vite build", + "build": "cross-env NODE_ENV=production vite build && node ./scripts/post-build.cjs", "build:ci": "cross-env NODE_ENV=development vite build --mode ci", "dev": "cross-env NODE_ENV=development vite", "preview-prod": "cross-env NODE_ENV=development vite preview", @@ -139,7 +139,6 @@ "postcss": "^8.4.31", "postcss-loader": "^7.1.0", "postcss-preset-env": "^8.2.0", - "rollup-plugin-visualizer": "^6.0.0", "tailwindcss": "^3.4.1", "ts-jest": "^29.2.5", "typescript": "^5.3.3", diff --git a/client/public/assets/google.svg b/client/public/assets/google.svg new file mode 100644 index 0000000000..bebf169e2b --- /dev/null +++ b/client/public/assets/google.svg @@ -0,0 +1 @@ +Gemini \ No newline at end of file diff --git a/client/public/assets/openai.svg b/client/public/assets/openai.svg new file mode 100644 index 0000000000..895b39d02f --- /dev/null +++ b/client/public/assets/openai.svg @@ -0,0 +1 @@ +OpenAI \ No newline at end of file diff --git a/client/public/assets/qwen.svg b/client/public/assets/qwen.svg new file mode 100644 index 0000000000..ed17f7c072 --- /dev/null +++ b/client/public/assets/qwen.svg @@ -0,0 +1 @@ +Qwen \ No newline at end of file diff --git a/client/scripts/post-build.cjs b/client/scripts/post-build.cjs new file mode 100644 index 0000000000..0c0f00dc14 --- /dev/null +++ b/client/scripts/post-build.cjs @@ -0,0 +1,14 @@ +const fs = require('fs-extra'); + +async function postBuild() { + try { + await fs.copy('public/assets', 'dist/assets'); + await fs.copy('public/robots.txt', 'dist/robots.txt'); + console.log('✅ PWA icons and robots.txt copied successfully. Glob pattern warnings resolved.'); + } catch (err) { + console.error('❌ Error copying files:', err); + process.exit(1); + } +} + +postBuild(); diff --git a/client/src/Providers/AgentPanelContext.tsx b/client/src/Providers/AgentPanelContext.tsx new file mode 100644 index 0000000000..628eda00f2 --- /dev/null +++ b/client/src/Providers/AgentPanelContext.tsx @@ -0,0 +1,45 @@ +import React, { createContext, useContext, useState } from 'react'; +import { Action, MCP, EModelEndpoint } from 'librechat-data-provider'; +import type { AgentPanelContextType } from '~/common'; +import { useGetActionsQuery } from '~/data-provider'; +import { Panel } from '~/common'; + +const AgentPanelContext = createContext(undefined); + +export function useAgentPanelContext() { + const context = useContext(AgentPanelContext); + if (context === undefined) { + throw new Error('useAgentPanelContext must be used within an AgentPanelProvider'); + } + return context; +} + +/** Houses relevant state for the Agent Form Panels (formerly 'commonProps') */ +export function AgentPanelProvider({ children }: { children: React.ReactNode }) { + const [mcp, setMcp] = useState(undefined); + const [mcps, setMcps] = useState(undefined); + const [action, setAction] = useState(undefined); + const [activePanel, setActivePanel] = useState(Panel.builder); + const [agent_id, setCurrentAgentId] = useState(undefined); + + const { data: actions } = useGetActionsQuery(EModelEndpoint.agents, { + enabled: !!agent_id, + }); + + const value = { + action, + setAction, + mcp, + setMcp, + mcps, + setMcps, + activePanel, + setActivePanel, + setCurrentAgentId, + agent_id, + /** Query data for actions */ + actions, + }; + + return {children}; +} diff --git a/client/src/Providers/AgentsContext.tsx b/client/src/Providers/AgentsContext.tsx index e793a3f087..a90a53ecb5 100644 --- a/client/src/Providers/AgentsContext.tsx +++ b/client/src/Providers/AgentsContext.tsx @@ -1,8 +1,8 @@ import { useForm, FormProvider } from 'react-hook-form'; import { createContext, useContext } from 'react'; -import { defaultAgentFormValues } from 'librechat-data-provider'; import type { UseFormReturn } from 'react-hook-form'; import type { AgentForm } from '~/common'; +import { getDefaultAgentFormValues } from '~/utils'; type AgentsContextType = UseFormReturn; @@ -20,7 +20,7 @@ export function useAgentsContext() { export default function AgentsProvider({ children }) { const methods = useForm({ - defaultValues: defaultAgentFormValues, + defaultValues: getDefaultAgentFormValues(), }); return {children}; diff --git a/client/src/Providers/index.ts b/client/src/Providers/index.ts index 00191318e0..41c9cdceb3 100644 --- a/client/src/Providers/index.ts +++ b/client/src/Providers/index.ts @@ -1,6 +1,7 @@ -export { default as ToastProvider } from './ToastContext'; export { default as AssistantsProvider } from './AssistantsContext'; export { default as AgentsProvider } from './AgentsContext'; +export { default as ToastProvider } from './ToastContext'; +export * from './AgentPanelContext'; export * from './ChatContext'; export * from './ShareContext'; export * from './ToastContext'; diff --git a/client/src/common/mcp.ts b/client/src/common/mcp.ts new file mode 100644 index 0000000000..b4f44a1f94 --- /dev/null +++ b/client/src/common/mcp.ts @@ -0,0 +1,26 @@ +import { + AuthorizationTypeEnum, + AuthTypeEnum, + TokenExchangeMethodEnum, +} from 'librechat-data-provider'; +import { MCPForm } from '~/common/types'; + +export const defaultMCPFormValues: MCPForm = { + type: AuthTypeEnum.None, + saved_auth_fields: false, + api_key: '', + authorization_type: AuthorizationTypeEnum.Basic, + custom_auth_header: '', + oauth_client_id: '', + oauth_client_secret: '', + authorization_url: '', + client_url: '', + scope: '', + token_exchange_method: TokenExchangeMethodEnum.DefaultPost, + name: '', + description: '', + url: '', + tools: [], + icon: '', + trust: false, +}; diff --git a/client/src/common/types.ts b/client/src/common/types.ts index 6837869e8e..6fe4784fbc 100644 --- a/client/src/common/types.ts +++ b/client/src/common/types.ts @@ -143,6 +143,7 @@ export enum Panel { actions = 'actions', model = 'model', version = 'version', + mcp = 'mcp', } export type FileSetter = @@ -166,6 +167,15 @@ export type ActionAuthForm = { token_exchange_method: t.TokenExchangeMethodEnum; }; +export type MCPForm = ActionAuthForm & { + name?: string; + description?: string; + url?: string; + tools?: string[]; + icon?: string; + trust?: boolean; +}; + export type ActionWithNullableMetadata = Omit & { metadata: t.ActionMetadata | null; }; @@ -188,16 +198,33 @@ export type AgentPanelProps = { index?: number; agent_id?: string; activePanel?: string; + mcp?: t.MCP; + mcps?: t.MCP[]; action?: t.Action; actions?: t.Action[]; createMutation: UseMutationResult; setActivePanel: React.Dispatch>; + setMcp: React.Dispatch>; setAction: React.Dispatch>; endpointsConfig?: t.TEndpointsConfig; setCurrentAgentId: React.Dispatch>; agentsConfig?: t.TAgentsEndpoint | null; }; +export type AgentPanelContextType = { + action?: t.Action; + actions?: t.Action[]; + setAction: React.Dispatch>; + mcp?: t.MCP; + mcps?: t.MCP[]; + setMcp: React.Dispatch>; + setMcps: React.Dispatch>; + activePanel?: string; + setActivePanel: React.Dispatch>; + setCurrentAgentId: React.Dispatch>; + agent_id?: string; +}; + export type AgentModelPanelProps = { agent_id?: string; providers: Option[]; @@ -457,11 +484,20 @@ export type VoiceOption = { }; export type TMessageAudio = { - messageId?: string; - content?: t.TMessageContentParts[] | string; - className?: string; - isLast: boolean; + isLast?: boolean; index: number; + messageId: string; + content: string; + className?: string; + renderButton?: (props: { + onClick: (e?: React.MouseEvent) => void; + title: string; + icon: React.ReactNode; + isActive?: boolean; + isVisible?: boolean; + isDisabled?: boolean; + className?: string; + }) => React.ReactNode; }; export type OptionWithIcon = Option & { icon?: React.ReactNode }; diff --git a/client/src/components/Artifacts/useDebounceCodeBlock.ts b/client/src/components/Artifacts/useDebounceCodeBlock.ts deleted file mode 100644 index 27aaf5bc83..0000000000 --- a/client/src/components/Artifacts/useDebounceCodeBlock.ts +++ /dev/null @@ -1,37 +0,0 @@ -// client/src/hooks/useDebounceCodeBlock.ts -import { useCallback, useEffect } from 'react'; -import debounce from 'lodash/debounce'; -import { useSetRecoilState } from 'recoil'; -import { codeBlocksState, codeBlockIdsState } from '~/store/artifacts'; -import type { CodeBlock } from '~/common'; - -export function useDebounceCodeBlock() { - const setCodeBlocks = useSetRecoilState(codeBlocksState); - const setCodeBlockIds = useSetRecoilState(codeBlockIdsState); - - const updateCodeBlock = useCallback((codeBlock: CodeBlock) => { - console.log('Updating code block:', codeBlock); - setCodeBlocks((prev) => ({ - ...prev, - [codeBlock.id]: codeBlock, - })); - setCodeBlockIds((prev) => - prev.includes(codeBlock.id) ? prev : [...prev, codeBlock.id], - ); - }, [setCodeBlocks, setCodeBlockIds]); - - const debouncedUpdateCodeBlock = useCallback( - debounce((codeBlock: CodeBlock) => { - updateCodeBlock(codeBlock); - }, 25), - [updateCodeBlock], - ); - - useEffect(() => { - return () => { - debouncedUpdateCodeBlock.cancel(); - }; - }, [debouncedUpdateCodeBlock]); - - return debouncedUpdateCodeBlock; -} diff --git a/client/src/components/Audio/TTS.tsx b/client/src/components/Audio/TTS.tsx index 3ceacb7f8d..9343b483d2 100644 --- a/client/src/components/Audio/TTS.tsx +++ b/client/src/components/Audio/TTS.tsx @@ -1,5 +1,5 @@ /* eslint-disable jsx-a11y/media-has-caption */ -import { useEffect, useMemo } from 'react'; +import { useEffect } from 'react'; import { useRecoilValue } from 'recoil'; import type { TMessageAudio } from '~/common'; import { useLocalize, useTTSBrowser, useTTSExternal } from '~/hooks'; @@ -7,7 +7,14 @@ import { VolumeIcon, VolumeMuteIcon, Spinner } from '~/components'; import { logger } from '~/utils'; import store from '~/store'; -export function BrowserTTS({ isLast, index, messageId, content, className }: TMessageAudio) { +export function BrowserTTS({ + isLast, + index, + messageId, + content, + className, + renderButton, +}: TMessageAudio) { const localize = useLocalize(); const playbackRate = useRecoilValue(store.playbackRate); @@ -18,16 +25,16 @@ export function BrowserTTS({ isLast, index, messageId, content, className }: TMe content, }); - const renderIcon = (size: string) => { + const renderIcon = () => { if (isLoading === true) { - return ; + return ; } if (isSpeaking === true) { - return ; + return ; } - return ; + return ; }; useEffect(() => { @@ -46,21 +53,30 @@ export function BrowserTTS({ isLast, index, messageId, content, className }: TMe audioRef.current, ); + const handleClick = () => { + if (audioRef.current) { + audioRef.current.muted = false; + } + toggleSpeech(); + }; + + const title = isSpeaking === true ? localize('com_ui_stop') : localize('com_ui_read_aloud'); + return ( <> - + {renderButton ? ( + renderButton({ + onClick: handleClick, + title: title, + icon: renderIcon(), + isActive: isSpeaking, + className, + }) + ) : ( + + )}