diff --git a/.env.example b/.env.example index f79b89a155..a58a37efb6 100644 --- a/.env.example +++ b/.env.example @@ -58,7 +58,7 @@ DEBUG_CONSOLE=false # Endpoints # #===================================================# -# ENDPOINTS=openAI,assistants,azureOpenAI,google,gptPlugins,anthropic +# ENDPOINTS=openAI,assistants,azureOpenAI,google,anthropic PROXY= @@ -142,10 +142,10 @@ GOOGLE_KEY=user_provided # GOOGLE_AUTH_HEADER=true # Gemini API (AI Studio) -# GOOGLE_MODELS=gemini-2.5-pro-preview-05-06,gemini-2.5-flash-preview-04-17,gemini-2.0-flash-001,gemini-2.0-flash-exp,gemini-2.0-flash-lite-001,gemini-1.5-pro-002,gemini-1.5-flash-002 +# GOOGLE_MODELS=gemini-2.5-pro,gemini-2.5-flash,gemini-2.5-flash-lite-preview-06-17,gemini-2.0-flash,gemini-2.0-flash-lite # Vertex AI -# GOOGLE_MODELS=gemini-2.5-pro-preview-05-06,gemini-2.5-flash-preview-04-17,gemini-2.0-flash-001,gemini-2.0-flash-exp,gemini-2.0-flash-lite-001,gemini-1.5-pro-002,gemini-1.5-flash-002 +# GOOGLE_MODELS=gemini-2.5-pro,gemini-2.5-flash,gemini-2.5-flash-lite-preview-06-17,gemini-2.0-flash-001,gemini-2.0-flash-lite-001 # GOOGLE_TITLE_MODEL=gemini-2.0-flash-lite-001 @@ -453,8 +453,8 @@ OPENID_REUSE_TOKENS= OPENID_JWKS_URL_CACHE_ENABLED= OPENID_JWKS_URL_CACHE_TIME= # 600000 ms eq to 10 minutes leave empty to disable caching #Set to true to trigger token exchange flow to acquire access token for the userinfo endpoint. -OPENID_ON_BEHALF_FLOW_FOR_USERINFRO_REQUIRED= -OPENID_ON_BEHALF_FLOW_USERINFRO_SCOPE = "user.read" # example for Scope Needed for Microsoft Graph API +OPENID_ON_BEHALF_FLOW_FOR_USERINFO_REQUIRED= +OPENID_ON_BEHALF_FLOW_USERINFO_SCOPE="user.read" # example for Scope Needed for Microsoft Graph API # Set to true to use the OpenID Connect end session endpoint for logout OPENID_USE_END_SESSION_ENDPOINT= @@ -515,6 +515,18 @@ EMAIL_PASSWORD= EMAIL_FROM_NAME= EMAIL_FROM=noreply@librechat.ai +#========================# +# Mailgun API # +#========================# + +# MAILGUN_API_KEY=your-mailgun-api-key +# MAILGUN_DOMAIN=mg.yourdomain.com +# EMAIL_FROM=noreply@yourdomain.com +# EMAIL_FROM_NAME="LibreChat" + +# # Optional: For EU region +# MAILGUN_HOST=https://api.eu.mailgun.net + #========================# # Firebase CDN # #========================# @@ -645,4 +657,4 @@ OPENWEATHER_API_KEY= # Reranker (Required) # JINA_API_KEY=your_jina_api_key # or -# COHERE_API_KEY=your_cohere_api_key \ No newline at end of file +# COHERE_API_KEY=your_cohere_api_key diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 09444a1b44..207aa17e66 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -30,8 +30,8 @@ Project maintainers have the right and responsibility to remove, edit, or reject 2. Install typescript globally: `npm i -g typescript`. 3. Run `npm ci` to install dependencies. 4. Build the data provider: `npm run build:data-provider`. -5. Build MCP: `npm run build:mcp`. -6. Build data schemas: `npm run build:data-schemas`. +5. Build data schemas: `npm run build:data-schemas`. +6. Build API methods: `npm run build:api`. 7. Setup and run unit tests: - Copy `.env.test`: `cp api/test/.env.test.example api/test/.env.test`. - Run backend unit tests: `npm run test:api`. diff --git a/.github/workflows/backend-review.yml b/.github/workflows/backend-review.yml index b7bccecae8..7637b8cdc0 100644 --- a/.github/workflows/backend-review.yml +++ b/.github/workflows/backend-review.yml @@ -7,6 +7,7 @@ on: - release/* paths: - 'api/**' + - 'packages/api/**' jobs: tests_Backend: name: Run Backend unit tests @@ -36,12 +37,12 @@ jobs: - name: Install Data Provider Package run: npm run build:data-provider - - name: Install MCP Package - run: npm run build:mcp - - name: Install Data Schemas Package run: npm run build:data-schemas + - name: Install API Package + run: npm run build:api + - name: Create empty auth.json file run: | mkdir -p api/data @@ -66,5 +67,8 @@ jobs: - name: Run librechat-data-provider unit tests run: cd packages/data-provider && npm run test:ci - - name: Run librechat-mcp unit tests - run: cd packages/mcp && npm run test:ci \ No newline at end of file + - name: Run @librechat/data-schemas unit tests + run: cd packages/data-schemas && npm run test:ci + + - name: Run @librechat/api unit tests + run: cd packages/api && npm run test:ci \ No newline at end of file diff --git a/.github/workflows/deploy-dev.yml b/.github/workflows/deploy-dev.yml index fc1c02db69..a255932e3e 100644 --- a/.github/workflows/deploy-dev.yml +++ b/.github/workflows/deploy-dev.yml @@ -2,7 +2,7 @@ name: Update Test Server on: workflow_run: - workflows: ["Docker Dev Images Build"] + workflows: ["Docker Dev Branch Images Build"] types: - completed workflow_dispatch: @@ -12,7 +12,8 @@ jobs: runs-on: ubuntu-latest if: | github.repository == 'danny-avila/LibreChat' && - (github.event_name == 'workflow_dispatch' || github.event.workflow_run.conclusion == 'success') + (github.event_name == 'workflow_dispatch' || + (github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.head_branch == 'dev')) steps: - name: Checkout repository uses: actions/checkout@v4 @@ -29,13 +30,17 @@ jobs: DO_USER: ${{ secrets.DO_USER }} run: | ssh -o StrictHostKeyChecking=no ${DO_USER}@${DO_HOST} << EOF - sudo -i -u danny bash << EEOF + sudo -i -u danny bash << 'EEOF' cd ~/LibreChat && \ git fetch origin main && \ - npm run update:deployed && \ + sudo npm run stop:deployed && \ + sudo docker images --format "{{.Repository}}:{{.ID}}" | grep -E "lc-dev|librechat" | cut -d: -f2 | xargs -r sudo docker rmi -f || true && \ + sudo npm run update:deployed && \ + git checkout dev && \ + git pull origin dev && \ git checkout do-deploy && \ - git rebase main && \ - npm run start:deployed && \ + git rebase dev && \ + sudo npm run start:deployed && \ echo "Update completed. Application should be running now." EEOF EOF diff --git a/.github/workflows/dev-branch-images.yml b/.github/workflows/dev-branch-images.yml new file mode 100644 index 0000000000..b7ad470314 --- /dev/null +++ b/.github/workflows/dev-branch-images.yml @@ -0,0 +1,72 @@ +name: Docker Dev Branch Images Build + +on: + workflow_dispatch: + push: + branches: + - dev + paths: + - 'api/**' + - 'client/**' + - 'packages/**' + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + include: + - target: api-build + file: Dockerfile.multi + image_name: lc-dev-api + - target: node + file: Dockerfile + image_name: lc-dev + + steps: + # Check out the repository + - name: Checkout + uses: actions/checkout@v4 + + # Set up QEMU + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + # Set up Docker Buildx + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + # Log in to GitHub Container Registry + - name: Log in to GitHub Container Registry + uses: docker/login-action@v2 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + # Login to Docker Hub + - name: Login to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + # Prepare the environment + - name: Prepare environment + run: | + cp .env.example .env + + # Build and push Docker images for each target + - name: Build and push Docker images + uses: docker/build-push-action@v5 + with: + context: . + file: ${{ matrix.file }} + push: true + tags: | + ghcr.io/${{ github.repository_owner }}/${{ matrix.image_name }}:${{ github.sha }} + ghcr.io/${{ github.repository_owner }}/${{ matrix.image_name }}:latest + ${{ secrets.DOCKERHUB_USERNAME }}/${{ matrix.image_name }}:${{ github.sha }} + ${{ secrets.DOCKERHUB_USERNAME }}/${{ matrix.image_name }}:latest + platforms: linux/amd64,linux/arm64 + target: ${{ matrix.target }} \ No newline at end of file diff --git a/.github/workflows/i18n-unused-keys.yml b/.github/workflows/i18n-unused-keys.yml index 6bcf824946..07cc77a1ae 100644 --- a/.github/workflows/i18n-unused-keys.yml +++ b/.github/workflows/i18n-unused-keys.yml @@ -5,12 +5,13 @@ on: paths: - "client/src/**" - "api/**" + - "packages/data-provider/src/**" jobs: detect-unused-i18n-keys: runs-on: ubuntu-latest permissions: - pull-requests: write # Required for posting PR comments + pull-requests: write steps: - name: Checkout repository uses: actions/checkout@v3 diff --git a/.github/workflows/unused-packages.yml b/.github/workflows/unused-packages.yml index 442e70e52c..dc6ce3ba56 100644 --- a/.github/workflows/unused-packages.yml +++ b/.github/workflows/unused-packages.yml @@ -98,6 +98,8 @@ jobs: cd client UNUSED=$(depcheck --json | jq -r '.dependencies | join("\n")' || echo "") UNUSED=$(comm -23 <(echo "$UNUSED" | sort) <(cat ../client_used_deps.txt ../client_used_code.txt | sort) || echo "") + # Filter out false positives + UNUSED=$(echo "$UNUSED" | grep -v "^micromark-extension-llm-math$" || echo "") echo "CLIENT_UNUSED<> $GITHUB_ENV echo "$UNUSED" >> $GITHUB_ENV echo "EOF" >> $GITHUB_ENV diff --git a/.gitignore b/.gitignore index f49594afdf..c9658f17e6 100644 --- a/.gitignore +++ b/.gitignore @@ -55,6 +55,7 @@ bower_components/ # AI .clineignore .cursor +.aider* # Floobits .floo diff --git a/Dockerfile b/Dockerfile index 393b35354d..02bcb7da1f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -# v0.7.8 +# v0.7.9-rc1 # Base node image FROM node:20-alpine AS node diff --git a/Dockerfile.multi b/Dockerfile.multi index 991f805bec..9738f4e1f3 100644 --- a/Dockerfile.multi +++ b/Dockerfile.multi @@ -1,5 +1,5 @@ # Dockerfile.multi -# v0.7.8 +# v0.7.9-rc1 # Base for all builds FROM node:20-alpine AS base-min @@ -14,7 +14,7 @@ RUN npm config set fetch-retry-maxtimeout 600000 && \ npm config set fetch-retry-mintimeout 15000 COPY package*.json ./ COPY packages/data-provider/package*.json ./packages/data-provider/ -COPY packages/mcp/package*.json ./packages/mcp/ +COPY packages/api/package*.json ./packages/api/ COPY packages/data-schemas/package*.json ./packages/data-schemas/ COPY client/package*.json ./client/ COPY api/package*.json ./api/ @@ -24,26 +24,27 @@ FROM base-min AS base WORKDIR /app RUN npm ci -# Build data-provider +# Build `data-provider` package FROM base AS data-provider-build WORKDIR /app/packages/data-provider COPY packages/data-provider ./ RUN npm run build -# Build mcp package -FROM base AS mcp-build -WORKDIR /app/packages/mcp -COPY packages/mcp ./ -COPY --from=data-provider-build /app/packages/data-provider/dist /app/packages/data-provider/dist -RUN npm run build - -# Build data-schemas +# Build `data-schemas` package FROM base AS data-schemas-build WORKDIR /app/packages/data-schemas COPY packages/data-schemas ./ COPY --from=data-provider-build /app/packages/data-provider/dist /app/packages/data-provider/dist RUN npm run build +# Build `api` package +FROM base AS api-package-build +WORKDIR /app/packages/api +COPY packages/api ./ +COPY --from=data-provider-build /app/packages/data-provider/dist /app/packages/data-provider/dist +COPY --from=data-schemas-build /app/packages/data-schemas/dist /app/packages/data-schemas/dist +RUN npm run build + # Client build FROM base AS client-build WORKDIR /app/client @@ -63,8 +64,8 @@ RUN npm ci --omit=dev COPY api ./api COPY config ./config COPY --from=data-provider-build /app/packages/data-provider/dist ./packages/data-provider/dist -COPY --from=mcp-build /app/packages/mcp/dist ./packages/mcp/dist COPY --from=data-schemas-build /app/packages/data-schemas/dist ./packages/data-schemas/dist +COPY --from=api-package-build /app/packages/api/dist ./packages/api/dist COPY --from=client-build /app/client/dist ./client/dist WORKDIR /app/api EXPOSE 3080 diff --git a/README.md b/README.md index cc9533b2d2..d6bd19ab43 100644 --- a/README.md +++ b/README.md @@ -150,8 +150,8 @@ Click on the thumbnail to open the video☝️ **Other:** - **Website:** [librechat.ai](https://librechat.ai) - - **Documentation:** [docs.librechat.ai](https://docs.librechat.ai) - - **Blog:** [blog.librechat.ai](https://blog.librechat.ai) + - **Documentation:** [librechat.ai/docs](https://librechat.ai/docs) + - **Blog:** [librechat.ai/blog](https://librechat.ai/blog) --- diff --git a/api/app/clients/AnthropicClient.js b/api/app/clients/AnthropicClient.js index 0da331ced5..a3fba29d5c 100644 --- a/api/app/clients/AnthropicClient.js +++ b/api/app/clients/AnthropicClient.js @@ -10,6 +10,7 @@ const { validateVisionModel, } = require('librechat-data-provider'); const { SplitStreamHandler: _Handler } = require('@librechat/agents'); +const { Tokenizer, createFetch, createStreamEventHandlers } = require('@librechat/api'); const { truncateText, formatMessage, @@ -26,8 +27,6 @@ const { const { getModelMaxTokens, getModelMaxOutputTokens, matchModelName } = require('~/utils'); const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); const { encodeAndFormat } = require('~/server/services/Files/images/encode'); -const { createFetch, createStreamEventHandlers } = require('./generators'); -const Tokenizer = require('~/server/services/Tokenizer'); const { sleep } = require('~/server/utils'); const BaseClient = require('./BaseClient'); const { logger } = require('~/config'); @@ -191,10 +190,11 @@ class AnthropicClient extends BaseClient { reverseProxyUrl: this.options.reverseProxyUrl, }), apiKey: this.apiKey, + fetchOptions: {}, }; if (this.options.proxy) { - options.httpAgent = new HttpsProxyAgent(this.options.proxy); + options.fetchOptions.agent = new HttpsProxyAgent(this.options.proxy); } if (this.options.reverseProxyUrl) { diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index 55b8780180..0598f0da21 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -13,7 +13,6 @@ const { const { getMessages, saveMessage, updateMessage, saveConvo, getConvo } = require('~/models'); const { checkBalance } = require('~/models/balanceMethods'); const { truncateToolCallOutputs } = require('./prompts'); -const { addSpaceIfNeeded } = require('~/server/utils'); const { getFiles } = require('~/models/File'); const TextStream = require('./TextStream'); const { logger } = require('~/config'); @@ -572,7 +571,7 @@ class BaseClient { }); } - const { generation = '' } = opts; + const { editedContent } = opts; // It's not necessary to push to currentMessages // depending on subclass implementation of handling messages @@ -587,11 +586,21 @@ class BaseClient { isCreatedByUser: false, model: this.modelOptions?.model ?? this.model, sender: this.sender, - text: generation, }; this.currentMessages.push(userMessage, latestMessage); - } else { - latestMessage.text = generation; + } else if (editedContent != null) { + // Handle editedContent for content parts + if (editedContent && latestMessage.content && Array.isArray(latestMessage.content)) { + const { index, text, type } = editedContent; + if (index >= 0 && index < latestMessage.content.length) { + const contentPart = latestMessage.content[index]; + if (type === ContentTypes.THINK && contentPart.type === ContentTypes.THINK) { + contentPart[ContentTypes.THINK] = text; + } else if (type === ContentTypes.TEXT && contentPart.type === ContentTypes.TEXT) { + contentPart[ContentTypes.TEXT] = text; + } + } + } } this.continued = true; } else { @@ -672,16 +681,32 @@ class BaseClient { }; if (typeof completion === 'string') { - responseMessage.text = addSpaceIfNeeded(generation) + completion; + responseMessage.text = completion; } else if ( Array.isArray(completion) && (this.clientName === EModelEndpoint.agents || isParamEndpoint(this.options.endpoint, this.options.endpointType)) ) { responseMessage.text = ''; - responseMessage.content = completion; + + if (!opts.editedContent || this.currentMessages.length === 0) { + responseMessage.content = completion; + } else { + const latestMessage = this.currentMessages[this.currentMessages.length - 1]; + if (!latestMessage?.content) { + responseMessage.content = completion; + } else { + const existingContent = [...latestMessage.content]; + const { type: editedType } = opts.editedContent; + responseMessage.content = this.mergeEditedContent( + existingContent, + completion, + editedType, + ); + } + } } else if (Array.isArray(completion)) { - responseMessage.text = addSpaceIfNeeded(generation) + completion.join(''); + responseMessage.text = completion.join(''); } if ( @@ -792,7 +817,8 @@ class BaseClient { userMessage.tokenCount = userMessageTokenCount; /* - Note: `AskController` saves the user message, so we update the count of its `userMessage` reference + Note: `AgentController` saves the user message if not saved here + (noted by `savedMessageIds`), so we update the count of its `userMessage` reference */ if (typeof opts?.getReqData === 'function') { opts.getReqData({ @@ -801,7 +827,8 @@ class BaseClient { } /* Note: we update the user message to be sure it gets the calculated token count; - though `AskController` saves the user message, EditController does not + though `AgentController` saves the user message if not saved here + (noted by `savedMessageIds`), EditController does not */ await userMessagePromise; await this.updateMessageInDatabase({ @@ -1093,6 +1120,50 @@ class BaseClient { return numTokens; } + /** + * Merges completion content with existing content when editing TEXT or THINK types + * @param {Array} existingContent - The existing content array + * @param {Array} newCompletion - The new completion content + * @param {string} editedType - The type of content being edited + * @returns {Array} The merged content array + */ + mergeEditedContent(existingContent, newCompletion, editedType) { + if (!newCompletion.length) { + return existingContent.concat(newCompletion); + } + + if (editedType !== ContentTypes.TEXT && editedType !== ContentTypes.THINK) { + return existingContent.concat(newCompletion); + } + + const lastIndex = existingContent.length - 1; + const lastExisting = existingContent[lastIndex]; + const firstNew = newCompletion[0]; + + if (lastExisting?.type !== firstNew?.type || firstNew?.type !== editedType) { + return existingContent.concat(newCompletion); + } + + const mergedContent = [...existingContent]; + if (editedType === ContentTypes.TEXT) { + mergedContent[lastIndex] = { + ...mergedContent[lastIndex], + [ContentTypes.TEXT]: + (mergedContent[lastIndex][ContentTypes.TEXT] || '') + (firstNew[ContentTypes.TEXT] || ''), + }; + } else { + mergedContent[lastIndex] = { + ...mergedContent[lastIndex], + [ContentTypes.THINK]: + (mergedContent[lastIndex][ContentTypes.THINK] || '') + + (firstNew[ContentTypes.THINK] || ''), + }; + } + + // Add remaining completion items + return mergedContent.concat(newCompletion.slice(1)); + } + async sendPayload(payload, opts = {}) { if (opts && typeof opts === 'object') { this.setOptions(opts); diff --git a/api/app/clients/ChatGPTClient.js b/api/app/clients/ChatGPTClient.js deleted file mode 100644 index 07b2fa97bb..0000000000 --- a/api/app/clients/ChatGPTClient.js +++ /dev/null @@ -1,804 +0,0 @@ -const { Keyv } = require('keyv'); -const crypto = require('crypto'); -const { CohereClient } = require('cohere-ai'); -const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source'); -const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken'); -const { - ImageDetail, - EModelEndpoint, - resolveHeaders, - CohereConstants, - mapModelToAzureConfig, -} = require('librechat-data-provider'); -const { extractBaseURL, constructAzureURL, genAzureChatCompletion } = require('~/utils'); -const { createContextHandlers } = require('./prompts'); -const { createCoherePayload } = require('./llm'); -const BaseClient = require('./BaseClient'); -const { logger } = require('~/config'); - -const CHATGPT_MODEL = 'gpt-3.5-turbo'; -const tokenizersCache = {}; - -class ChatGPTClient extends BaseClient { - constructor(apiKey, options = {}, cacheOptions = {}) { - super(apiKey, options, cacheOptions); - - cacheOptions.namespace = cacheOptions.namespace || 'chatgpt'; - this.conversationsCache = new Keyv(cacheOptions); - this.setOptions(options); - } - - setOptions(options) { - if (this.options && !this.options.replaceOptions) { - // nested options aren't spread properly, so we need to do this manually - this.options.modelOptions = { - ...this.options.modelOptions, - ...options.modelOptions, - }; - delete options.modelOptions; - // now we can merge options - this.options = { - ...this.options, - ...options, - }; - } else { - this.options = options; - } - - if (this.options.openaiApiKey) { - this.apiKey = this.options.openaiApiKey; - } - - const modelOptions = this.options.modelOptions || {}; - this.modelOptions = { - ...modelOptions, - // set some good defaults (check for undefined in some cases because they may be 0) - model: modelOptions.model || CHATGPT_MODEL, - temperature: typeof modelOptions.temperature === 'undefined' ? 0.8 : modelOptions.temperature, - top_p: typeof modelOptions.top_p === 'undefined' ? 1 : modelOptions.top_p, - presence_penalty: - typeof modelOptions.presence_penalty === 'undefined' ? 1 : modelOptions.presence_penalty, - stop: modelOptions.stop, - }; - - this.isChatGptModel = this.modelOptions.model.includes('gpt-'); - const { isChatGptModel } = this; - this.isUnofficialChatGptModel = - this.modelOptions.model.startsWith('text-chat') || - this.modelOptions.model.startsWith('text-davinci-002-render'); - const { isUnofficialChatGptModel } = this; - - // Davinci models have a max context length of 4097 tokens. - this.maxContextTokens = this.options.maxContextTokens || (isChatGptModel ? 4095 : 4097); - // I decided to reserve 1024 tokens for the response. - // The max prompt tokens is determined by the max context tokens minus the max response tokens. - // Earlier messages will be dropped until the prompt is within the limit. - this.maxResponseTokens = this.modelOptions.max_tokens || 1024; - this.maxPromptTokens = - this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens; - - if (this.maxPromptTokens + this.maxResponseTokens > this.maxContextTokens) { - throw new Error( - `maxPromptTokens + max_tokens (${this.maxPromptTokens} + ${this.maxResponseTokens} = ${ - this.maxPromptTokens + this.maxResponseTokens - }) must be less than or equal to maxContextTokens (${this.maxContextTokens})`, - ); - } - - this.userLabel = this.options.userLabel || 'User'; - this.chatGptLabel = this.options.chatGptLabel || 'ChatGPT'; - - if (isChatGptModel) { - // Use these faux tokens to help the AI understand the context since we are building the chat log ourselves. - // Trying to use "<|im_start|>" causes the AI to still generate "<" or "<|" at the end sometimes for some reason, - // without tripping the stop sequences, so I'm using "||>" instead. - this.startToken = '||>'; - this.endToken = ''; - this.gptEncoder = this.constructor.getTokenizer('cl100k_base'); - } else if (isUnofficialChatGptModel) { - this.startToken = '<|im_start|>'; - this.endToken = '<|im_end|>'; - this.gptEncoder = this.constructor.getTokenizer('text-davinci-003', true, { - '<|im_start|>': 100264, - '<|im_end|>': 100265, - }); - } else { - // Previously I was trying to use "<|endoftext|>" but there seems to be some bug with OpenAI's token counting - // system that causes only the first "<|endoftext|>" to be counted as 1 token, and the rest are not treated - // as a single token. So we're using this instead. - this.startToken = '||>'; - this.endToken = ''; - try { - this.gptEncoder = this.constructor.getTokenizer(this.modelOptions.model, true); - } catch { - this.gptEncoder = this.constructor.getTokenizer('text-davinci-003', true); - } - } - - if (!this.modelOptions.stop) { - const stopTokens = [this.startToken]; - if (this.endToken && this.endToken !== this.startToken) { - stopTokens.push(this.endToken); - } - stopTokens.push(`\n${this.userLabel}:`); - stopTokens.push('<|diff_marker|>'); - // I chose not to do one for `chatGptLabel` because I've never seen it happen - this.modelOptions.stop = stopTokens; - } - - if (this.options.reverseProxyUrl) { - this.completionsUrl = this.options.reverseProxyUrl; - } else if (isChatGptModel) { - this.completionsUrl = 'https://api.openai.com/v1/chat/completions'; - } else { - this.completionsUrl = 'https://api.openai.com/v1/completions'; - } - - return this; - } - - static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) { - if (tokenizersCache[encoding]) { - return tokenizersCache[encoding]; - } - let tokenizer; - if (isModelName) { - tokenizer = encodingForModel(encoding, extendSpecialTokens); - } else { - tokenizer = getEncoding(encoding, extendSpecialTokens); - } - tokenizersCache[encoding] = tokenizer; - return tokenizer; - } - - /** @type {getCompletion} */ - async getCompletion(input, onProgress, onTokenProgress, abortController = null) { - if (!abortController) { - abortController = new AbortController(); - } - - let modelOptions = { ...this.modelOptions }; - if (typeof onProgress === 'function') { - modelOptions.stream = true; - } - if (this.isChatGptModel) { - modelOptions.messages = input; - } else { - modelOptions.prompt = input; - } - - if (this.useOpenRouter && modelOptions.prompt) { - delete modelOptions.stop; - } - - const { debug } = this.options; - let baseURL = this.completionsUrl; - if (debug) { - console.debug(); - console.debug(baseURL); - console.debug(modelOptions); - console.debug(); - } - - const opts = { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - }; - - if (this.isVisionModel) { - modelOptions.max_tokens = 4000; - } - - /** @type {TAzureConfig | undefined} */ - const azureConfig = this.options?.req?.app?.locals?.[EModelEndpoint.azureOpenAI]; - - const isAzure = this.azure || this.options.azure; - if ( - (isAzure && this.isVisionModel && azureConfig) || - (azureConfig && this.isVisionModel && this.options.endpoint === EModelEndpoint.azureOpenAI) - ) { - const { modelGroupMap, groupMap } = azureConfig; - const { - azureOptions, - baseURL, - headers = {}, - serverless, - } = mapModelToAzureConfig({ - modelName: modelOptions.model, - modelGroupMap, - groupMap, - }); - opts.headers = resolveHeaders(headers); - this.langchainProxy = extractBaseURL(baseURL); - this.apiKey = azureOptions.azureOpenAIApiKey; - - const groupName = modelGroupMap[modelOptions.model].group; - this.options.addParams = azureConfig.groupMap[groupName].addParams; - this.options.dropParams = azureConfig.groupMap[groupName].dropParams; - // Note: `forcePrompt` not re-assigned as only chat models are vision models - - this.azure = !serverless && azureOptions; - this.azureEndpoint = - !serverless && genAzureChatCompletion(this.azure, modelOptions.model, this); - if (serverless === true) { - this.options.defaultQuery = azureOptions.azureOpenAIApiVersion - ? { 'api-version': azureOptions.azureOpenAIApiVersion } - : undefined; - this.options.headers['api-key'] = this.apiKey; - } - } - - if (this.options.defaultQuery) { - opts.defaultQuery = this.options.defaultQuery; - } - - if (this.options.headers) { - opts.headers = { ...opts.headers, ...this.options.headers }; - } - - if (isAzure) { - // Azure does not accept `model` in the body, so we need to remove it. - delete modelOptions.model; - - baseURL = this.langchainProxy - ? constructAzureURL({ - baseURL: this.langchainProxy, - azureOptions: this.azure, - }) - : this.azureEndpoint.split(/(? msg.role === 'system'); - - if (systemMessageIndex > 0) { - const [systemMessage] = messages.splice(systemMessageIndex, 1); - messages.unshift(systemMessage); - } - - modelOptions.messages = messages; - - if (messages.length === 1 && messages[0].role === 'system') { - modelOptions.messages[0].role = 'user'; - } - } - - if (this.options.addParams && typeof this.options.addParams === 'object') { - modelOptions = { - ...modelOptions, - ...this.options.addParams, - }; - logger.debug('[ChatGPTClient] chatCompletion: added params', { - addParams: this.options.addParams, - modelOptions, - }); - } - - if (this.options.dropParams && Array.isArray(this.options.dropParams)) { - this.options.dropParams.forEach((param) => { - delete modelOptions[param]; - }); - logger.debug('[ChatGPTClient] chatCompletion: dropped params', { - dropParams: this.options.dropParams, - modelOptions, - }); - } - - if (baseURL.startsWith(CohereConstants.API_URL)) { - const payload = createCoherePayload({ modelOptions }); - return await this.cohereChatCompletion({ payload, onTokenProgress }); - } - - if (baseURL.includes('v1') && !baseURL.includes('/completions') && !this.isChatCompletion) { - baseURL = baseURL.split('v1')[0] + 'v1/completions'; - } else if ( - baseURL.includes('v1') && - !baseURL.includes('/chat/completions') && - this.isChatCompletion - ) { - baseURL = baseURL.split('v1')[0] + 'v1/chat/completions'; - } - - const BASE_URL = new URL(baseURL); - if (opts.defaultQuery) { - Object.entries(opts.defaultQuery).forEach(([key, value]) => { - BASE_URL.searchParams.append(key, value); - }); - delete opts.defaultQuery; - } - - const completionsURL = BASE_URL.toString(); - opts.body = JSON.stringify(modelOptions); - - if (modelOptions.stream) { - - return new Promise(async (resolve, reject) => { - try { - let done = false; - await fetchEventSource(completionsURL, { - ...opts, - signal: abortController.signal, - async onopen(response) { - if (response.status === 200) { - return; - } - if (debug) { - console.debug(response); - } - let error; - try { - const body = await response.text(); - error = new Error(`Failed to send message. HTTP ${response.status} - ${body}`); - error.status = response.status; - error.json = JSON.parse(body); - } catch { - error = error || new Error(`Failed to send message. HTTP ${response.status}`); - } - throw error; - }, - onclose() { - if (debug) { - console.debug('Server closed the connection unexpectedly, returning...'); - } - // workaround for private API not sending [DONE] event - if (!done) { - onProgress('[DONE]'); - resolve(); - } - }, - onerror(err) { - if (debug) { - console.debug(err); - } - // rethrow to stop the operation - throw err; - }, - onmessage(message) { - if (debug) { - console.debug(message); - } - if (!message.data || message.event === 'ping') { - return; - } - if (message.data === '[DONE]') { - onProgress('[DONE]'); - resolve(); - done = true; - return; - } - onProgress(JSON.parse(message.data)); - }, - }); - } catch (err) { - reject(err); - } - }); - } - const response = await fetch(completionsURL, { - ...opts, - signal: abortController.signal, - }); - if (response.status !== 200) { - const body = await response.text(); - const error = new Error(`Failed to send message. HTTP ${response.status} - ${body}`); - error.status = response.status; - try { - error.json = JSON.parse(body); - } catch { - error.body = body; - } - throw error; - } - return response.json(); - } - - /** @type {cohereChatCompletion} */ - async cohereChatCompletion({ payload, onTokenProgress }) { - const cohere = new CohereClient({ - token: this.apiKey, - environment: this.completionsUrl, - }); - - if (!payload.stream) { - const chatResponse = await cohere.chat(payload); - return chatResponse.text; - } - - const chatStream = await cohere.chatStream(payload); - let reply = ''; - for await (const message of chatStream) { - if (!message) { - continue; - } - - if (message.eventType === 'text-generation' && message.text) { - onTokenProgress(message.text); - reply += message.text; - } - /* - Cohere API Chinese Unicode character replacement hotfix. - Should be un-commented when the following issue is resolved: - https://github.com/cohere-ai/cohere-typescript/issues/151 - - else if (message.eventType === 'stream-end' && message.response) { - reply = message.response.text; - } - */ - } - - return reply; - } - - async generateTitle(userMessage, botMessage) { - const instructionsPayload = { - role: 'system', - content: `Write an extremely concise subtitle for this conversation with no more than a few words. All words should be capitalized. Exclude punctuation. - -||>Message: -${userMessage.message} -||>Response: -${botMessage.message} - -||>Title:`, - }; - - const titleGenClientOptions = JSON.parse(JSON.stringify(this.options)); - titleGenClientOptions.modelOptions = { - model: 'gpt-3.5-turbo', - temperature: 0, - presence_penalty: 0, - frequency_penalty: 0, - }; - const titleGenClient = new ChatGPTClient(this.apiKey, titleGenClientOptions); - const result = await titleGenClient.getCompletion([instructionsPayload], null); - // remove any non-alphanumeric characters, replace multiple spaces with 1, and then trim - return result.choices[0].message.content - .replace(/[^a-zA-Z0-9' ]/g, '') - .replace(/\s+/g, ' ') - .trim(); - } - - async sendMessage(message, opts = {}) { - if (opts.clientOptions && typeof opts.clientOptions === 'object') { - this.setOptions(opts.clientOptions); - } - - const conversationId = opts.conversationId || crypto.randomUUID(); - const parentMessageId = opts.parentMessageId || crypto.randomUUID(); - - let conversation = - typeof opts.conversation === 'object' - ? opts.conversation - : await this.conversationsCache.get(conversationId); - - let isNewConversation = false; - if (!conversation) { - conversation = { - messages: [], - createdAt: Date.now(), - }; - isNewConversation = true; - } - - const shouldGenerateTitle = opts.shouldGenerateTitle && isNewConversation; - - const userMessage = { - id: crypto.randomUUID(), - parentMessageId, - role: 'User', - message, - }; - conversation.messages.push(userMessage); - - // Doing it this way instead of having each message be a separate element in the array seems to be more reliable, - // especially when it comes to keeping the AI in character. It also seems to improve coherency and context retention. - const { prompt: payload, context } = await this.buildPrompt( - conversation.messages, - userMessage.id, - { - isChatGptModel: this.isChatGptModel, - promptPrefix: opts.promptPrefix, - }, - ); - - if (this.options.keepNecessaryMessagesOnly) { - conversation.messages = context; - } - - let reply = ''; - let result = null; - if (typeof opts.onProgress === 'function') { - await this.getCompletion( - payload, - (progressMessage) => { - if (progressMessage === '[DONE]') { - return; - } - const token = this.isChatGptModel - ? progressMessage.choices[0].delta.content - : progressMessage.choices[0].text; - // first event's delta content is always undefined - if (!token) { - return; - } - if (this.options.debug) { - console.debug(token); - } - if (token === this.endToken) { - return; - } - opts.onProgress(token); - reply += token; - }, - opts.abortController || new AbortController(), - ); - } else { - result = await this.getCompletion( - payload, - null, - opts.abortController || new AbortController(), - ); - if (this.options.debug) { - console.debug(JSON.stringify(result)); - } - if (this.isChatGptModel) { - reply = result.choices[0].message.content; - } else { - reply = result.choices[0].text.replace(this.endToken, ''); - } - } - - // avoids some rendering issues when using the CLI app - if (this.options.debug) { - console.debug(); - } - - reply = reply.trim(); - - const replyMessage = { - id: crypto.randomUUID(), - parentMessageId: userMessage.id, - role: 'ChatGPT', - message: reply, - }; - conversation.messages.push(replyMessage); - - const returnData = { - response: replyMessage.message, - conversationId, - parentMessageId: replyMessage.parentMessageId, - messageId: replyMessage.id, - details: result || {}, - }; - - if (shouldGenerateTitle) { - conversation.title = await this.generateTitle(userMessage, replyMessage); - returnData.title = conversation.title; - } - - await this.conversationsCache.set(conversationId, conversation); - - if (this.options.returnConversation) { - returnData.conversation = conversation; - } - - return returnData; - } - - async buildPrompt(messages, { isChatGptModel = false, promptPrefix = null }) { - promptPrefix = (promptPrefix || this.options.promptPrefix || '').trim(); - - // Handle attachments and create augmentedPrompt - if (this.options.attachments) { - const attachments = await this.options.attachments; - const lastMessage = messages[messages.length - 1]; - - if (this.message_file_map) { - this.message_file_map[lastMessage.messageId] = attachments; - } else { - this.message_file_map = { - [lastMessage.messageId]: attachments, - }; - } - - const files = await this.addImageURLs(lastMessage, attachments); - this.options.attachments = files; - - this.contextHandlers = createContextHandlers(this.options.req, lastMessage.text); - } - - if (this.message_file_map) { - this.contextHandlers = createContextHandlers( - this.options.req, - messages[messages.length - 1].text, - ); - } - - // Calculate image token cost and process embedded files - messages.forEach((message, i) => { - if (this.message_file_map && this.message_file_map[message.messageId]) { - const attachments = this.message_file_map[message.messageId]; - for (const file of attachments) { - if (file.embedded) { - this.contextHandlers?.processFile(file); - continue; - } - - messages[i].tokenCount = - (messages[i].tokenCount || 0) + - this.calculateImageTokenCost({ - width: file.width, - height: file.height, - detail: this.options.imageDetail ?? ImageDetail.auto, - }); - } - } - }); - - if (this.contextHandlers) { - this.augmentedPrompt = await this.contextHandlers.createContext(); - promptPrefix = this.augmentedPrompt + promptPrefix; - } - - if (promptPrefix) { - // If the prompt prefix doesn't end with the end token, add it. - if (!promptPrefix.endsWith(`${this.endToken}`)) { - promptPrefix = `${promptPrefix.trim()}${this.endToken}\n\n`; - } - promptPrefix = `${this.startToken}Instructions:\n${promptPrefix}`; - } - const promptSuffix = `${this.startToken}${this.chatGptLabel}:\n`; // Prompt ChatGPT to respond. - - const instructionsPayload = { - role: 'system', - content: promptPrefix, - }; - - const messagePayload = { - role: 'system', - content: promptSuffix, - }; - - let currentTokenCount; - if (isChatGptModel) { - currentTokenCount = - this.getTokenCountForMessage(instructionsPayload) + - this.getTokenCountForMessage(messagePayload); - } else { - currentTokenCount = this.getTokenCount(`${promptPrefix}${promptSuffix}`); - } - let promptBody = ''; - const maxTokenCount = this.maxPromptTokens; - - const context = []; - - // Iterate backwards through the messages, adding them to the prompt until we reach the max token count. - // Do this within a recursive async function so that it doesn't block the event loop for too long. - const buildPromptBody = async () => { - if (currentTokenCount < maxTokenCount && messages.length > 0) { - const message = messages.pop(); - const roleLabel = - message?.isCreatedByUser || message?.role?.toLowerCase() === 'user' - ? this.userLabel - : this.chatGptLabel; - const messageString = `${this.startToken}${roleLabel}:\n${ - message?.text ?? message?.message - }${this.endToken}\n`; - let newPromptBody; - if (promptBody || isChatGptModel) { - newPromptBody = `${messageString}${promptBody}`; - } else { - // Always insert prompt prefix before the last user message, if not gpt-3.5-turbo. - // This makes the AI obey the prompt instructions better, which is important for custom instructions. - // After a bunch of testing, it doesn't seem to cause the AI any confusion, even if you ask it things - // like "what's the last thing I wrote?". - newPromptBody = `${promptPrefix}${messageString}${promptBody}`; - } - - context.unshift(message); - - const tokenCountForMessage = this.getTokenCount(messageString); - const newTokenCount = currentTokenCount + tokenCountForMessage; - if (newTokenCount > maxTokenCount) { - if (promptBody) { - // This message would put us over the token limit, so don't add it. - return false; - } - // This is the first message, so we can't add it. Just throw an error. - throw new Error( - `Prompt is too long. Max token count is ${maxTokenCount}, but prompt is ${newTokenCount} tokens long.`, - ); - } - promptBody = newPromptBody; - currentTokenCount = newTokenCount; - // wait for next tick to avoid blocking the event loop - await new Promise((resolve) => setImmediate(resolve)); - return buildPromptBody(); - } - return true; - }; - - await buildPromptBody(); - - const prompt = `${promptBody}${promptSuffix}`; - if (isChatGptModel) { - messagePayload.content = prompt; - // Add 3 tokens for Assistant Label priming after all messages have been counted. - currentTokenCount += 3; - } - - // Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response. - this.modelOptions.max_tokens = Math.min( - this.maxContextTokens - currentTokenCount, - this.maxResponseTokens, - ); - - if (isChatGptModel) { - return { prompt: [instructionsPayload, messagePayload], context }; - } - return { prompt, context, promptTokens: currentTokenCount }; - } - - getTokenCount(text) { - return this.gptEncoder.encode(text, 'all').length; - } - - /** - * Algorithm adapted from "6. Counting tokens for chat API calls" of - * https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb - * - * An additional 3 tokens need to be added for assistant label priming after all messages have been counted. - * - * @param {Object} message - */ - getTokenCountForMessage(message) { - // Note: gpt-3.5-turbo and gpt-4 may update over time. Use default for these as well as for unknown models - let tokensPerMessage = 3; - let tokensPerName = 1; - - if (this.modelOptions.model === 'gpt-3.5-turbo-0301') { - tokensPerMessage = 4; - tokensPerName = -1; - } - - let numTokens = tokensPerMessage; - for (let [key, value] of Object.entries(message)) { - numTokens += this.getTokenCount(value); - if (key === 'name') { - numTokens += tokensPerName; - } - } - - return numTokens; - } -} - -module.exports = ChatGPTClient; diff --git a/api/app/clients/GoogleClient.js b/api/app/clients/GoogleClient.js index c9102e9ae2..2ec23a0a06 100644 --- a/api/app/clients/GoogleClient.js +++ b/api/app/clients/GoogleClient.js @@ -1,6 +1,7 @@ const { google } = require('googleapis'); const { concat } = require('@langchain/core/utils/stream'); const { ChatVertexAI } = require('@langchain/google-vertexai'); +const { Tokenizer, getSafetySettings } = require('@librechat/api'); const { ChatGoogleGenerativeAI } = require('@langchain/google-genai'); const { GoogleGenerativeAI: GenAI } = require('@google/generative-ai'); const { HumanMessage, SystemMessage } = require('@langchain/core/messages'); @@ -11,15 +12,14 @@ const { endpointSettings, parseTextParts, EModelEndpoint, + googleSettings, ContentTypes, VisionModes, ErrorTypes, Constants, AuthKeys, } = require('librechat-data-provider'); -const { getSafetySettings } = require('~/server/services/Endpoints/google/llm'); const { encodeAndFormat } = require('~/server/services/Files/images'); -const Tokenizer = require('~/server/services/Tokenizer'); const { spendTokens } = require('~/models/spendTokens'); const { getModelMaxTokens } = require('~/utils'); const { sleep } = require('~/server/utils'); @@ -34,7 +34,8 @@ const BaseClient = require('./BaseClient'); const loc = process.env.GOOGLE_LOC || 'us-central1'; const publisher = 'google'; -const endpointPrefix = `${loc}-aiplatform.googleapis.com`; +const endpointPrefix = + loc === 'global' ? 'aiplatform.googleapis.com' : `${loc}-aiplatform.googleapis.com`; const settings = endpointSettings[EModelEndpoint.google]; const EXCLUDED_GENAI_MODELS = /gemini-(?:1\.0|1-0|pro)/; @@ -165,6 +166,16 @@ class GoogleClient extends BaseClient { ); } + // Add thinking configuration + this.modelOptions.thinkingConfig = { + thinkingBudget: + (this.modelOptions.thinking ?? googleSettings.thinking.default) + ? this.modelOptions.thinkingBudget + : 0, + }; + delete this.modelOptions.thinking; + delete this.modelOptions.thinkingBudget; + this.sender = this.options.sender ?? getResponseSender({ @@ -236,11 +247,11 @@ class GoogleClient extends BaseClient { msg.content = ( !Array.isArray(msg.content) ? [ - { - type: ContentTypes.TEXT, - [ContentTypes.TEXT]: msg.content, - }, - ] + { + type: ContentTypes.TEXT, + [ContentTypes.TEXT]: msg.content, + }, + ] : msg.content ).concat(message.image_urls); diff --git a/api/app/clients/OllamaClient.js b/api/app/clients/OllamaClient.js index 77d007580c..032781f1f1 100644 --- a/api/app/clients/OllamaClient.js +++ b/api/app/clients/OllamaClient.js @@ -1,10 +1,11 @@ const { z } = require('zod'); const axios = require('axios'); const { Ollama } = require('ollama'); +const { sleep } = require('@librechat/agents'); +const { logAxiosError } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { Constants } = require('librechat-data-provider'); -const { deriveBaseURL, logAxiosError } = require('~/utils'); -const { sleep } = require('~/server/utils'); -const { logger } = require('~/config'); +const { deriveBaseURL } = require('~/utils'); const ollamaPayloadSchema = z.object({ mirostat: z.number().optional(), @@ -67,7 +68,7 @@ class OllamaClient { return models; } catch (error) { const logMessage = - 'Failed to fetch models from Ollama API. If you are not using Ollama directly, and instead, through some aggregator or reverse proxy that handles fetching via OpenAI spec, ensure the name of the endpoint doesn\'t start with `ollama` (case-insensitive).'; + "Failed to fetch models from Ollama API. If you are not using Ollama directly, and instead, through some aggregator or reverse proxy that handles fetching via OpenAI spec, ensure the name of the endpoint doesn't start with `ollama` (case-insensitive)."; logAxiosError({ message: logMessage, error }); return []; } diff --git a/api/app/clients/OpenAIClient.js b/api/app/clients/OpenAIClient.js index 280db89284..2eda322640 100644 --- a/api/app/clients/OpenAIClient.js +++ b/api/app/clients/OpenAIClient.js @@ -1,13 +1,21 @@ const { OllamaClient } = require('./OllamaClient'); const { HttpsProxyAgent } = require('https-proxy-agent'); const { SplitStreamHandler, CustomOpenAIClient: OpenAI } = require('@librechat/agents'); +const { + isEnabled, + Tokenizer, + createFetch, + resolveHeaders, + constructAzureURL, + genAzureChatCompletion, + createStreamEventHandlers, +} = require('@librechat/api'); const { Constants, ImageDetail, ContentTypes, parseTextParts, EModelEndpoint, - resolveHeaders, KnownEndpoints, openAISettings, ImageDetailCost, @@ -16,13 +24,6 @@ const { validateVisionModel, mapModelToAzureConfig, } = require('librechat-data-provider'); -const { - extractBaseURL, - constructAzureURL, - getModelMaxTokens, - genAzureChatCompletion, - getModelMaxOutputTokens, -} = require('~/utils'); const { truncateText, formatMessage, @@ -30,14 +31,12 @@ const { titleInstruction, createContextHandlers, } = require('./prompts'); +const { extractBaseURL, getModelMaxTokens, getModelMaxOutputTokens } = require('~/utils'); const { encodeAndFormat } = require('~/server/services/Files/images/encode'); -const { createFetch, createStreamEventHandlers } = require('./generators'); -const { addSpaceIfNeeded, isEnabled, sleep } = require('~/server/utils'); -const Tokenizer = require('~/server/services/Tokenizer'); +const { addSpaceIfNeeded, sleep } = require('~/server/utils'); const { spendTokens } = require('~/models/spendTokens'); const { handleOpenAIErrors } = require('./tools/util'); const { createLLM, RunManager } = require('./llm'); -const ChatGPTClient = require('./ChatGPTClient'); const { summaryBuffer } = require('./memory'); const { runTitleChain } = require('./chains'); const { tokenSplit } = require('./document'); @@ -47,12 +46,6 @@ const { logger } = require('~/config'); class OpenAIClient extends BaseClient { constructor(apiKey, options = {}) { super(apiKey, options); - this.ChatGPTClient = new ChatGPTClient(); - this.buildPrompt = this.ChatGPTClient.buildPrompt.bind(this); - /** @type {getCompletion} */ - this.getCompletion = this.ChatGPTClient.getCompletion.bind(this); - /** @type {cohereChatCompletion} */ - this.cohereChatCompletion = this.ChatGPTClient.cohereChatCompletion.bind(this); this.contextStrategy = options.contextStrategy ? options.contextStrategy.toLowerCase() : 'discard'; @@ -379,23 +372,12 @@ class OpenAIClient extends BaseClient { return files; } - async buildMessages( - messages, - parentMessageId, - { isChatCompletion = false, promptPrefix = null }, - opts, - ) { + async buildMessages(messages, parentMessageId, { promptPrefix = null }, opts) { let orderedMessages = this.constructor.getMessagesForConversation({ messages, parentMessageId, summary: this.shouldSummarize, }); - if (!isChatCompletion) { - return await this.buildPrompt(orderedMessages, { - isChatGptModel: isChatCompletion, - promptPrefix, - }); - } let payload; let instructions; @@ -1159,6 +1141,7 @@ ${convo} logger.debug('[OpenAIClient] chatCompletion', { baseURL, modelOptions }); const opts = { baseURL, + fetchOptions: {}, }; if (this.useOpenRouter) { @@ -1177,7 +1160,7 @@ ${convo} } if (this.options.proxy) { - opts.httpAgent = new HttpsProxyAgent(this.options.proxy); + opts.fetchOptions.agent = new HttpsProxyAgent(this.options.proxy); } /** @type {TAzureConfig | undefined} */ @@ -1395,7 +1378,7 @@ ${convo} ...modelOptions, stream: true, }; - const stream = await openai.beta.chat.completions + const stream = await openai.chat.completions .stream(params) .on('abort', () => { /* Do nothing here */ diff --git a/api/app/clients/PluginsClient.js b/api/app/clients/PluginsClient.js deleted file mode 100644 index d0ffe2ef75..0000000000 --- a/api/app/clients/PluginsClient.js +++ /dev/null @@ -1,542 +0,0 @@ -const OpenAIClient = require('./OpenAIClient'); -const { CallbackManager } = require('@langchain/core/callbacks/manager'); -const { BufferMemory, ChatMessageHistory } = require('langchain/memory'); -const { addImages, buildErrorInput, buildPromptPrefix } = require('./output_parsers'); -const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents'); -const { processFileURL } = require('~/server/services/Files/process'); -const { EModelEndpoint } = require('librechat-data-provider'); -const { checkBalance } = require('~/models/balanceMethods'); -const { formatLangChainMessages } = require('./prompts'); -const { extractBaseURL } = require('~/utils'); -const { loadTools } = require('./tools/util'); -const { logger } = require('~/config'); - -class PluginsClient extends OpenAIClient { - constructor(apiKey, options = {}) { - super(apiKey, options); - this.sender = options.sender ?? 'Assistant'; - this.tools = []; - this.actions = []; - this.setOptions(options); - this.openAIApiKey = this.apiKey; - this.executor = null; - } - - setOptions(options) { - this.agentOptions = { ...options.agentOptions }; - this.functionsAgent = this.agentOptions?.agent === 'functions'; - this.agentIsGpt3 = this.agentOptions?.model?.includes('gpt-3'); - - super.setOptions(options); - - this.isGpt3 = this.modelOptions?.model?.includes('gpt-3'); - - if (this.options.reverseProxyUrl) { - this.langchainProxy = extractBaseURL(this.options.reverseProxyUrl); - } - } - - getSaveOptions() { - return { - artifacts: this.options.artifacts, - chatGptLabel: this.options.chatGptLabel, - modelLabel: this.options.modelLabel, - promptPrefix: this.options.promptPrefix, - tools: this.options.tools, - ...this.modelOptions, - agentOptions: this.agentOptions, - iconURL: this.options.iconURL, - greeting: this.options.greeting, - spec: this.options.spec, - }; - } - - saveLatestAction(action) { - this.actions.push(action); - } - - getFunctionModelName(input) { - if (/-(?!0314)\d{4}/.test(input)) { - return input; - } else if (input.includes('gpt-3.5-turbo')) { - return 'gpt-3.5-turbo'; - } else if (input.includes('gpt-4')) { - return 'gpt-4'; - } else { - return 'gpt-3.5-turbo'; - } - } - - getBuildMessagesOptions(opts) { - return { - isChatCompletion: true, - promptPrefix: opts.promptPrefix, - abortController: opts.abortController, - }; - } - - async initialize({ user, message, onAgentAction, onChainEnd, signal }) { - const modelOptions = { - modelName: this.agentOptions.model, - temperature: this.agentOptions.temperature, - }; - - const model = this.initializeLLM({ - ...modelOptions, - context: 'plugins', - initialMessageCount: this.currentMessages.length + 1, - }); - - logger.debug( - `[PluginsClient] Agent Model: ${model.modelName} | Temp: ${model.temperature} | Functions: ${this.functionsAgent}`, - ); - - // Map Messages to Langchain format - const pastMessages = formatLangChainMessages(this.currentMessages.slice(0, -1), { - userName: this.options?.name, - }); - logger.debug('[PluginsClient] pastMessages: ' + pastMessages.length); - - // TODO: use readOnly memory, TokenBufferMemory? (both unavailable in LangChainJS) - const memory = new BufferMemory({ - llm: model, - chatHistory: new ChatMessageHistory(pastMessages), - }); - - const { loadedTools } = await loadTools({ - user, - model, - tools: this.options.tools, - functions: this.functionsAgent, - options: { - memory, - signal: this.abortController.signal, - openAIApiKey: this.openAIApiKey, - conversationId: this.conversationId, - fileStrategy: this.options.req.app.locals.fileStrategy, - processFileURL, - message, - }, - useSpecs: true, - }); - - if (loadedTools.length === 0) { - return; - } - - this.tools = loadedTools; - - logger.debug('[PluginsClient] Requested Tools', this.options.tools); - logger.debug( - '[PluginsClient] Loaded Tools', - this.tools.map((tool) => tool.name), - ); - - const handleAction = (action, runId, callback = null) => { - this.saveLatestAction(action); - - logger.debug('[PluginsClient] Latest Agent Action ', this.actions[this.actions.length - 1]); - - if (typeof callback === 'function') { - callback(action, runId); - } - }; - - // initialize agent - const initializer = this.functionsAgent ? initializeFunctionsAgent : initializeCustomAgent; - - let customInstructions = (this.options.promptPrefix ?? '').trim(); - if (typeof this.options.artifactsPrompt === 'string' && this.options.artifactsPrompt) { - customInstructions = `${customInstructions ?? ''}\n${this.options.artifactsPrompt}`.trim(); - } - - this.executor = await initializer({ - model, - signal, - pastMessages, - tools: this.tools, - customInstructions, - verbose: this.options.debug, - returnIntermediateSteps: true, - customName: this.options.chatGptLabel, - currentDateString: this.currentDateString, - callbackManager: CallbackManager.fromHandlers({ - async handleAgentAction(action, runId) { - handleAction(action, runId, onAgentAction); - }, - async handleChainEnd(action) { - if (typeof onChainEnd === 'function') { - onChainEnd(action); - } - }, - }), - }); - - logger.debug('[PluginsClient] Loaded agent.'); - } - - async executorCall(message, { signal, stream, onToolStart, onToolEnd }) { - let errorMessage = ''; - const maxAttempts = 1; - - for (let attempts = 1; attempts <= maxAttempts; attempts++) { - const errorInput = buildErrorInput({ - message, - errorMessage, - actions: this.actions, - functionsAgent: this.functionsAgent, - }); - const input = attempts > 1 ? errorInput : message; - - logger.debug(`[PluginsClient] Attempt ${attempts} of ${maxAttempts}`); - - if (errorMessage.length > 0) { - logger.debug('[PluginsClient] Caught error, input: ' + JSON.stringify(input)); - } - - try { - this.result = await this.executor.call({ input, signal }, [ - { - async handleToolStart(...args) { - await onToolStart(...args); - }, - async handleToolEnd(...args) { - await onToolEnd(...args); - }, - async handleLLMEnd(output) { - const { generations } = output; - const { text } = generations[0][0]; - if (text && typeof stream === 'function') { - await stream(text); - } - }, - }, - ]); - break; // Exit the loop if the function call is successful - } catch (err) { - logger.error('[PluginsClient] executorCall error:', err); - if (attempts === maxAttempts) { - const { run } = this.runManager.getRunByConversationId(this.conversationId); - const defaultOutput = `Encountered an error while attempting to respond: ${err.message}`; - this.result.output = run && run.error ? run.error : defaultOutput; - this.result.errorMessage = run && run.error ? run.error : err.message; - this.result.intermediateSteps = this.actions; - break; - } - } - } - } - - /** - * - * @param {TMessage} responseMessage - * @param {Partial} saveOptions - * @param {string} user - * @returns - */ - async handleResponseMessage(responseMessage, saveOptions, user) { - const { output, errorMessage, ...result } = this.result; - logger.debug('[PluginsClient][handleResponseMessage] Output:', { - output, - errorMessage, - ...result, - }); - const { error } = responseMessage; - if (!error) { - responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage); - responseMessage.completionTokens = this.getTokenCount(responseMessage.text); - } - - // Record usage only when completion is skipped as it is already recorded in the agent phase. - if (!this.agentOptions.skipCompletion && !error) { - await this.recordTokenUsage(responseMessage); - } - - const databasePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user); - delete responseMessage.tokenCount; - return { ...responseMessage, ...result, databasePromise }; - } - - async sendMessage(message, opts = {}) { - /** @type {Promise} */ - let userMessagePromise; - /** @type {{ filteredTools: string[], includedTools: string[] }} */ - const { filteredTools = [], includedTools = [] } = this.options.req.app.locals; - - if (includedTools.length > 0) { - const tools = this.options.tools.filter((plugin) => includedTools.includes(plugin)); - this.options.tools = tools; - } else { - const tools = this.options.tools.filter((plugin) => !filteredTools.includes(plugin)); - this.options.tools = tools; - } - - // If a message is edited, no tools can be used. - const completionMode = this.options.tools.length === 0 || opts.isEdited; - if (completionMode) { - this.setOptions(opts); - return super.sendMessage(message, opts); - } - - logger.debug('[PluginsClient] sendMessage', { userMessageText: message, opts }); - const { - user, - conversationId, - responseMessageId, - saveOptions, - userMessage, - onAgentAction, - onChainEnd, - onToolStart, - onToolEnd, - } = await this.handleStartMethods(message, opts); - - if (opts.progressCallback) { - opts.onProgress = opts.progressCallback.call(null, { - ...(opts.progressOptions ?? {}), - parentMessageId: userMessage.messageId, - messageId: responseMessageId, - }); - } - - this.currentMessages.push(userMessage); - - let { - prompt: payload, - tokenCountMap, - promptTokens, - } = await this.buildMessages( - this.currentMessages, - userMessage.messageId, - this.getBuildMessagesOptions({ - promptPrefix: null, - abortController: this.abortController, - }), - ); - - if (tokenCountMap) { - logger.debug('[PluginsClient] tokenCountMap', { tokenCountMap }); - if (tokenCountMap[userMessage.messageId]) { - userMessage.tokenCount = tokenCountMap[userMessage.messageId]; - logger.debug('[PluginsClient] userMessage.tokenCount', userMessage.tokenCount); - } - this.handleTokenCountMap(tokenCountMap); - } - - this.result = {}; - if (payload) { - this.currentMessages = payload; - } - - if (!this.skipSaveUserMessage) { - userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user); - if (typeof opts?.getReqData === 'function') { - opts.getReqData({ - userMessagePromise, - }); - } - } - - const balance = this.options.req?.app?.locals?.balance; - if (balance?.enabled) { - await checkBalance({ - req: this.options.req, - res: this.options.res, - txData: { - user: this.user, - tokenType: 'prompt', - amount: promptTokens, - debug: this.options.debug, - model: this.modelOptions.model, - endpoint: EModelEndpoint.openAI, - }, - }); - } - - const responseMessage = { - endpoint: EModelEndpoint.gptPlugins, - iconURL: this.options.iconURL, - messageId: responseMessageId, - conversationId, - parentMessageId: userMessage.messageId, - isCreatedByUser: false, - model: this.modelOptions.model, - sender: this.sender, - promptTokens, - }; - - await this.initialize({ - user, - message, - onAgentAction, - onChainEnd, - signal: this.abortController.signal, - onProgress: opts.onProgress, - }); - - // const stream = async (text) => { - // await this.generateTextStream.call(this, text, opts.onProgress, { delay: 1 }); - // }; - await this.executorCall(message, { - signal: this.abortController.signal, - // stream, - onToolStart, - onToolEnd, - }); - - // If message was aborted mid-generation - if (this.result?.errorMessage?.length > 0 && this.result?.errorMessage?.includes('cancel')) { - responseMessage.text = 'Cancelled.'; - return await this.handleResponseMessage(responseMessage, saveOptions, user); - } - - // If error occurred during generation (likely token_balance) - if (this.result?.errorMessage?.length > 0) { - responseMessage.error = true; - responseMessage.text = this.result.output; - return await this.handleResponseMessage(responseMessage, saveOptions, user); - } - - if (this.agentOptions.skipCompletion && this.result.output && this.functionsAgent) { - const partialText = opts.getPartialText(); - const trimmedPartial = opts.getPartialText().replaceAll(':::plugin:::\n', ''); - responseMessage.text = - trimmedPartial.length === 0 ? `${partialText}${this.result.output}` : partialText; - addImages(this.result.intermediateSteps, responseMessage); - await this.generateTextStream(this.result.output, opts.onProgress, { delay: 5 }); - return await this.handleResponseMessage(responseMessage, saveOptions, user); - } - - if (this.agentOptions.skipCompletion && this.result.output) { - responseMessage.text = this.result.output; - addImages(this.result.intermediateSteps, responseMessage); - await this.generateTextStream(this.result.output, opts.onProgress, { delay: 5 }); - return await this.handleResponseMessage(responseMessage, saveOptions, user); - } - - logger.debug('[PluginsClient] Completion phase: this.result', this.result); - - const promptPrefix = buildPromptPrefix({ - result: this.result, - message, - functionsAgent: this.functionsAgent, - }); - - logger.debug('[PluginsClient]', { promptPrefix }); - - payload = await this.buildCompletionPrompt({ - messages: this.currentMessages, - promptPrefix, - }); - - logger.debug('[PluginsClient] buildCompletionPrompt Payload', payload); - responseMessage.text = await this.sendCompletion(payload, opts); - return await this.handleResponseMessage(responseMessage, saveOptions, user); - } - - async buildCompletionPrompt({ messages, promptPrefix: _promptPrefix }) { - logger.debug('[PluginsClient] buildCompletionPrompt messages', messages); - - const orderedMessages = messages; - let promptPrefix = _promptPrefix.trim(); - // If the prompt prefix doesn't end with the end token, add it. - if (!promptPrefix.endsWith(`${this.endToken}`)) { - promptPrefix = `${promptPrefix.trim()}${this.endToken}\n\n`; - } - promptPrefix = `${this.startToken}Instructions:\n${promptPrefix}`; - const promptSuffix = `${this.startToken}${this.chatGptLabel ?? 'Assistant'}:\n`; - - const instructionsPayload = { - role: 'system', - content: promptPrefix, - }; - - const messagePayload = { - role: 'system', - content: promptSuffix, - }; - - if (this.isGpt3) { - instructionsPayload.role = 'user'; - messagePayload.role = 'user'; - instructionsPayload.content += `\n${promptSuffix}`; - } - - // testing if this works with browser endpoint - if (!this.isGpt3 && this.options.reverseProxyUrl) { - instructionsPayload.role = 'user'; - } - - let currentTokenCount = - this.getTokenCountForMessage(instructionsPayload) + - this.getTokenCountForMessage(messagePayload); - - let promptBody = ''; - const maxTokenCount = this.maxPromptTokens; - // Iterate backwards through the messages, adding them to the prompt until we reach the max token count. - // Do this within a recursive async function so that it doesn't block the event loop for too long. - const buildPromptBody = async () => { - if (currentTokenCount < maxTokenCount && orderedMessages.length > 0) { - const message = orderedMessages.pop(); - const isCreatedByUser = message.isCreatedByUser || message.role?.toLowerCase() === 'user'; - const roleLabel = isCreatedByUser ? this.userLabel : this.chatGptLabel; - let messageString = `${this.startToken}${roleLabel}:\n${ - message.text ?? message.content ?? '' - }${this.endToken}\n`; - let newPromptBody = `${messageString}${promptBody}`; - - const tokenCountForMessage = this.getTokenCount(messageString); - const newTokenCount = currentTokenCount + tokenCountForMessage; - if (newTokenCount > maxTokenCount) { - if (promptBody) { - // This message would put us over the token limit, so don't add it. - return false; - } - // This is the first message, so we can't add it. Just throw an error. - throw new Error( - `Prompt is too long. Max token count is ${maxTokenCount}, but prompt is ${newTokenCount} tokens long.`, - ); - } - promptBody = newPromptBody; - currentTokenCount = newTokenCount; - // wait for next tick to avoid blocking the event loop - await new Promise((resolve) => setTimeout(resolve, 0)); - return buildPromptBody(); - } - return true; - }; - - await buildPromptBody(); - const prompt = promptBody; - messagePayload.content = prompt; - // Add 2 tokens for metadata after all messages have been counted. - currentTokenCount += 2; - - if (this.isGpt3 && messagePayload.content.length > 0) { - const context = 'Chat History:\n'; - messagePayload.content = `${context}${prompt}`; - currentTokenCount += this.getTokenCount(context); - } - - // Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response. - this.modelOptions.max_tokens = Math.min( - this.maxContextTokens - currentTokenCount, - this.maxResponseTokens, - ); - - if (this.isGpt3) { - messagePayload.content += promptSuffix; - return [instructionsPayload, messagePayload]; - } - - const result = [messagePayload, instructionsPayload]; - - if (this.functionsAgent && !this.isGpt3) { - result[1].content = `${result[1].content}\n${this.startToken}${this.chatGptLabel}:\nSure thing! Here is the output you requested:\n`; - } - - return result.filter((message) => message.content.length > 0); - } -} - -module.exports = PluginsClient; diff --git a/api/app/clients/generators.js b/api/app/clients/generators.js deleted file mode 100644 index 9814cac7a5..0000000000 --- a/api/app/clients/generators.js +++ /dev/null @@ -1,71 +0,0 @@ -const fetch = require('node-fetch'); -const { GraphEvents } = require('@librechat/agents'); -const { logger, sendEvent } = require('~/config'); -const { sleep } = require('~/server/utils'); - -/** - * Makes a function to make HTTP request and logs the process. - * @param {Object} params - * @param {boolean} [params.directEndpoint] - Whether to use a direct endpoint. - * @param {string} [params.reverseProxyUrl] - The reverse proxy URL to use for the request. - * @returns {Promise} - A promise that resolves to the response of the fetch request. - */ -function createFetch({ directEndpoint = false, reverseProxyUrl = '' }) { - /** - * Makes an HTTP request and logs the process. - * @param {RequestInfo} url - The URL to make the request to. Can be a string or a Request object. - * @param {RequestInit} [init] - Optional init options for the request. - * @returns {Promise} - A promise that resolves to the response of the fetch request. - */ - return async (_url, init) => { - let url = _url; - if (directEndpoint) { - url = reverseProxyUrl; - } - logger.debug(`Making request to ${url}`); - if (typeof Bun !== 'undefined') { - return await fetch(url, init); - } - return await fetch(url, init); - }; -} - -// Add this at the module level outside the class -/** - * Creates event handlers for stream events that don't capture client references - * @param {Object} res - The response object to send events to - * @returns {Object} Object containing handler functions - */ -function createStreamEventHandlers(res) { - return { - [GraphEvents.ON_RUN_STEP]: (event) => { - if (res) { - sendEvent(res, event); - } - }, - [GraphEvents.ON_MESSAGE_DELTA]: (event) => { - if (res) { - sendEvent(res, event); - } - }, - [GraphEvents.ON_REASONING_DELTA]: (event) => { - if (res) { - sendEvent(res, event); - } - }, - }; -} - -function createHandleLLMNewToken(streamRate) { - return async () => { - if (streamRate) { - await sleep(streamRate); - } - }; -} - -module.exports = { - createFetch, - createHandleLLMNewToken, - createStreamEventHandlers, -}; diff --git a/api/app/clients/index.js b/api/app/clients/index.js index a5e8eee504..d8b2bae27b 100644 --- a/api/app/clients/index.js +++ b/api/app/clients/index.js @@ -1,15 +1,11 @@ -const ChatGPTClient = require('./ChatGPTClient'); const OpenAIClient = require('./OpenAIClient'); -const PluginsClient = require('./PluginsClient'); const GoogleClient = require('./GoogleClient'); const TextStream = require('./TextStream'); const AnthropicClient = require('./AnthropicClient'); const toolUtils = require('./tools/util'); module.exports = { - ChatGPTClient, OpenAIClient, - PluginsClient, GoogleClient, TextStream, AnthropicClient, diff --git a/api/app/clients/llm/createLLM.js b/api/app/clients/llm/createLLM.js index c8d6666bce..846c4d8e9c 100644 --- a/api/app/clients/llm/createLLM.js +++ b/api/app/clients/llm/createLLM.js @@ -1,6 +1,5 @@ const { ChatOpenAI } = require('@langchain/openai'); -const { sanitizeModelName, constructAzureURL } = require('~/utils'); -const { isEnabled } = require('~/server/utils'); +const { isEnabled, sanitizeModelName, constructAzureURL } = require('@librechat/api'); /** * Creates a new instance of a language model (LLM) for chat interactions. diff --git a/api/app/clients/prompts/createContextHandlers.js b/api/app/clients/prompts/createContextHandlers.js index 4dcfaf68e4..b3ea9164e7 100644 --- a/api/app/clients/prompts/createContextHandlers.js +++ b/api/app/clients/prompts/createContextHandlers.js @@ -1,6 +1,7 @@ const axios = require('axios'); -const { isEnabled } = require('~/server/utils'); -const { logger } = require('~/config'); +const { isEnabled } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); +const { generateShortLivedToken } = require('~/server/services/AuthService'); const footer = `Use the context as your learned knowledge to better answer the user. @@ -18,7 +19,7 @@ function createContextHandlers(req, userMessageContent) { const queryPromises = []; const processedFiles = []; const processedIds = new Set(); - const jwtToken = req.headers.authorization.split(' ')[1]; + const jwtToken = generateShortLivedToken(req.user.id); const useFullContext = isEnabled(process.env.RAG_USE_FULL_CONTEXT); const query = async (file) => { @@ -96,35 +97,35 @@ function createContextHandlers(req, userMessageContent) { resolvedQueries.length === 0 ? '\n\tThe semantic search did not return any results.' : resolvedQueries - .map((queryResult, index) => { - const file = processedFiles[index]; - let contextItems = queryResult.data; + .map((queryResult, index) => { + const file = processedFiles[index]; + let contextItems = queryResult.data; - const generateContext = (currentContext) => - ` + const generateContext = (currentContext) => + ` ${file.filename} ${currentContext} `; - if (useFullContext) { - return generateContext(`\n${contextItems}`); - } + if (useFullContext) { + return generateContext(`\n${contextItems}`); + } - contextItems = queryResult.data - .map((item) => { - const pageContent = item[0].page_content; - return ` + contextItems = queryResult.data + .map((item) => { + const pageContent = item[0].page_content; + return ` `; - }) - .join(''); + }) + .join(''); - return generateContext(contextItems); - }) - .join(''); + return generateContext(contextItems); + }) + .join(''); if (useFullContext) { const prompt = `${header} diff --git a/api/app/clients/specs/AnthropicClient.test.js b/api/app/clients/specs/AnthropicClient.test.js index 9867859087..fbcd2b75e4 100644 --- a/api/app/clients/specs/AnthropicClient.test.js +++ b/api/app/clients/specs/AnthropicClient.test.js @@ -309,7 +309,7 @@ describe('AnthropicClient', () => { }; client.setOptions({ modelOptions, promptCache: true }); const anthropicClient = client.getClient(modelOptions); - expect(anthropicClient.defaultHeaders).not.toHaveProperty('anthropic-beta'); + expect(anthropicClient._options.defaultHeaders).toBeUndefined(); }); it('should not add beta header for other models', () => { @@ -320,7 +320,7 @@ describe('AnthropicClient', () => { }, }); const anthropicClient = client.getClient(); - expect(anthropicClient.defaultHeaders).not.toHaveProperty('anthropic-beta'); + expect(anthropicClient._options.defaultHeaders).toBeUndefined(); }); }); diff --git a/api/app/clients/specs/BaseClient.test.js b/api/app/clients/specs/BaseClient.test.js index d620d5f647..6d44915804 100644 --- a/api/app/clients/specs/BaseClient.test.js +++ b/api/app/clients/specs/BaseClient.test.js @@ -1,7 +1,7 @@ const { Constants } = require('librechat-data-provider'); const { initializeFakeClient } = require('./FakeClient'); -jest.mock('~/lib/db/connectDb'); +jest.mock('~/db/connect'); jest.mock('~/models', () => ({ User: jest.fn(), Key: jest.fn(), @@ -33,7 +33,9 @@ jest.mock('~/models', () => ({ const { getConvo, saveConvo } = require('~/models'); jest.mock('@librechat/agents', () => { + const { Providers } = jest.requireActual('@librechat/agents'); return { + Providers, ChatOpenAI: jest.fn().mockImplementation(() => { return {}; }), @@ -52,7 +54,7 @@ const messageHistory = [ { role: 'user', isCreatedByUser: true, - text: 'What\'s up', + text: "What's up", messageId: '3', parentMessageId: '2', }, @@ -456,7 +458,7 @@ describe('BaseClient', () => { const chatMessages2 = await TestClient.loadHistory(conversationId, '3'); expect(TestClient.currentMessages).toHaveLength(3); - expect(chatMessages2[chatMessages2.length - 1].text).toEqual('What\'s up'); + expect(chatMessages2[chatMessages2.length - 1].text).toEqual("What's up"); }); /* Most of the new sendMessage logic revolving around edited/continued AI messages diff --git a/api/app/clients/specs/OpenAIClient.test.js b/api/app/clients/specs/OpenAIClient.test.js index 579f636eef..efca66a867 100644 --- a/api/app/clients/specs/OpenAIClient.test.js +++ b/api/app/clients/specs/OpenAIClient.test.js @@ -5,7 +5,7 @@ const getLogStores = require('~/cache/getLogStores'); const OpenAIClient = require('../OpenAIClient'); jest.mock('meilisearch'); -jest.mock('~/lib/db/connectDb'); +jest.mock('~/db/connect'); jest.mock('~/models', () => ({ User: jest.fn(), Key: jest.fn(), @@ -462,17 +462,17 @@ describe('OpenAIClient', () => { role: 'system', name: 'example_user', content: - 'Let\'s circle back when we have more bandwidth to touch base on opportunities for increased leverage.', + "Let's circle back when we have more bandwidth to touch base on opportunities for increased leverage.", }, { role: 'system', name: 'example_assistant', - content: 'Let\'s talk later when we\'re less busy about how to do better.', + content: "Let's talk later when we're less busy about how to do better.", }, { role: 'user', content: - 'This late pivot means we don\'t have time to boil the ocean for the client deliverable.', + "This late pivot means we don't have time to boil the ocean for the client deliverable.", }, ]; @@ -531,44 +531,6 @@ describe('OpenAIClient', () => { }); }); - describe('sendMessage/getCompletion/chatCompletion', () => { - afterEach(() => { - delete process.env.AZURE_OPENAI_DEFAULT_MODEL; - delete process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME; - }); - - it('should call getCompletion and fetchEventSource when using a text/instruct model', async () => { - const model = 'text-davinci-003'; - const onProgress = jest.fn().mockImplementation(() => ({})); - - const testClient = new OpenAIClient('test-api-key', { - ...defaultOptions, - modelOptions: { model }, - }); - - const getCompletion = jest.spyOn(testClient, 'getCompletion'); - await testClient.sendMessage('Hi mom!', { onProgress }); - - expect(getCompletion).toHaveBeenCalled(); - expect(getCompletion.mock.calls.length).toBe(1); - - expect(getCompletion.mock.calls[0][0]).toBe('||>User:\nHi mom!\n||>Assistant:\n'); - - expect(fetchEventSource).toHaveBeenCalled(); - expect(fetchEventSource.mock.calls.length).toBe(1); - - // Check if the first argument (url) is correct - const firstCallArgs = fetchEventSource.mock.calls[0]; - - const expectedURL = 'https://api.openai.com/v1/completions'; - expect(firstCallArgs[0]).toBe(expectedURL); - - const requestBody = JSON.parse(firstCallArgs[1].body); - expect(requestBody).toHaveProperty('model'); - expect(requestBody.model).toBe(model); - }); - }); - describe('checkVisionRequest functionality', () => { let client; const attachments = [{ type: 'image/png' }]; diff --git a/api/app/clients/specs/PluginsClient.test.js b/api/app/clients/specs/PluginsClient.test.js deleted file mode 100644 index fd7bee5043..0000000000 --- a/api/app/clients/specs/PluginsClient.test.js +++ /dev/null @@ -1,314 +0,0 @@ -const crypto = require('crypto'); -const { Constants } = require('librechat-data-provider'); -const { HumanMessage, AIMessage } = require('@langchain/core/messages'); -const PluginsClient = require('../PluginsClient'); - -jest.mock('~/lib/db/connectDb'); -jest.mock('~/models/Conversation', () => { - return function () { - return { - save: jest.fn(), - deleteConvos: jest.fn(), - }; - }; -}); - -const defaultAzureOptions = { - azureOpenAIApiInstanceName: 'your-instance-name', - azureOpenAIApiDeploymentName: 'your-deployment-name', - azureOpenAIApiVersion: '2020-07-01-preview', -}; - -describe('PluginsClient', () => { - let TestAgent; - let options = { - tools: [], - modelOptions: { - model: 'gpt-3.5-turbo', - temperature: 0, - max_tokens: 2, - }, - agentOptions: { - model: 'gpt-3.5-turbo', - }, - }; - let parentMessageId; - let conversationId; - const fakeMessages = []; - const userMessage = 'Hello, ChatGPT!'; - const apiKey = 'fake-api-key'; - - beforeEach(() => { - TestAgent = new PluginsClient(apiKey, options); - TestAgent.loadHistory = jest - .fn() - .mockImplementation((conversationId, parentMessageId = null) => { - if (!conversationId) { - TestAgent.currentMessages = []; - return Promise.resolve([]); - } - - const orderedMessages = TestAgent.constructor.getMessagesForConversation({ - messages: fakeMessages, - parentMessageId, - }); - - const chatMessages = orderedMessages.map((msg) => - msg?.isCreatedByUser || msg?.role?.toLowerCase() === 'user' - ? new HumanMessage(msg.text) - : new AIMessage(msg.text), - ); - - TestAgent.currentMessages = orderedMessages; - return Promise.resolve(chatMessages); - }); - TestAgent.sendMessage = jest.fn().mockImplementation(async (message, opts = {}) => { - if (opts && typeof opts === 'object') { - TestAgent.setOptions(opts); - } - const conversationId = opts.conversationId || crypto.randomUUID(); - const parentMessageId = opts.parentMessageId || Constants.NO_PARENT; - const userMessageId = opts.overrideParentMessageId || crypto.randomUUID(); - this.pastMessages = await TestAgent.loadHistory( - conversationId, - TestAgent.options?.parentMessageId, - ); - - const userMessage = { - text: message, - sender: 'ChatGPT', - isCreatedByUser: true, - messageId: userMessageId, - parentMessageId, - conversationId, - }; - - const response = { - sender: 'ChatGPT', - text: 'Hello, User!', - isCreatedByUser: false, - messageId: crypto.randomUUID(), - parentMessageId: userMessage.messageId, - conversationId, - }; - - fakeMessages.push(userMessage); - fakeMessages.push(response); - return response; - }); - }); - - test('initializes PluginsClient without crashing', () => { - expect(TestAgent).toBeInstanceOf(PluginsClient); - }); - - test('check setOptions function', () => { - expect(TestAgent.agentIsGpt3).toBe(true); - }); - - describe('sendMessage', () => { - test('sendMessage should return a response message', async () => { - const expectedResult = expect.objectContaining({ - sender: 'ChatGPT', - text: expect.any(String), - isCreatedByUser: false, - messageId: expect.any(String), - parentMessageId: expect.any(String), - conversationId: expect.any(String), - }); - - const response = await TestAgent.sendMessage(userMessage); - parentMessageId = response.messageId; - conversationId = response.conversationId; - expect(response).toEqual(expectedResult); - }); - - test('sendMessage should work with provided conversationId and parentMessageId', async () => { - const userMessage = 'Second message in the conversation'; - const opts = { - conversationId, - parentMessageId, - }; - - const expectedResult = expect.objectContaining({ - sender: 'ChatGPT', - text: expect.any(String), - isCreatedByUser: false, - messageId: expect.any(String), - parentMessageId: expect.any(String), - conversationId: opts.conversationId, - }); - - const response = await TestAgent.sendMessage(userMessage, opts); - parentMessageId = response.messageId; - expect(response.conversationId).toEqual(conversationId); - expect(response).toEqual(expectedResult); - }); - - test('should return chat history', async () => { - const chatMessages = await TestAgent.loadHistory(conversationId, parentMessageId); - expect(TestAgent.currentMessages).toHaveLength(4); - expect(chatMessages[0].text).toEqual(userMessage); - }); - }); - - describe('getFunctionModelName', () => { - let client; - - beforeEach(() => { - client = new PluginsClient('dummy_api_key'); - }); - - test('should return the input when it includes a dash followed by four digits', () => { - expect(client.getFunctionModelName('-1234')).toBe('-1234'); - expect(client.getFunctionModelName('gpt-4-5678-preview')).toBe('gpt-4-5678-preview'); - }); - - test('should return the input for all function-capable models (`0613` models and above)', () => { - expect(client.getFunctionModelName('gpt-4-0613')).toBe('gpt-4-0613'); - expect(client.getFunctionModelName('gpt-4-32k-0613')).toBe('gpt-4-32k-0613'); - expect(client.getFunctionModelName('gpt-3.5-turbo-0613')).toBe('gpt-3.5-turbo-0613'); - expect(client.getFunctionModelName('gpt-3.5-turbo-16k-0613')).toBe('gpt-3.5-turbo-16k-0613'); - expect(client.getFunctionModelName('gpt-3.5-turbo-1106')).toBe('gpt-3.5-turbo-1106'); - expect(client.getFunctionModelName('gpt-4-1106-preview')).toBe('gpt-4-1106-preview'); - expect(client.getFunctionModelName('gpt-4-1106')).toBe('gpt-4-1106'); - }); - - test('should return the corresponding model if input is non-function capable (`0314` models)', () => { - expect(client.getFunctionModelName('gpt-4-0314')).toBe('gpt-4'); - expect(client.getFunctionModelName('gpt-4-32k-0314')).toBe('gpt-4'); - expect(client.getFunctionModelName('gpt-3.5-turbo-0314')).toBe('gpt-3.5-turbo'); - expect(client.getFunctionModelName('gpt-3.5-turbo-16k-0314')).toBe('gpt-3.5-turbo'); - }); - - test('should return "gpt-3.5-turbo" when the input includes "gpt-3.5-turbo"', () => { - expect(client.getFunctionModelName('test gpt-3.5-turbo model')).toBe('gpt-3.5-turbo'); - }); - - test('should return "gpt-4" when the input includes "gpt-4"', () => { - expect(client.getFunctionModelName('testing gpt-4')).toBe('gpt-4'); - }); - - test('should return "gpt-3.5-turbo" for input that does not meet any specific condition', () => { - expect(client.getFunctionModelName('random string')).toBe('gpt-3.5-turbo'); - expect(client.getFunctionModelName('')).toBe('gpt-3.5-turbo'); - }); - }); - - describe('Azure OpenAI tests specific to Plugins', () => { - // TODO: add more tests for Azure OpenAI integration with Plugins - // let client; - // beforeEach(() => { - // client = new PluginsClient('dummy_api_key'); - // }); - - test('should not call getFunctionModelName when azure options are set', () => { - const spy = jest.spyOn(PluginsClient.prototype, 'getFunctionModelName'); - const model = 'gpt-4-turbo'; - - // note, without the azure change in PR #1766, `getFunctionModelName` is called twice - const testClient = new PluginsClient('dummy_api_key', { - agentOptions: { - model, - agent: 'functions', - }, - azure: defaultAzureOptions, - }); - - expect(spy).not.toHaveBeenCalled(); - expect(testClient.agentOptions.model).toBe(model); - - spy.mockRestore(); - }); - }); - - describe('sendMessage with filtered tools', () => { - let TestAgent; - const apiKey = 'fake-api-key'; - const mockTools = [{ name: 'tool1' }, { name: 'tool2' }, { name: 'tool3' }, { name: 'tool4' }]; - - beforeEach(() => { - TestAgent = new PluginsClient(apiKey, { - tools: mockTools, - modelOptions: { - model: 'gpt-3.5-turbo', - temperature: 0, - max_tokens: 2, - }, - agentOptions: { - model: 'gpt-3.5-turbo', - }, - }); - - TestAgent.options.req = { - app: { - locals: {}, - }, - }; - - TestAgent.sendMessage = jest.fn().mockImplementation(async () => { - const { filteredTools = [], includedTools = [] } = TestAgent.options.req.app.locals; - - if (includedTools.length > 0) { - const tools = TestAgent.options.tools.filter((plugin) => - includedTools.includes(plugin.name), - ); - TestAgent.options.tools = tools; - } else { - const tools = TestAgent.options.tools.filter( - (plugin) => !filteredTools.includes(plugin.name), - ); - TestAgent.options.tools = tools; - } - - return { - text: 'Mocked response', - tools: TestAgent.options.tools, - }; - }); - }); - - test('should filter out tools when filteredTools is provided', async () => { - TestAgent.options.req.app.locals.filteredTools = ['tool1', 'tool3']; - const response = await TestAgent.sendMessage('Test message'); - expect(response.tools).toHaveLength(2); - expect(response.tools).toEqual( - expect.arrayContaining([ - expect.objectContaining({ name: 'tool2' }), - expect.objectContaining({ name: 'tool4' }), - ]), - ); - }); - - test('should only include specified tools when includedTools is provided', async () => { - TestAgent.options.req.app.locals.includedTools = ['tool2', 'tool4']; - const response = await TestAgent.sendMessage('Test message'); - expect(response.tools).toHaveLength(2); - expect(response.tools).toEqual( - expect.arrayContaining([ - expect.objectContaining({ name: 'tool2' }), - expect.objectContaining({ name: 'tool4' }), - ]), - ); - }); - - test('should prioritize includedTools over filteredTools', async () => { - TestAgent.options.req.app.locals.filteredTools = ['tool1', 'tool3']; - TestAgent.options.req.app.locals.includedTools = ['tool1', 'tool2']; - const response = await TestAgent.sendMessage('Test message'); - expect(response.tools).toHaveLength(2); - expect(response.tools).toEqual( - expect.arrayContaining([ - expect.objectContaining({ name: 'tool1' }), - expect.objectContaining({ name: 'tool2' }), - ]), - ); - }); - - test('should not modify tools when no filters are provided', async () => { - const response = await TestAgent.sendMessage('Test message'); - expect(response.tools).toHaveLength(4); - expect(response.tools).toEqual(expect.arrayContaining(mockTools)); - }); - }); -}); diff --git a/api/app/clients/tools/dynamic/OpenAPIPlugin.js b/api/app/clients/tools/dynamic/OpenAPIPlugin.js deleted file mode 100644 index acc3a64d32..0000000000 --- a/api/app/clients/tools/dynamic/OpenAPIPlugin.js +++ /dev/null @@ -1,184 +0,0 @@ -require('dotenv').config(); -const fs = require('fs'); -const { z } = require('zod'); -const path = require('path'); -const yaml = require('js-yaml'); -const { createOpenAPIChain } = require('langchain/chains'); -const { DynamicStructuredTool } = require('@langchain/core/tools'); -const { ChatPromptTemplate, HumanMessagePromptTemplate } = require('@langchain/core/prompts'); -const { logger } = require('~/config'); - -function addLinePrefix(text, prefix = '// ') { - return text - .split('\n') - .map((line) => prefix + line) - .join('\n'); -} - -function createPrompt(name, functions) { - const prefix = `// The ${name} tool has the following functions. Determine the desired or most optimal function for the user's query:`; - const functionDescriptions = functions - .map((func) => `// - ${func.name}: ${func.description}`) - .join('\n'); - return `${prefix}\n${functionDescriptions} -// You are an expert manager and scrum master. You must provide a detailed intent to better execute the function. -// Always format as such: {{"func": "function_name", "intent": "intent and expected result"}}`; -} - -const AuthBearer = z - .object({ - type: z.string().includes('service_http'), - authorization_type: z.string().includes('bearer'), - verification_tokens: z.object({ - openai: z.string(), - }), - }) - .catch(() => false); - -const AuthDefinition = z - .object({ - type: z.string(), - authorization_type: z.string(), - verification_tokens: z.object({ - openai: z.string(), - }), - }) - .catch(() => false); - -async function readSpecFile(filePath) { - try { - const fileContents = await fs.promises.readFile(filePath, 'utf8'); - if (path.extname(filePath) === '.json') { - return JSON.parse(fileContents); - } - return yaml.load(fileContents); - } catch (e) { - logger.error('[readSpecFile] error', e); - return false; - } -} - -async function getSpec(url) { - const RegularUrl = z - .string() - .url() - .catch(() => false); - - if (RegularUrl.parse(url) && path.extname(url) === '.json') { - const response = await fetch(url); - return await response.json(); - } - - const ValidSpecPath = z - .string() - .url() - .catch(async () => { - const spec = path.join(__dirname, '..', '.well-known', 'openapi', url); - if (!fs.existsSync(spec)) { - return false; - } - - return await readSpecFile(spec); - }); - - return ValidSpecPath.parse(url); -} - -async function createOpenAPIPlugin({ data, llm, user, message, memory, signal }) { - let spec; - try { - spec = await getSpec(data.api.url); - } catch (error) { - logger.error('[createOpenAPIPlugin] getSpec error', error); - return null; - } - - if (!spec) { - logger.warn('[createOpenAPIPlugin] No spec found'); - return null; - } - - const headers = {}; - const { auth, name_for_model, description_for_model, description_for_human } = data; - if (auth && AuthDefinition.parse(auth)) { - logger.debug('[createOpenAPIPlugin] auth detected', auth); - const { openai } = auth.verification_tokens; - if (AuthBearer.parse(auth)) { - headers.authorization = `Bearer ${openai}`; - logger.debug('[createOpenAPIPlugin] added auth bearer', headers); - } - } - - const chainOptions = { llm }; - - if (data.headers && data.headers['librechat_user_id']) { - logger.debug('[createOpenAPIPlugin] id detected', headers); - headers[data.headers['librechat_user_id']] = user; - } - - if (Object.keys(headers).length > 0) { - logger.debug('[createOpenAPIPlugin] headers detected', headers); - chainOptions.headers = headers; - } - - if (data.params) { - logger.debug('[createOpenAPIPlugin] params detected', data.params); - chainOptions.params = data.params; - } - - let history = ''; - if (memory) { - logger.debug('[createOpenAPIPlugin] openAPI chain: memory detected', memory); - const { history: chat_history } = await memory.loadMemoryVariables({}); - history = chat_history?.length > 0 ? `\n\n## Chat History:\n${chat_history}\n` : ''; - } - - chainOptions.prompt = ChatPromptTemplate.fromMessages([ - HumanMessagePromptTemplate.fromTemplate( - `# Use the provided API's to respond to this query:\n\n{query}\n\n## Instructions:\n${addLinePrefix( - description_for_model, - )}${history}`, - ), - ]); - - const chain = await createOpenAPIChain(spec, chainOptions); - - const { functions } = chain.chains[0].lc_kwargs.llmKwargs; - - return new DynamicStructuredTool({ - name: name_for_model, - description_for_model: `${addLinePrefix(description_for_human)}${createPrompt( - name_for_model, - functions, - )}`, - description: `${description_for_human}`, - schema: z.object({ - func: z - .string() - .describe( - `The function to invoke. The functions available are: ${functions - .map((func) => func.name) - .join(', ')}`, - ), - intent: z - .string() - .describe('Describe your intent with the function and your expected result'), - }), - func: async ({ func = '', intent = '' }) => { - const filteredFunctions = functions.filter((f) => f.name === func); - chain.chains[0].lc_kwargs.llmKwargs.functions = filteredFunctions; - const query = `${message}${func?.length > 0 ? `\n// Intent: ${intent}` : ''}`; - const result = await chain.call({ - query, - signal, - }); - return result.response; - }, - }); -} - -module.exports = { - getSpec, - readSpecFile, - createOpenAPIPlugin, -}; diff --git a/api/app/clients/tools/dynamic/OpenAPIPlugin.spec.js b/api/app/clients/tools/dynamic/OpenAPIPlugin.spec.js deleted file mode 100644 index 83bc5e9397..0000000000 --- a/api/app/clients/tools/dynamic/OpenAPIPlugin.spec.js +++ /dev/null @@ -1,72 +0,0 @@ -const fs = require('fs'); -const { createOpenAPIPlugin, getSpec, readSpecFile } = require('./OpenAPIPlugin'); - -global.fetch = jest.fn().mockImplementationOnce(() => { - return new Promise((resolve) => { - resolve({ - ok: true, - json: () => Promise.resolve({ key: 'value' }), - }); - }); -}); -jest.mock('fs', () => ({ - promises: { - readFile: jest.fn(), - }, - existsSync: jest.fn(), -})); - -describe('readSpecFile', () => { - it('reads JSON file correctly', async () => { - fs.promises.readFile.mockResolvedValue(JSON.stringify({ test: 'value' })); - const result = await readSpecFile('test.json'); - expect(result).toEqual({ test: 'value' }); - }); - - it('reads YAML file correctly', async () => { - fs.promises.readFile.mockResolvedValue('test: value'); - const result = await readSpecFile('test.yaml'); - expect(result).toEqual({ test: 'value' }); - }); - - it('handles error correctly', async () => { - fs.promises.readFile.mockRejectedValue(new Error('test error')); - const result = await readSpecFile('test.json'); - expect(result).toBe(false); - }); -}); - -describe('getSpec', () => { - it('fetches spec from url correctly', async () => { - const parsedJson = await getSpec('https://www.instacart.com/.well-known/ai-plugin.json'); - const isObject = typeof parsedJson === 'object'; - expect(isObject).toEqual(true); - }); - - it('reads spec from file correctly', async () => { - fs.existsSync.mockReturnValue(true); - fs.promises.readFile.mockResolvedValue(JSON.stringify({ test: 'value' })); - const result = await getSpec('test.json'); - expect(result).toEqual({ test: 'value' }); - }); - - it('returns false when file does not exist', async () => { - fs.existsSync.mockReturnValue(false); - const result = await getSpec('test.json'); - expect(result).toBe(false); - }); -}); - -describe('createOpenAPIPlugin', () => { - it('returns null when getSpec throws an error', async () => { - const result = await createOpenAPIPlugin({ data: { api: { url: 'invalid' } } }); - expect(result).toBe(null); - }); - - it('returns null when no spec is found', async () => { - const result = await createOpenAPIPlugin({}); - expect(result).toBe(null); - }); - - // Add more tests here for different scenarios -}); diff --git a/api/app/clients/tools/structured/DALLE3.js b/api/app/clients/tools/structured/DALLE3.js index fc0f1851f6..7c2a56fe71 100644 --- a/api/app/clients/tools/structured/DALLE3.js +++ b/api/app/clients/tools/structured/DALLE3.js @@ -8,10 +8,10 @@ const { HttpsProxyAgent } = require('https-proxy-agent'); const { FileContext, ContentTypes } = require('librechat-data-provider'); const { getImageBasename } = require('~/server/services/Files/images'); const extractBaseURL = require('~/utils/extractBaseURL'); -const { logger } = require('~/config'); +const logger = require('~/config/winston'); const displayMessage = - 'DALL-E displayed an image. All generated images are already plainly visible, so don\'t repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.'; + "DALL-E displayed an image. All generated images are already plainly visible, so don't repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user."; class DALLE3 extends Tool { constructor(fields = {}) { super(); diff --git a/api/app/clients/tools/structured/OpenAIImageTools.js b/api/app/clients/tools/structured/OpenAIImageTools.js index afea9dfd55..411db1edf9 100644 --- a/api/app/clients/tools/structured/OpenAIImageTools.js +++ b/api/app/clients/tools/structured/OpenAIImageTools.js @@ -4,12 +4,13 @@ const { v4 } = require('uuid'); const OpenAI = require('openai'); const FormData = require('form-data'); const { tool } = require('@langchain/core/tools'); +const { logAxiosError } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { HttpsProxyAgent } = require('https-proxy-agent'); const { ContentTypes, EImageOutputType } = require('librechat-data-provider'); const { getStrategyFunctions } = require('~/server/services/Files/strategies'); -const { logAxiosError, extractBaseURL } = require('~/utils'); +const { extractBaseURL } = require('~/utils'); const { getFiles } = require('~/models/File'); -const { logger } = require('~/config'); /** Default descriptions for image generation tool */ const DEFAULT_IMAGE_GEN_DESCRIPTION = ` @@ -64,7 +65,7 @@ const DEFAULT_IMAGE_EDIT_PROMPT_DESCRIPTION = `Describe the changes, enhancement Always base this prompt on the most recently uploaded reference images.`; const displayMessage = - 'The tool displayed an image. All generated images are already plainly visible, so don\'t repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.'; + "The tool displayed an image. All generated images are already plainly visible, so don't repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user."; /** * Replaces unwanted characters from the input string @@ -106,6 +107,12 @@ const getImageEditPromptDescription = () => { return process.env.IMAGE_EDIT_OAI_PROMPT_DESCRIPTION || DEFAULT_IMAGE_EDIT_PROMPT_DESCRIPTION; }; +function createAbortHandler() { + return function () { + logger.debug('[ImageGenOAI] Image generation aborted'); + }; +} + /** * Creates OpenAI Image tools (generation and editing) * @param {Object} fields - Configuration fields @@ -200,10 +207,18 @@ function createOpenAIImageTools(fields = {}) { } let resp; + /** @type {AbortSignal} */ + let derivedSignal = null; + /** @type {() => void} */ + let abortHandler = null; + try { - const derivedSignal = runnableConfig?.signal - ? AbortSignal.any([runnableConfig.signal]) - : undefined; + if (runnableConfig?.signal) { + derivedSignal = AbortSignal.any([runnableConfig.signal]); + abortHandler = createAbortHandler(); + derivedSignal.addEventListener('abort', abortHandler, { once: true }); + } + resp = await openai.images.generate( { model: 'gpt-image-1', @@ -227,6 +242,10 @@ function createOpenAIImageTools(fields = {}) { logAxiosError({ error, message }); return returnValue(`Something went wrong when trying to generate the image. The OpenAI API may be unavailable: Error Message: ${error.message}`); + } finally { + if (abortHandler && derivedSignal) { + derivedSignal.removeEventListener('abort', abortHandler); + } } if (!resp) { @@ -408,10 +427,17 @@ Error Message: ${error.message}`); headers['Authorization'] = `Bearer ${apiKey}`; } + /** @type {AbortSignal} */ + let derivedSignal = null; + /** @type {() => void} */ + let abortHandler = null; + try { - const derivedSignal = runnableConfig?.signal - ? AbortSignal.any([runnableConfig.signal]) - : undefined; + if (runnableConfig?.signal) { + derivedSignal = AbortSignal.any([runnableConfig.signal]); + abortHandler = createAbortHandler(); + derivedSignal.addEventListener('abort', abortHandler, { once: true }); + } /** @type {import('axios').AxiosRequestConfig} */ const axiosConfig = { @@ -466,6 +492,10 @@ Error Message: ${error.message}`); logAxiosError({ error, message }); return returnValue(`Something went wrong when trying to edit the image. The OpenAI API may be unavailable: Error Message: ${error.message || 'Unknown error'}`); + } finally { + if (abortHandler && derivedSignal) { + derivedSignal.removeEventListener('abort', abortHandler); + } } }, { diff --git a/api/app/clients/tools/structured/specs/DALLE3.spec.js b/api/app/clients/tools/structured/specs/DALLE3.spec.js index 1b28de2faf..2def575fb3 100644 --- a/api/app/clients/tools/structured/specs/DALLE3.spec.js +++ b/api/app/clients/tools/structured/specs/DALLE3.spec.js @@ -1,10 +1,29 @@ const OpenAI = require('openai'); const DALLE3 = require('../DALLE3'); - -const { logger } = require('~/config'); +const logger = require('~/config/winston'); jest.mock('openai'); +jest.mock('@librechat/data-schemas', () => { + return { + logger: { + info: jest.fn(), + warn: jest.fn(), + debug: jest.fn(), + error: jest.fn(), + }, + }; +}); + +jest.mock('tiktoken', () => { + return { + encoding_for_model: jest.fn().mockReturnValue({ + encode: jest.fn(), + decode: jest.fn(), + }), + }; +}); + const processFileURL = jest.fn(); jest.mock('~/server/services/Files/images', () => ({ @@ -37,6 +56,11 @@ jest.mock('fs', () => { return { existsSync: jest.fn(), mkdirSync: jest.fn(), + promises: { + writeFile: jest.fn(), + readFile: jest.fn(), + unlink: jest.fn(), + }, }; }); diff --git a/api/app/clients/tools/util/fileSearch.js b/api/app/clients/tools/util/fileSearch.js index 54da483362..050a0fd896 100644 --- a/api/app/clients/tools/util/fileSearch.js +++ b/api/app/clients/tools/util/fileSearch.js @@ -1,9 +1,10 @@ const { z } = require('zod'); const axios = require('axios'); const { tool } = require('@langchain/core/tools'); +const { logger } = require('@librechat/data-schemas'); const { Tools, EToolResources } = require('librechat-data-provider'); +const { generateShortLivedToken } = require('~/server/services/AuthService'); const { getFiles } = require('~/models/File'); -const { logger } = require('~/config'); /** * @@ -59,7 +60,7 @@ const createFileSearchTool = async ({ req, files, entity_id }) => { if (files.length === 0) { return 'No files to search. Instruct the user to add files for the search.'; } - const jwtToken = req.headers.authorization.split(' ')[1]; + const jwtToken = generateShortLivedToken(req.user.id); if (!jwtToken) { return 'There was an error authenticating the file search request.'; } @@ -135,7 +136,7 @@ const createFileSearchTool = async ({ req, files, entity_id }) => { query: z .string() .describe( - 'A natural language query to search for relevant information in the files. Be specific and use keywords related to the information you\'re looking for. The query will be used for semantic similarity matching against the file contents.', + "A natural language query to search for relevant information in the files. Be specific and use keywords related to the information you're looking for. The query will be used for semantic similarity matching against the file contents.", ), }), }, diff --git a/api/app/clients/tools/util/handleTools.js b/api/app/clients/tools/util/handleTools.js index 51f0c87ef9..c233c0f762 100644 --- a/api/app/clients/tools/util/handleTools.js +++ b/api/app/clients/tools/util/handleTools.js @@ -1,14 +1,14 @@ +const { mcpToolPattern } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { SerpAPI } = require('@langchain/community/tools/serpapi'); const { Calculator } = require('@langchain/community/tools/calculator'); const { EnvVar, createCodeExecutionTool, createSearchTool } = require('@librechat/agents'); const { Tools, - Constants, EToolResources, loadWebSearchAuth, replaceSpecialVars, } = require('librechat-data-provider'); -const { getUserPluginAuthValue } = require('~/server/services/PluginService'); const { availableTools, manifestToolMap, @@ -28,11 +28,10 @@ const { } = require('../'); const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process'); const { createFileSearchTool, primeFiles: primeSearchFiles } = require('./fileSearch'); +const { getUserPluginAuthValue } = require('~/server/services/PluginService'); const { loadAuthValues } = require('~/server/services/Tools/credentials'); +const { getCachedTools } = require('~/server/services/Config'); const { createMCPTool } = require('~/server/services/MCP'); -const { logger } = require('~/config'); - -const mcpToolPattern = new RegExp(`^.+${Constants.mcp_delimiter}.+$`); /** * Validates the availability and authentication of tools for a user based on environment variables or user-specific plugin authentication values. @@ -93,7 +92,7 @@ const validateTools = async (user, tools = []) => { return Array.from(validToolsSet.values()); } catch (err) { logger.error('[validateTools] There was a problem validating tools', err); - throw new Error('There was a problem validating tools'); + throw new Error(err); } }; @@ -236,7 +235,7 @@ const loadTools = async ({ /** @type {Record} */ const toolContextMap = {}; - const appTools = options.req?.app?.locals?.availableTools ?? {}; + const appTools = (await getCachedTools({ includeGlobal: true })) ?? {}; for (const tool of tools) { if (tool === Tools.execute_code) { @@ -299,6 +298,7 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })} requestedTools[tool] = async () => createMCPTool({ req: options.req, + res: options.res, toolKey: tool, model: agent?.model ?? model, provider: agent?.provider ?? endpoint, diff --git a/api/app/clients/tools/util/handleTools.test.js b/api/app/clients/tools/util/handleTools.test.js index 6538ce9aa4..1cacda8159 100644 --- a/api/app/clients/tools/util/handleTools.test.js +++ b/api/app/clients/tools/util/handleTools.test.js @@ -1,8 +1,5 @@ -const mockUser = { - _id: 'fakeId', - save: jest.fn(), - findByIdAndDelete: jest.fn(), -}; +const mongoose = require('mongoose'); +const { MongoMemoryServer } = require('mongodb-memory-server'); const mockPluginService = { updateUserPluginAuth: jest.fn(), @@ -10,23 +7,18 @@ const mockPluginService = { getUserPluginAuthValue: jest.fn(), }; -jest.mock('~/models/User', () => { - return function () { - return mockUser; - }; -}); - jest.mock('~/server/services/PluginService', () => mockPluginService); const { BaseLLM } = require('@langchain/openai'); const { Calculator } = require('@langchain/community/tools/calculator'); -const User = require('~/models/User'); +const { User } = require('~/db/models'); const PluginService = require('~/server/services/PluginService'); const { validateTools, loadTools, loadToolWithAuth } = require('./handleTools'); const { StructuredSD, availableTools, DALLE3 } = require('../'); describe('Tool Handlers', () => { + let mongoServer; let fakeUser; const pluginKey = 'dalle'; const pluginKey2 = 'wolfram'; @@ -37,7 +29,9 @@ describe('Tool Handlers', () => { const authConfigs = mainPlugin.authConfig; beforeAll(async () => { - mockUser.save.mockResolvedValue(undefined); + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + await mongoose.connect(mongoUri); const userAuthValues = {}; mockPluginService.getUserPluginAuthValue.mockImplementation((userId, authField) => { @@ -78,9 +72,36 @@ describe('Tool Handlers', () => { }); afterAll(async () => { - await mockUser.findByIdAndDelete(fakeUser._id); + await mongoose.disconnect(); + await mongoServer.stop(); + }); + + beforeEach(async () => { + // Clear mocks but not the database since we need the user to persist + jest.clearAllMocks(); + + // Reset the mock implementations + const userAuthValues = {}; + mockPluginService.getUserPluginAuthValue.mockImplementation((userId, authField) => { + return userAuthValues[`${userId}-${authField}`]; + }); + mockPluginService.updateUserPluginAuth.mockImplementation( + (userId, authField, _pluginKey, credential) => { + const fields = authField.split('||'); + fields.forEach((field) => { + userAuthValues[`${userId}-${field}`] = credential; + }); + }, + ); + + // Re-add the auth configs for the user for (const authConfig of authConfigs) { - await PluginService.deleteUserPluginAuth(fakeUser._id, authConfig.authField); + await PluginService.updateUserPluginAuth( + fakeUser._id, + authConfig.authField, + pluginKey, + mockCredential, + ); } }); @@ -218,7 +239,6 @@ describe('Tool Handlers', () => { try { await loadTool2(); } catch (error) { - // eslint-disable-next-line jest/no-conditional-expect expect(error).toBeDefined(); } }); diff --git a/api/cache/banViolation.js b/api/cache/banViolation.js index cdbff85c54..3a2d9791b4 100644 --- a/api/cache/banViolation.js +++ b/api/cache/banViolation.js @@ -1,8 +1,9 @@ +const { logger } = require('@librechat/data-schemas'); +const { isEnabled, math } = require('@librechat/api'); const { ViolationTypes } = require('librechat-data-provider'); -const { isEnabled, math, removePorts } = require('~/server/utils'); const { deleteAllUserSessions } = require('~/models'); +const { removePorts } = require('~/server/utils'); const getLogStores = require('./getLogStores'); -const { logger } = require('~/config'); const { BAN_VIOLATIONS, BAN_INTERVAL } = process.env ?? {}; const interval = math(BAN_INTERVAL, 20); @@ -32,7 +33,6 @@ const banViolation = async (req, res, errorMessage) => { if (!isEnabled(BAN_VIOLATIONS)) { return; } - if (!errorMessage) { return; } @@ -51,7 +51,6 @@ const banViolation = async (req, res, errorMessage) => { const banLogs = getLogStores(ViolationTypes.BAN); const duration = errorMessage.duration || banLogs.opts.ttl; - if (duration <= 0) { return; } diff --git a/api/cache/banViolation.spec.js b/api/cache/banViolation.spec.js index 8fef16920f..df98753498 100644 --- a/api/cache/banViolation.spec.js +++ b/api/cache/banViolation.spec.js @@ -1,48 +1,28 @@ +const mongoose = require('mongoose'); +const { MongoMemoryServer } = require('mongodb-memory-server'); const banViolation = require('./banViolation'); -jest.mock('keyv'); -jest.mock('../models/Session'); -// Mocking the getLogStores function -jest.mock('./getLogStores', () => { - return jest.fn().mockImplementation(() => { - const EventEmitter = require('events'); - const { CacheKeys } = require('librechat-data-provider'); - const math = require('../server/utils/math'); - const mockGet = jest.fn(); - const mockSet = jest.fn(); - class KeyvMongo extends EventEmitter { - constructor(url = 'mongodb://127.0.0.1:27017', options) { - super(); - this.ttlSupport = false; - url = url ?? {}; - if (typeof url === 'string') { - url = { url }; - } - if (url.uri) { - url = { url: url.uri, ...url }; - } - this.opts = { - url, - collection: 'keyv', - ...url, - ...options, - }; - } - - get = mockGet; - set = mockSet; - } - - return new KeyvMongo('', { - namespace: CacheKeys.BANS, - ttl: math(process.env.BAN_DURATION, 7200000), - }); - }); -}); +// Mock deleteAllUserSessions since we're testing ban logic, not session deletion +jest.mock('~/models', () => ({ + ...jest.requireActual('~/models'), + deleteAllUserSessions: jest.fn().mockResolvedValue(true), +})); describe('banViolation', () => { + let mongoServer; let req, res, errorMessage; + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + await mongoose.connect(mongoUri); + }); + + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); + }); + beforeEach(() => { req = { ip: '127.0.0.1', @@ -55,7 +35,7 @@ describe('banViolation', () => { }; errorMessage = { type: 'someViolation', - user_id: '12345', + user_id: new mongoose.Types.ObjectId().toString(), // Use valid ObjectId prev_count: 0, violation_count: 0, }; diff --git a/api/cache/getLogStores.js b/api/cache/getLogStores.js index b0a6a822ac..0eef7d3fb4 100644 --- a/api/cache/getLogStores.js +++ b/api/cache/getLogStores.js @@ -1,7 +1,7 @@ const { Keyv } = require('keyv'); +const { isEnabled, math } = require('@librechat/api'); const { CacheKeys, ViolationTypes, Time } = require('librechat-data-provider'); const { logFile, violationFile } = require('./keyvFiles'); -const { math, isEnabled } = require('~/server/utils'); const keyvRedis = require('./keyvRedis'); const keyvMongo = require('./keyvMongo'); @@ -29,6 +29,10 @@ const roles = isRedisEnabled ? new Keyv({ store: keyvRedis }) : new Keyv({ namespace: CacheKeys.ROLES }); +const mcpTools = isRedisEnabled + ? new Keyv({ store: keyvRedis }) + : new Keyv({ namespace: CacheKeys.MCP_TOOLS }); + const audioRuns = isRedisEnabled ? new Keyv({ store: keyvRedis, ttl: Time.TEN_MINUTES }) : new Keyv({ namespace: CacheKeys.AUDIO_RUNS, ttl: Time.TEN_MINUTES }); @@ -67,6 +71,7 @@ const openIdExchangedTokensCache = isRedisEnabled const namespaces = { [CacheKeys.ROLES]: roles, + [CacheKeys.MCP_TOOLS]: mcpTools, [CacheKeys.CONFIG_STORE]: config, [CacheKeys.PENDING_REQ]: pending_req, [ViolationTypes.BAN]: new Keyv({ store: keyvMongo, namespace: CacheKeys.BANS, ttl: duration }), diff --git a/api/config/index.js b/api/config/index.js index e238f700be..2e69e87118 100644 --- a/api/config/index.js +++ b/api/config/index.js @@ -1,7 +1,6 @@ -const axios = require('axios'); const { EventSource } = require('eventsource'); -const { Time, CacheKeys } = require('librechat-data-provider'); -const { MCPManager, FlowStateManager } = require('librechat-mcp'); +const { Time } = require('librechat-data-provider'); +const { MCPManager, FlowStateManager } = require('@librechat/api'); const logger = require('./winston'); global.EventSource = EventSource; @@ -16,7 +15,7 @@ let flowManager = null; */ function getMCPManager(userId) { if (!mcpManager) { - mcpManager = MCPManager.getInstance(logger); + mcpManager = MCPManager.getInstance(); } else { mcpManager.checkIdleConnections(userId); } @@ -31,66 +30,13 @@ function getFlowStateManager(flowsCache) { if (!flowManager) { flowManager = new FlowStateManager(flowsCache, { ttl: Time.ONE_MINUTE * 3, - logger, }); } return flowManager; } -/** - * Sends message data in Server Sent Events format. - * @param {ServerResponse} res - The server response. - * @param {{ data: string | Record, event?: string }} event - The message event. - * @param {string} event.event - The type of event. - * @param {string} event.data - The message to be sent. - */ -const sendEvent = (res, event) => { - if (typeof event.data === 'string' && event.data.length === 0) { - return; - } - res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`); -}; - -/** - * Creates and configures an Axios instance with optional proxy settings. - * - * @typedef {import('axios').AxiosInstance} AxiosInstance - * @typedef {import('axios').AxiosProxyConfig} AxiosProxyConfig - * - * @returns {AxiosInstance} A configured Axios instance - * @throws {Error} If there's an issue creating the Axios instance or parsing the proxy URL - */ -function createAxiosInstance() { - const instance = axios.create(); - - if (process.env.proxy) { - try { - const url = new URL(process.env.proxy); - - /** @type {AxiosProxyConfig} */ - const proxyConfig = { - host: url.hostname.replace(/^\[|\]$/g, ''), - protocol: url.protocol.replace(':', ''), - }; - - if (url.port) { - proxyConfig.port = parseInt(url.port, 10); - } - - instance.defaults.proxy = proxyConfig; - } catch (error) { - console.error('Error parsing proxy URL:', error); - throw new Error(`Invalid proxy URL: ${process.env.proxy}`); - } - } - - return instance; -} - module.exports = { logger, - sendEvent, getMCPManager, - createAxiosInstance, getFlowStateManager, }; diff --git a/api/lib/db/connectDb.js b/api/db/connect.js similarity index 96% rename from api/lib/db/connectDb.js rename to api/db/connect.js index b8cbeb2adb..e88ffa51ed 100644 --- a/api/lib/db/connectDb.js +++ b/api/db/connect.js @@ -39,7 +39,10 @@ async function connectDb() { }); } cached.conn = await cached.promise; + return cached.conn; } -module.exports = connectDb; +module.exports = { + connectDb, +}; diff --git a/api/db/index.js b/api/db/index.js new file mode 100644 index 0000000000..5c29902f69 --- /dev/null +++ b/api/db/index.js @@ -0,0 +1,8 @@ +const mongoose = require('mongoose'); +const { createModels } = require('@librechat/data-schemas'); +const { connectDb } = require('./connect'); +const indexSync = require('./indexSync'); + +createModels(mongoose); + +module.exports = { connectDb, indexSync }; diff --git a/api/db/indexSync.js b/api/db/indexSync.js new file mode 100644 index 0000000000..945346a906 --- /dev/null +++ b/api/db/indexSync.js @@ -0,0 +1,174 @@ +const mongoose = require('mongoose'); +const { MeiliSearch } = require('meilisearch'); +const { logger } = require('@librechat/data-schemas'); +const { FlowStateManager } = require('@librechat/api'); +const { CacheKeys } = require('librechat-data-provider'); + +const { isEnabled } = require('~/server/utils'); +const { getLogStores } = require('~/cache'); + +const Conversation = mongoose.models.Conversation; +const Message = mongoose.models.Message; + +const searchEnabled = isEnabled(process.env.SEARCH); +const indexingDisabled = isEnabled(process.env.MEILI_NO_SYNC); +let currentTimeout = null; + +class MeiliSearchClient { + static instance = null; + + static getInstance() { + if (!MeiliSearchClient.instance) { + if (!process.env.MEILI_HOST || !process.env.MEILI_MASTER_KEY) { + throw new Error('Meilisearch configuration is missing.'); + } + MeiliSearchClient.instance = new MeiliSearch({ + host: process.env.MEILI_HOST, + apiKey: process.env.MEILI_MASTER_KEY, + }); + } + return MeiliSearchClient.instance; + } +} + +/** + * Performs the actual sync operations for messages and conversations + */ +async function performSync() { + const client = MeiliSearchClient.getInstance(); + + const { status } = await client.health(); + if (status !== 'available') { + throw new Error('Meilisearch not available'); + } + + if (indexingDisabled === true) { + logger.info('[indexSync] Indexing is disabled, skipping...'); + return { messagesSync: false, convosSync: false }; + } + + let messagesSync = false; + let convosSync = false; + + // Check if we need to sync messages + const messageProgress = await Message.getSyncProgress(); + if (!messageProgress.isComplete) { + logger.info( + `[indexSync] Messages need syncing: ${messageProgress.totalProcessed}/${messageProgress.totalDocuments} indexed`, + ); + + // Check if we should do a full sync or incremental + const messageCount = await Message.countDocuments(); + const messagesIndexed = messageProgress.totalProcessed; + const syncThreshold = parseInt(process.env.MEILI_SYNC_THRESHOLD || '1000', 10); + + if (messageCount - messagesIndexed > syncThreshold) { + logger.info('[indexSync] Starting full message sync due to large difference'); + await Message.syncWithMeili(); + messagesSync = true; + } else if (messageCount !== messagesIndexed) { + logger.warn('[indexSync] Messages out of sync, performing incremental sync'); + await Message.syncWithMeili(); + messagesSync = true; + } + } else { + logger.info( + `[indexSync] Messages are fully synced: ${messageProgress.totalProcessed}/${messageProgress.totalDocuments}`, + ); + } + + // Check if we need to sync conversations + const convoProgress = await Conversation.getSyncProgress(); + if (!convoProgress.isComplete) { + logger.info( + `[indexSync] Conversations need syncing: ${convoProgress.totalProcessed}/${convoProgress.totalDocuments} indexed`, + ); + + const convoCount = await Conversation.countDocuments(); + const convosIndexed = convoProgress.totalProcessed; + const syncThreshold = parseInt(process.env.MEILI_SYNC_THRESHOLD || '1000', 10); + + if (convoCount - convosIndexed > syncThreshold) { + logger.info('[indexSync] Starting full conversation sync due to large difference'); + await Conversation.syncWithMeili(); + convosSync = true; + } else if (convoCount !== convosIndexed) { + logger.warn('[indexSync] Convos out of sync, performing incremental sync'); + await Conversation.syncWithMeili(); + convosSync = true; + } + } else { + logger.info( + `[indexSync] Conversations are fully synced: ${convoProgress.totalProcessed}/${convoProgress.totalDocuments}`, + ); + } + + return { messagesSync, convosSync }; +} + +/** + * Main index sync function that uses FlowStateManager to prevent concurrent execution + */ +async function indexSync() { + if (!searchEnabled) { + return; + } + + logger.info('[indexSync] Starting index synchronization check...'); + + try { + // Get or create FlowStateManager instance + const flowsCache = getLogStores(CacheKeys.FLOWS); + if (!flowsCache) { + logger.warn('[indexSync] Flows cache not available, falling back to direct sync'); + return await performSync(); + } + + const flowManager = new FlowStateManager(flowsCache, { + ttl: 60000 * 10, // 10 minutes TTL for sync operations + }); + + // Use a unique flow ID for the sync operation + const flowId = 'meili-index-sync'; + const flowType = 'MEILI_SYNC'; + + // This will only execute the handler if no other instance is running the sync + const result = await flowManager.createFlowWithHandler(flowId, flowType, performSync); + + if (result.messagesSync || result.convosSync) { + logger.info('[indexSync] Sync completed successfully'); + } else { + logger.debug('[indexSync] No sync was needed'); + } + + return result; + } catch (err) { + if (err.message.includes('flow already exists')) { + logger.info('[indexSync] Sync already running on another instance'); + return; + } + + if (err.message.includes('not found')) { + logger.debug('[indexSync] Creating indices...'); + currentTimeout = setTimeout(async () => { + try { + await Message.syncWithMeili(); + await Conversation.syncWithMeili(); + } catch (err) { + logger.error('[indexSync] Trouble creating indices, try restarting the server.', err); + } + }, 750); + } else if (err.message.includes('Meilisearch not configured')) { + logger.info('[indexSync] Meilisearch not configured, search will be disabled.'); + } else { + logger.error('[indexSync] error', err); + } + } +} + +process.on('exit', () => { + logger.debug('[indexSync] Clearing sync timeouts before exiting...'); + clearTimeout(currentTimeout); +}); + +module.exports = indexSync; diff --git a/api/db/models.js b/api/db/models.js new file mode 100644 index 0000000000..fca1327446 --- /dev/null +++ b/api/db/models.js @@ -0,0 +1,5 @@ +const mongoose = require('mongoose'); +const { createModels } = require('@librechat/data-schemas'); +const models = createModels(mongoose); + +module.exports = { ...models }; diff --git a/api/lib/db/index.js b/api/lib/db/index.js deleted file mode 100644 index fa7a460d05..0000000000 --- a/api/lib/db/index.js +++ /dev/null @@ -1,4 +0,0 @@ -const connectDb = require('./connectDb'); -const indexSync = require('./indexSync'); - -module.exports = { connectDb, indexSync }; diff --git a/api/lib/db/indexSync.js b/api/lib/db/indexSync.js deleted file mode 100644 index 75acd9d231..0000000000 --- a/api/lib/db/indexSync.js +++ /dev/null @@ -1,89 +0,0 @@ -const { MeiliSearch } = require('meilisearch'); -const { Conversation } = require('~/models/Conversation'); -const { Message } = require('~/models/Message'); -const { isEnabled } = require('~/server/utils'); -const { logger } = require('~/config'); - -const searchEnabled = isEnabled(process.env.SEARCH); -const indexingDisabled = isEnabled(process.env.MEILI_NO_SYNC); -let currentTimeout = null; - -class MeiliSearchClient { - static instance = null; - - static getInstance() { - if (!MeiliSearchClient.instance) { - if (!process.env.MEILI_HOST || !process.env.MEILI_MASTER_KEY) { - throw new Error('Meilisearch configuration is missing.'); - } - MeiliSearchClient.instance = new MeiliSearch({ - host: process.env.MEILI_HOST, - apiKey: process.env.MEILI_MASTER_KEY, - }); - } - return MeiliSearchClient.instance; - } -} - -async function indexSync() { - if (!searchEnabled) { - return; - } - - try { - const client = MeiliSearchClient.getInstance(); - - const { status } = await client.health(); - if (status !== 'available') { - throw new Error('Meilisearch not available'); - } - - if (indexingDisabled === true) { - logger.info('[indexSync] Indexing is disabled, skipping...'); - return; - } - - const messageCount = await Message.countDocuments(); - const convoCount = await Conversation.countDocuments(); - const messages = await client.index('messages').getStats(); - const convos = await client.index('convos').getStats(); - const messagesIndexed = messages.numberOfDocuments; - const convosIndexed = convos.numberOfDocuments; - - logger.debug(`[indexSync] There are ${messageCount} messages and ${messagesIndexed} indexed`); - logger.debug(`[indexSync] There are ${convoCount} convos and ${convosIndexed} indexed`); - - if (messageCount !== messagesIndexed) { - logger.debug('[indexSync] Messages out of sync, indexing'); - Message.syncWithMeili(); - } - - if (convoCount !== convosIndexed) { - logger.debug('[indexSync] Convos out of sync, indexing'); - Conversation.syncWithMeili(); - } - } catch (err) { - if (err.message.includes('not found')) { - logger.debug('[indexSync] Creating indices...'); - currentTimeout = setTimeout(async () => { - try { - await Message.syncWithMeili(); - await Conversation.syncWithMeili(); - } catch (err) { - logger.error('[indexSync] Trouble creating indices, try restarting the server.', err); - } - }, 750); - } else if (err.message.includes('Meilisearch not configured')) { - logger.info('[indexSync] Meilisearch not configured, search will be disabled.'); - } else { - logger.error('[indexSync] error', err); - } - } -} - -process.on('exit', () => { - logger.debug('[indexSync] Clearing sync timeouts before exiting...'); - clearTimeout(currentTimeout); -}); - -module.exports = indexSync; diff --git a/api/models/Action.js b/api/models/Action.js index 677b4d78df..20aa20a7e4 100644 --- a/api/models/Action.js +++ b/api/models/Action.js @@ -1,7 +1,4 @@ -const mongoose = require('mongoose'); -const { actionSchema } = require('@librechat/data-schemas'); - -const Action = mongoose.model('action', actionSchema); +const { Action } = require('~/db/models'); /** * Update an action with new data without overwriting existing properties, diff --git a/api/models/Agent.js b/api/models/Agent.js index a2b325b5bf..04ba8b020e 100644 --- a/api/models/Agent.js +++ b/api/models/Agent.js @@ -1,6 +1,6 @@ const mongoose = require('mongoose'); const crypto = require('node:crypto'); -const { agentSchema } = require('@librechat/data-schemas'); +const { logger } = require('@librechat/data-schemas'); const { SystemRoles, Tools, actionDelimiter } = require('librechat-data-provider'); const { GLOBAL_PROJECT_NAME, EPHEMERAL_AGENT_ID, mcp_delimiter } = require('librechat-data-provider').Constants; @@ -11,11 +11,10 @@ const { removeAgentIdsFromProject, removeAgentFromAllProjects, } = require('./Project'); +const { getCachedTools } = require('~/server/services/Config'); const getLogStores = require('~/cache/getLogStores'); const { getActions } = require('./Action'); -const { logger } = require('~/config'); - -const Agent = mongoose.model('agent', agentSchema); +const { Agent } = require('~/db/models'); /** * Create an agent with the provided data. @@ -57,12 +56,12 @@ const getAgent = async (searchParameter) => await Agent.findOne(searchParameter) * @param {string} params.agent_id * @param {string} params.endpoint * @param {import('@librechat/agents').ClientOptions} [params.model_parameters] - * @returns {Agent|null} The agent document as a plain object, or null if not found. + * @returns {Promise} The agent document as a plain object, or null if not found. */ -const loadEphemeralAgent = ({ req, agent_id, endpoint, model_parameters: _m }) => { +const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _m }) => { const { model, ...model_parameters } = _m; /** @type {Record} */ - const availableTools = req.app.locals.availableTools; + const availableTools = await getCachedTools({ includeGlobal: true }); /** @type {TEphemeralAgent | null} */ const ephemeralAgent = req.body.ephemeralAgent; const mcpServers = new Set(ephemeralAgent?.mcp); @@ -71,6 +70,9 @@ const loadEphemeralAgent = ({ req, agent_id, endpoint, model_parameters: _m }) = if (ephemeralAgent?.execute_code === true) { tools.push(Tools.execute_code); } + if (ephemeralAgent?.file_search === true) { + tools.push(Tools.file_search); + } if (ephemeralAgent?.web_search === true) { tools.push(Tools.web_search); } @@ -113,7 +115,7 @@ const loadAgent = async ({ req, agent_id, endpoint, model_parameters }) => { return null; } if (agent_id === EPHEMERAL_AGENT_ID) { - return loadEphemeralAgent({ req, agent_id, endpoint, model_parameters }); + return await loadEphemeralAgent({ req, agent_id, endpoint, model_parameters }); } const agent = await getAgent({ id: agent_id, @@ -172,7 +174,6 @@ const isDuplicateVersion = (updateData, currentData, versions, actionsHash = nul 'created_at', 'updated_at', '__v', - 'agent_ids', 'versions', 'actionsHash', // Exclude actionsHash from direct comparison ]; @@ -262,11 +263,12 @@ const isDuplicateVersion = (updateData, currentData, versions, actionsHash = nul * @param {Object} [options] - Optional configuration object. * @param {string} [options.updatingUserId] - The ID of the user performing the update (used for tracking non-author updates). * @param {boolean} [options.forceVersion] - Force creation of a new version even if no fields changed. + * @param {boolean} [options.skipVersioning] - Skip version creation entirely (useful for isolated operations like sharing). * @returns {Promise} The updated or newly created agent document as a plain object. * @throws {Error} If the update would create a duplicate version */ const updateAgent = async (searchParameter, updateData, options = {}) => { - const { updatingUserId = null, forceVersion = false } = options; + const { updatingUserId = null, forceVersion = false, skipVersioning = false } = options; const mongoOptions = { new: true, upsert: false }; const currentAgent = await Agent.findOne(searchParameter); @@ -303,10 +305,8 @@ const updateAgent = async (searchParameter, updateData, options = {}) => { } const shouldCreateVersion = - forceVersion || - (versions && - versions.length > 0 && - (Object.keys(directUpdates).length > 0 || $push || $pull || $addToSet)); + !skipVersioning && + (forceVersion || Object.keys(directUpdates).length > 0 || $push || $pull || $addToSet); if (shouldCreateVersion) { const duplicateVersion = isDuplicateVersion(updateData, versionData, versions, actionsHash); @@ -341,7 +341,7 @@ const updateAgent = async (searchParameter, updateData, options = {}) => { versionEntry.updatedBy = new mongoose.Types.ObjectId(updatingUserId); } - if (shouldCreateVersion || forceVersion) { + if (shouldCreateVersion) { updateData.$push = { ...($push || {}), versions: versionEntry, @@ -481,7 +481,6 @@ const getListAgents = async (searchParameter) => { delete globalQuery.author; query = { $or: [globalQuery, query] }; } - const agents = ( await Agent.find(query, { id: 1, @@ -553,7 +552,10 @@ const updateAgentProjects = async ({ user, agentId, projectIds, removeProjectIds delete updateQuery.author; } - const updatedAgent = await updateAgent(updateQuery, updateOps, { updatingUserId: user.id }); + const updatedAgent = await updateAgent(updateQuery, updateOps, { + updatingUserId: user.id, + skipVersioning: true, + }); if (updatedAgent) { return updatedAgent; } @@ -662,7 +664,6 @@ const generateActionMetadataHash = async (actionIds, actions) => { */ module.exports = { - Agent, getAgent, loadAgent, createAgent, diff --git a/api/models/Agent.spec.js b/api/models/Agent.spec.js index 41b32ffa92..8953ae0482 100644 --- a/api/models/Agent.spec.js +++ b/api/models/Agent.spec.js @@ -6,1071 +6,2655 @@ const originalEnv = { process.env.CREDS_KEY = '0123456789abcdef0123456789abcdef'; process.env.CREDS_IV = '0123456789abcdef'; +jest.mock('~/server/services/Config', () => ({ + getCachedTools: jest.fn(), +})); + const mongoose = require('mongoose'); const { v4: uuidv4 } = require('uuid'); +const { agentSchema } = require('@librechat/data-schemas'); const { MongoMemoryServer } = require('mongodb-memory-server'); const { - Agent, - addAgentResourceFile, - removeAgentResourceFiles, + getAgent, + loadAgent, createAgent, updateAgent, - getAgent, deleteAgent, getListAgents, updateAgentProjects, + addAgentResourceFile, + removeAgentResourceFiles, + generateActionMetadataHash, + revertAgentVersion, } = require('./Agent'); +const { getCachedTools } = require('~/server/services/Config'); -describe('Agent Resource File Operations', () => { - let mongoServer; +/** + * @type {import('mongoose').Model} + */ +let Agent; - beforeAll(async () => { - mongoServer = await MongoMemoryServer.create(); - const mongoUri = mongoServer.getUri(); - await mongoose.connect(mongoUri); - }); +describe('models/Agent', () => { + describe('Agent Resource File Operations', () => { + let mongoServer; - afterAll(async () => { - await mongoose.disconnect(); - await mongoServer.stop(); - process.env.CREDS_KEY = originalEnv.CREDS_KEY; - process.env.CREDS_IV = originalEnv.CREDS_IV; - }); + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema); + await mongoose.connect(mongoUri); + }, 20000); - beforeEach(async () => { - await Agent.deleteMany({}); - }); - - const createBasicAgent = async () => { - const agentId = `agent_${uuidv4()}`; - const agent = await Agent.create({ - id: agentId, - name: 'Test Agent', - provider: 'test', - model: 'test-model', - author: new mongoose.Types.ObjectId(), - }); - return agent; - }; - - test('should add tool_resource to tools if missing', async () => { - const agent = await createBasicAgent(); - const fileId = uuidv4(); - const toolResource = 'file_search'; - - const updatedAgent = await addAgentResourceFile({ - agent_id: agent.id, - tool_resource: toolResource, - file_id: fileId, + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); + process.env.CREDS_KEY = originalEnv.CREDS_KEY; + process.env.CREDS_IV = originalEnv.CREDS_IV; }); - expect(updatedAgent.tools).toContain(toolResource); - expect(Array.isArray(updatedAgent.tools)).toBe(true); - // Should not duplicate - const count = updatedAgent.tools.filter((t) => t === toolResource).length; - expect(count).toBe(1); - }); - - test('should not duplicate tool_resource in tools if already present', async () => { - const agent = await createBasicAgent(); - const fileId1 = uuidv4(); - const fileId2 = uuidv4(); - const toolResource = 'file_search'; - - // First add - await addAgentResourceFile({ - agent_id: agent.id, - tool_resource: toolResource, - file_id: fileId1, + beforeEach(async () => { + await Agent.deleteMany({}); }); - // Second add (should not duplicate) - const updatedAgent = await addAgentResourceFile({ - agent_id: agent.id, - tool_resource: toolResource, - file_id: fileId2, + test('should add tool_resource to tools if missing', async () => { + const agent = await createBasicAgent(); + const fileId = uuidv4(); + const toolResource = 'file_search'; + + const updatedAgent = await addAgentResourceFile({ + agent_id: agent.id, + tool_resource: toolResource, + file_id: fileId, + }); + + expect(updatedAgent.tools).toContain(toolResource); + expect(Array.isArray(updatedAgent.tools)).toBe(true); + // Should not duplicate + const count = updatedAgent.tools.filter((t) => t === toolResource).length; + expect(count).toBe(1); }); - expect(updatedAgent.tools).toContain(toolResource); - expect(Array.isArray(updatedAgent.tools)).toBe(true); - const count = updatedAgent.tools.filter((t) => t === toolResource).length; - expect(count).toBe(1); - }); + test('should not duplicate tool_resource in tools if already present', async () => { + const agent = await createBasicAgent(); + const fileId1 = uuidv4(); + const fileId2 = uuidv4(); + const toolResource = 'file_search'; - test('should handle concurrent file additions', async () => { - const agent = await createBasicAgent(); - const fileIds = Array.from({ length: 10 }, () => uuidv4()); + // First add + await addAgentResourceFile({ + agent_id: agent.id, + tool_resource: toolResource, + file_id: fileId1, + }); - // Concurrent additions - const additionPromises = fileIds.map((fileId) => - addAgentResourceFile({ + // Second add (should not duplicate) + const updatedAgent = await addAgentResourceFile({ + agent_id: agent.id, + tool_resource: toolResource, + file_id: fileId2, + }); + + expect(updatedAgent.tools).toContain(toolResource); + expect(Array.isArray(updatedAgent.tools)).toBe(true); + const count = updatedAgent.tools.filter((t) => t === toolResource).length; + expect(count).toBe(1); + }); + + test('should handle concurrent file additions', async () => { + const agent = await createBasicAgent(); + const fileIds = Array.from({ length: 10 }, () => uuidv4()); + + // Concurrent additions + const additionPromises = createFileOperations(agent.id, fileIds, 'add'); + + await Promise.all(additionPromises); + + const updatedAgent = await Agent.findOne({ id: agent.id }); + expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined(); + expect(updatedAgent.tool_resources.test_tool.file_ids).toHaveLength(10); + expect(new Set(updatedAgent.tool_resources.test_tool.file_ids).size).toBe(10); + }); + + test('should handle concurrent additions and removals', async () => { + const agent = await createBasicAgent(); + const initialFileIds = Array.from({ length: 5 }, () => uuidv4()); + + await Promise.all(createFileOperations(agent.id, initialFileIds, 'add')); + + const newFileIds = Array.from({ length: 5 }, () => uuidv4()); + const operations = [ + ...newFileIds.map((fileId) => + addAgentResourceFile({ + agent_id: agent.id, + tool_resource: 'test_tool', + file_id: fileId, + }), + ), + ...initialFileIds.map((fileId) => + removeAgentResourceFiles({ + agent_id: agent.id, + files: [{ tool_resource: 'test_tool', file_id: fileId }], + }), + ), + ]; + + await Promise.all(operations); + + const updatedAgent = await Agent.findOne({ id: agent.id }); + expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined(); + expect(updatedAgent.tool_resources.test_tool.file_ids).toHaveLength(5); + }); + + test('should initialize array when adding to non-existent tool resource', async () => { + const agent = await createBasicAgent(); + const fileId = uuidv4(); + + const updatedAgent = await addAgentResourceFile({ + agent_id: agent.id, + tool_resource: 'new_tool', + file_id: fileId, + }); + + expect(updatedAgent.tool_resources.new_tool.file_ids).toBeDefined(); + expect(updatedAgent.tool_resources.new_tool.file_ids).toHaveLength(1); + expect(updatedAgent.tool_resources.new_tool.file_ids[0]).toBe(fileId); + }); + + test('should handle rapid sequential modifications to same tool resource', async () => { + const agent = await createBasicAgent(); + const fileId = uuidv4(); + + for (let i = 0; i < 10; i++) { + await addAgentResourceFile({ + agent_id: agent.id, + tool_resource: 'test_tool', + file_id: `${fileId}_${i}`, + }); + + if (i % 2 === 0) { + await removeAgentResourceFiles({ + agent_id: agent.id, + files: [{ tool_resource: 'test_tool', file_id: `${fileId}_${i}` }], + }); + } + } + + const updatedAgent = await Agent.findOne({ id: agent.id }); + expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined(); + expect(Array.isArray(updatedAgent.tool_resources.test_tool.file_ids)).toBe(true); + }); + + test('should handle multiple tool resources concurrently', async () => { + const agent = await createBasicAgent(); + const toolResources = ['tool1', 'tool2', 'tool3']; + const operations = []; + + toolResources.forEach((tool) => { + const fileIds = Array.from({ length: 5 }, () => uuidv4()); + fileIds.forEach((fileId) => { + operations.push( + addAgentResourceFile({ + agent_id: agent.id, + tool_resource: tool, + file_id: fileId, + }), + ); + }); + }); + + await Promise.all(operations); + + const updatedAgent = await Agent.findOne({ id: agent.id }); + toolResources.forEach((tool) => { + expect(updatedAgent.tool_resources[tool].file_ids).toBeDefined(); + expect(updatedAgent.tool_resources[tool].file_ids).toHaveLength(5); + }); + }); + + test.each([ + { + name: 'duplicate additions', + operation: 'add', + duplicateCount: 5, + expectedLength: 1, + expectedContains: true, + }, + { + name: 'duplicate removals', + operation: 'remove', + duplicateCount: 5, + expectedLength: 0, + expectedContains: false, + setupFile: true, + }, + ])( + 'should handle concurrent $name', + async ({ operation, duplicateCount, expectedLength, expectedContains, setupFile }) => { + const agent = await createBasicAgent(); + const fileId = uuidv4(); + + if (setupFile) { + await addAgentResourceFile({ + agent_id: agent.id, + tool_resource: 'test_tool', + file_id: fileId, + }); + } + + const promises = Array.from({ length: duplicateCount }).map(() => + operation === 'add' + ? addAgentResourceFile({ + agent_id: agent.id, + tool_resource: 'test_tool', + file_id: fileId, + }) + : removeAgentResourceFiles({ + agent_id: agent.id, + files: [{ tool_resource: 'test_tool', file_id: fileId }], + }), + ); + + await Promise.all(promises); + + const updatedAgent = await Agent.findOne({ id: agent.id }); + const fileIds = updatedAgent.tool_resources?.test_tool?.file_ids ?? []; + + expect(fileIds).toHaveLength(expectedLength); + if (expectedContains) { + expect(fileIds[0]).toBe(fileId); + } else { + expect(fileIds).not.toContain(fileId); + } + }, + ); + + test('should handle concurrent add and remove of the same file', async () => { + const agent = await createBasicAgent(); + const fileId = uuidv4(); + + await addAgentResourceFile({ agent_id: agent.id, tool_resource: 'test_tool', file_id: fileId, - }), - ); + }); - await Promise.all(additionPromises); - - const updatedAgent = await Agent.findOne({ id: agent.id }); - expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined(); - expect(updatedAgent.tool_resources.test_tool.file_ids).toHaveLength(10); - expect(new Set(updatedAgent.tool_resources.test_tool.file_ids).size).toBe(10); - }); - - test('should handle concurrent additions and removals', async () => { - const agent = await createBasicAgent(); - const initialFileIds = Array.from({ length: 5 }, () => uuidv4()); - - await Promise.all( - initialFileIds.map((fileId) => + const operations = [ addAgentResourceFile({ agent_id: agent.id, tool_resource: 'test_tool', file_id: fileId, }), - ), - ); - - const newFileIds = Array.from({ length: 5 }, () => uuidv4()); - const operations = [ - ...newFileIds.map((fileId) => - addAgentResourceFile({ - agent_id: agent.id, - tool_resource: 'test_tool', - file_id: fileId, - }), - ), - ...initialFileIds.map((fileId) => removeAgentResourceFiles({ agent_id: agent.id, files: [{ tool_resource: 'test_tool', file_id: fileId }], }), - ), - ]; + ]; - await Promise.all(operations); + await Promise.all(operations); - const updatedAgent = await Agent.findOne({ id: agent.id }); - expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined(); - expect(updatedAgent.tool_resources.test_tool.file_ids).toHaveLength(5); - }); + const updatedAgent = await Agent.findOne({ id: agent.id }); + const finalFileIds = updatedAgent.tool_resources.test_tool.file_ids; + const count = finalFileIds.filter((id) => id === fileId).length; - test('should initialize array when adding to non-existent tool resource', async () => { - const agent = await createBasicAgent(); - const fileId = uuidv4(); - - const updatedAgent = await addAgentResourceFile({ - agent_id: agent.id, - tool_resource: 'new_tool', - file_id: fileId, + expect(count).toBeLessThanOrEqual(1); + if (count === 0) { + expect(finalFileIds).toHaveLength(0); + } else { + expect(finalFileIds).toHaveLength(1); + expect(finalFileIds[0]).toBe(fileId); + } }); - expect(updatedAgent.tool_resources.new_tool.file_ids).toBeDefined(); - expect(updatedAgent.tool_resources.new_tool.file_ids).toHaveLength(1); - expect(updatedAgent.tool_resources.new_tool.file_ids[0]).toBe(fileId); - }); + test('should handle concurrent removals of different files', async () => { + const agent = await createBasicAgent(); + const fileIds = Array.from({ length: 10 }, () => uuidv4()); - test('should handle rapid sequential modifications to same tool resource', async () => { - const agent = await createBasicAgent(); - const fileId = uuidv4(); - - for (let i = 0; i < 10; i++) { - await addAgentResourceFile({ - agent_id: agent.id, - tool_resource: 'test_tool', - file_id: `${fileId}_${i}`, - }); - - if (i % 2 === 0) { - await removeAgentResourceFiles({ - agent_id: agent.id, - files: [{ tool_resource: 'test_tool', file_id: `${fileId}_${i}` }], - }); - } - } - - const updatedAgent = await Agent.findOne({ id: agent.id }); - expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined(); - expect(Array.isArray(updatedAgent.tool_resources.test_tool.file_ids)).toBe(true); - }); - - test('should handle multiple tool resources concurrently', async () => { - const agent = await createBasicAgent(); - const toolResources = ['tool1', 'tool2', 'tool3']; - const operations = []; - - toolResources.forEach((tool) => { - const fileIds = Array.from({ length: 5 }, () => uuidv4()); - fileIds.forEach((fileId) => { - operations.push( + // Add all files first + await Promise.all( + fileIds.map((fileId) => addAgentResourceFile({ agent_id: agent.id, - tool_resource: tool, + tool_resource: 'test_tool', file_id: fileId, }), - ); + ), + ); + + // Concurrently remove all files + const removalPromises = fileIds.map((fileId) => + removeAgentResourceFiles({ + agent_id: agent.id, + files: [{ tool_resource: 'test_tool', file_id: fileId }], + }), + ); + + await Promise.all(removalPromises); + + const updatedAgent = await Agent.findOne({ id: agent.id }); + // Check if the array is empty or the tool resource itself is removed + const finalFileIds = updatedAgent.tool_resources?.test_tool?.file_ids ?? []; + expect(finalFileIds).toHaveLength(0); + }); + + describe('Edge Cases', () => { + describe.each([ + { + operation: 'add', + name: 'empty file_id', + needsAgent: true, + params: { tool_resource: 'file_search', file_id: '' }, + shouldResolve: true, + }, + { + operation: 'add', + name: 'non-existent agent', + needsAgent: false, + params: { tool_resource: 'file_search', file_id: 'file123' }, + shouldResolve: false, + error: 'Agent not found for adding resource file', + }, + ])('addAgentResourceFile with $name', ({ needsAgent, params, shouldResolve, error }) => { + test(`should ${shouldResolve ? 'resolve' : 'reject'}`, async () => { + const agent = needsAgent ? await createBasicAgent() : null; + const agent_id = needsAgent ? agent.id : `agent_${uuidv4()}`; + + if (shouldResolve) { + await expect(addAgentResourceFile({ agent_id, ...params })).resolves.toBeDefined(); + } else { + await expect(addAgentResourceFile({ agent_id, ...params })).rejects.toThrow(error); + } + }); + }); + + describe.each([ + { + name: 'empty files array', + files: [], + needsAgent: true, + shouldResolve: true, + }, + { + name: 'non-existent tool_resource', + files: [{ tool_resource: 'non_existent_tool', file_id: 'file123' }], + needsAgent: true, + shouldResolve: true, + }, + { + name: 'non-existent agent', + files: [{ tool_resource: 'file_search', file_id: 'file123' }], + needsAgent: false, + shouldResolve: false, + error: 'Agent not found for removing resource files', + }, + ])('removeAgentResourceFiles with $name', ({ files, needsAgent, shouldResolve, error }) => { + test(`should ${shouldResolve ? 'resolve' : 'reject'}`, async () => { + const agent = needsAgent ? await createBasicAgent() : null; + const agent_id = needsAgent ? agent.id : `agent_${uuidv4()}`; + + if (shouldResolve) { + const result = await removeAgentResourceFiles({ agent_id, files }); + expect(result).toBeDefined(); + if (agent) { + expect(result.id).toBe(agent.id); + } + } else { + await expect(removeAgentResourceFiles({ agent_id, files })).rejects.toThrow(error); + } + }); }); }); - - await Promise.all(operations); - - const updatedAgent = await Agent.findOne({ id: agent.id }); - toolResources.forEach((tool) => { - expect(updatedAgent.tool_resources[tool].file_ids).toBeDefined(); - expect(updatedAgent.tool_resources[tool].file_ids).toHaveLength(5); - }); }); - test('should handle concurrent duplicate additions', async () => { - const agent = await createBasicAgent(); - const fileId = uuidv4(); + describe('Agent CRUD Operations', () => { + let mongoServer; - // Concurrent additions of the same file - const additionPromises = Array.from({ length: 5 }).map(() => - addAgentResourceFile({ - agent_id: agent.id, - tool_resource: 'test_tool', - file_id: fileId, - }), - ); + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema); + await mongoose.connect(mongoUri); + }, 20000); - await Promise.all(additionPromises); - - const updatedAgent = await Agent.findOne({ id: agent.id }); - expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined(); - // Should only contain one instance of the fileId - expect(updatedAgent.tool_resources.test_tool.file_ids).toHaveLength(1); - expect(updatedAgent.tool_resources.test_tool.file_ids[0]).toBe(fileId); - }); - - test('should handle concurrent add and remove of the same file', async () => { - const agent = await createBasicAgent(); - const fileId = uuidv4(); - - // First, ensure the file exists (or test might be trivial if remove runs first) - await addAgentResourceFile({ - agent_id: agent.id, - tool_resource: 'test_tool', - file_id: fileId, + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); }); - // Concurrent add (which should be ignored) and remove - const operations = [ - addAgentResourceFile({ - agent_id: agent.id, - tool_resource: 'test_tool', - file_id: fileId, - }), - removeAgentResourceFiles({ - agent_id: agent.id, - files: [{ tool_resource: 'test_tool', file_id: fileId }], - }), - ]; - - await Promise.all(operations); - - const updatedAgent = await Agent.findOne({ id: agent.id }); - // The final state should ideally be that the file is removed, - // but the key point is consistency (not duplicated or error state). - // Depending on execution order, the file might remain if the add operation's - // findOneAndUpdate runs after the remove operation completes. - // A more robust check might be that the length is <= 1. - // Given the remove uses an update pipeline, it might be more likely to win. - // The final state depends on race condition timing (add or remove might "win"). - // The critical part is that the state is consistent (no duplicates, no errors). - // Assert that the fileId is either present exactly once or not present at all. - expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined(); - const finalFileIds = updatedAgent.tool_resources.test_tool.file_ids; - const count = finalFileIds.filter((id) => id === fileId).length; - expect(count).toBeLessThanOrEqual(1); // Should be 0 or 1, never more - // Optional: Check overall length is consistent with the count - if (count === 0) { - expect(finalFileIds).toHaveLength(0); - } else { - expect(finalFileIds).toHaveLength(1); - expect(finalFileIds[0]).toBe(fileId); - } - }); - - test('should handle concurrent duplicate removals', async () => { - const agent = await createBasicAgent(); - const fileId = uuidv4(); - - // Add the file first - await addAgentResourceFile({ - agent_id: agent.id, - tool_resource: 'test_tool', - file_id: fileId, + beforeEach(async () => { + await Agent.deleteMany({}); }); - // Concurrent removals of the same file - const removalPromises = Array.from({ length: 5 }).map(() => - removeAgentResourceFiles({ - agent_id: agent.id, - files: [{ tool_resource: 'test_tool', file_id: fileId }], - }), - ); + test('should create and get an agent', async () => { + const { agentId, authorId } = createTestIds(); - await Promise.all(removalPromises); + const newAgent = await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + description: 'Test description', + }); - const updatedAgent = await Agent.findOne({ id: agent.id }); - // Check if the array is empty or the tool resource itself is removed - const fileIds = updatedAgent.tool_resources?.test_tool?.file_ids ?? []; - expect(fileIds).toHaveLength(0); - expect(fileIds).not.toContain(fileId); - }); + expect(newAgent).toBeDefined(); + expect(newAgent.id).toBe(agentId); + expect(newAgent.name).toBe('Test Agent'); - test('should handle concurrent removals of different files', async () => { - const agent = await createBasicAgent(); - const fileIds = Array.from({ length: 10 }, () => uuidv4()); - - // Add all files first - await Promise.all( - fileIds.map((fileId) => - addAgentResourceFile({ - agent_id: agent.id, - tool_resource: 'test_tool', - file_id: fileId, - }), - ), - ); - - // Concurrently remove all files - const removalPromises = fileIds.map((fileId) => - removeAgentResourceFiles({ - agent_id: agent.id, - files: [{ tool_resource: 'test_tool', file_id: fileId }], - }), - ); - - await Promise.all(removalPromises); - - const updatedAgent = await Agent.findOne({ id: agent.id }); - // Check if the array is empty or the tool resource itself is removed - const finalFileIds = updatedAgent.tool_resources?.test_tool?.file_ids ?? []; - expect(finalFileIds).toHaveLength(0); - }); -}); - -describe('Agent CRUD Operations', () => { - let mongoServer; - - beforeAll(async () => { - mongoServer = await MongoMemoryServer.create(); - const mongoUri = mongoServer.getUri(); - await mongoose.connect(mongoUri); - }); - - afterAll(async () => { - await mongoose.disconnect(); - await mongoServer.stop(); - }); - - beforeEach(async () => { - await Agent.deleteMany({}); - }); - - test('should create and get an agent', async () => { - const agentId = `agent_${uuidv4()}`; - const authorId = new mongoose.Types.ObjectId(); - - const newAgent = await createAgent({ - id: agentId, - name: 'Test Agent', - provider: 'test', - model: 'test-model', - author: authorId, - description: 'Test description', + const retrievedAgent = await getAgent({ id: agentId }); + expect(retrievedAgent).toBeDefined(); + expect(retrievedAgent.id).toBe(agentId); + expect(retrievedAgent.name).toBe('Test Agent'); + expect(retrievedAgent.description).toBe('Test description'); }); - expect(newAgent).toBeDefined(); - expect(newAgent.id).toBe(agentId); - expect(newAgent.name).toBe('Test Agent'); + test('should delete an agent', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); - const retrievedAgent = await getAgent({ id: agentId }); - expect(retrievedAgent).toBeDefined(); - expect(retrievedAgent.id).toBe(agentId); - expect(retrievedAgent.name).toBe('Test Agent'); - expect(retrievedAgent.description).toBe('Test description'); - }); - - test('should delete an agent', async () => { - const agentId = `agent_${uuidv4()}`; - const authorId = new mongoose.Types.ObjectId(); - - await createAgent({ - id: agentId, - name: 'Agent To Delete', - provider: 'test', - model: 'test-model', - author: authorId, - }); - - const agentBeforeDelete = await getAgent({ id: agentId }); - expect(agentBeforeDelete).toBeDefined(); - - await deleteAgent({ id: agentId }); - - const agentAfterDelete = await getAgent({ id: agentId }); - expect(agentAfterDelete).toBeNull(); - }); - - test('should list agents by author', async () => { - const authorId = new mongoose.Types.ObjectId(); - const otherAuthorId = new mongoose.Types.ObjectId(); - - const agentIds = []; - for (let i = 0; i < 5; i++) { - const id = `agent_${uuidv4()}`; - agentIds.push(id); await createAgent({ - id, - name: `Agent ${i}`, + id: agentId, + name: 'Agent To Delete', provider: 'test', model: 'test-model', author: authorId, }); - } - for (let i = 0; i < 3; i++) { + const agentBeforeDelete = await getAgent({ id: agentId }); + expect(agentBeforeDelete).toBeDefined(); + + await deleteAgent({ id: agentId }); + + const agentAfterDelete = await getAgent({ id: agentId }); + expect(agentAfterDelete).toBeNull(); + }); + + test('should list agents by author', async () => { + const authorId = new mongoose.Types.ObjectId(); + const otherAuthorId = new mongoose.Types.ObjectId(); + + const agentIds = []; + for (let i = 0; i < 5; i++) { + const id = `agent_${uuidv4()}`; + agentIds.push(id); + await createAgent({ + id, + name: `Agent ${i}`, + provider: 'test', + model: 'test-model', + author: authorId, + }); + } + + for (let i = 0; i < 3; i++) { + await createAgent({ + id: `other_agent_${uuidv4()}`, + name: `Other Agent ${i}`, + provider: 'test', + model: 'test-model', + author: otherAuthorId, + }); + } + + const result = await getListAgents({ author: authorId.toString() }); + + expect(result).toBeDefined(); + expect(result.data).toBeDefined(); + expect(result.data).toHaveLength(5); + expect(result.has_more).toBe(true); + + for (const agent of result.data) { + expect(agent.author).toBe(authorId.toString()); + } + }); + + test('should update agent projects', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + const projectId1 = new mongoose.Types.ObjectId(); + const projectId2 = new mongoose.Types.ObjectId(); + const projectId3 = new mongoose.Types.ObjectId(); + await createAgent({ - id: `other_agent_${uuidv4()}`, - name: `Other Agent ${i}`, + id: agentId, + name: 'Project Test Agent', provider: 'test', model: 'test-model', - author: otherAuthorId, + author: authorId, + projectIds: [projectId1], }); - } - const result = await getListAgents({ author: authorId.toString() }); + await updateAgent( + { id: agentId }, + { $addToSet: { projectIds: { $each: [projectId2, projectId3] } } }, + ); - expect(result).toBeDefined(); - expect(result.data).toBeDefined(); - expect(result.data).toHaveLength(5); - expect(result.has_more).toBe(true); + await updateAgent({ id: agentId }, { $pull: { projectIds: projectId1 } }); - for (const agent of result.data) { - expect(agent.author).toBe(authorId.toString()); - } - }); + await updateAgent({ id: agentId }, { projectIds: [projectId2, projectId3] }); - test('should update agent projects', async () => { - const agentId = `agent_${uuidv4()}`; - const authorId = new mongoose.Types.ObjectId(); - const projectId1 = new mongoose.Types.ObjectId(); - const projectId2 = new mongoose.Types.ObjectId(); - const projectId3 = new mongoose.Types.ObjectId(); + const updatedAgent = await getAgent({ id: agentId }); + expect(updatedAgent.projectIds).toHaveLength(2); + expect(updatedAgent.projectIds.map((id) => id.toString())).toContain(projectId2.toString()); + expect(updatedAgent.projectIds.map((id) => id.toString())).toContain(projectId3.toString()); + expect(updatedAgent.projectIds.map((id) => id.toString())).not.toContain( + projectId1.toString(), + ); - await createAgent({ - id: agentId, - name: 'Project Test Agent', - provider: 'test', - model: 'test-model', - author: authorId, - projectIds: [projectId1], + await updateAgent({ id: agentId }, { projectIds: [] }); + + const emptyProjectsAgent = await getAgent({ id: agentId }); + expect(emptyProjectsAgent.projectIds).toHaveLength(0); + + const nonExistentId = `agent_${uuidv4()}`; + await expect( + updateAgentProjects({ + id: nonExistentId, + projectIds: [projectId1], + }), + ).rejects.toThrow(); }); - await updateAgent( - { id: agentId }, - { $addToSet: { projectIds: { $each: [projectId2, projectId3] } } }, - ); + test('should handle ephemeral agent loading', async () => { + const agentId = 'ephemeral_test'; + const endpoint = 'openai'; - await updateAgent({ id: agentId }, { $pull: { projectIds: projectId1 } }); + const originalModule = jest.requireActual('librechat-data-provider'); - await updateAgent({ id: agentId }, { projectIds: [projectId2, projectId3] }); + const mockDataProvider = { + ...originalModule, + Constants: { + ...originalModule.Constants, + EPHEMERAL_AGENT_ID: 'ephemeral_test', + }, + }; - const updatedAgent = await getAgent({ id: agentId }); - expect(updatedAgent.projectIds).toHaveLength(2); - expect(updatedAgent.projectIds.map((id) => id.toString())).toContain(projectId2.toString()); - expect(updatedAgent.projectIds.map((id) => id.toString())).toContain(projectId3.toString()); - expect(updatedAgent.projectIds.map((id) => id.toString())).not.toContain(projectId1.toString()); + jest.doMock('librechat-data-provider', () => mockDataProvider); - await updateAgent({ id: agentId }, { projectIds: [] }); + expect(agentId).toBeDefined(); + expect(endpoint).toBeDefined(); - const emptyProjectsAgent = await getAgent({ id: agentId }); - expect(emptyProjectsAgent.projectIds).toHaveLength(0); + jest.dontMock('librechat-data-provider'); + }); - const nonExistentId = `agent_${uuidv4()}`; - await expect( - updateAgentProjects({ - id: nonExistentId, - projectIds: [projectId1], - }), - ).rejects.toThrow(); + test('should handle loadAgent functionality and errors', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Load Agent', + provider: 'test', + model: 'test-model', + author: authorId, + tools: ['tool1', 'tool2'], + }); + + const agent = await getAgent({ id: agentId }); + + expect(agent).toBeDefined(); + expect(agent.id).toBe(agentId); + expect(agent.name).toBe('Test Load Agent'); + expect(agent.tools).toEqual(expect.arrayContaining(['tool1', 'tool2'])); + + const mockLoadAgent = jest.fn().mockResolvedValue(agent); + const loadedAgent = await mockLoadAgent(); + expect(loadedAgent).toBeDefined(); + expect(loadedAgent.id).toBe(agentId); + + const nonExistentId = `agent_${uuidv4()}`; + const nonExistentAgent = await getAgent({ id: nonExistentId }); + expect(nonExistentAgent).toBeNull(); + + const mockLoadAgentError = jest.fn().mockRejectedValue(new Error('No agent found with ID')); + await expect(mockLoadAgentError()).rejects.toThrow('No agent found with ID'); + }); + + describe('Edge Cases', () => { + test.each([ + { + name: 'getAgent with undefined search parameters', + fn: () => getAgent(undefined), + expected: null, + }, + { + name: 'deleteAgent with non-existent agent', + fn: () => deleteAgent({ id: 'non-existent' }), + expected: null, + }, + ])('$name should return null', async ({ fn, expected }) => { + const result = await fn(); + expect(result).toBe(expected); + }); + + test('should handle getListAgents with invalid author format', async () => { + try { + const result = await getListAgents({ author: 'invalid-object-id' }); + expect(result.data).toEqual([]); + } catch (error) { + expect(error).toBeDefined(); + } + }); + + test('should handle getListAgents with no agents', async () => { + const authorId = new mongoose.Types.ObjectId(); + const result = await getListAgents({ author: authorId.toString() }); + + expect(result).toBeDefined(); + expect(result.data).toEqual([]); + expect(result.has_more).toBe(false); + expect(result.first_id).toBeNull(); + expect(result.last_id).toBeNull(); + }); + + test('should handle updateAgentProjects with non-existent agent', async () => { + const nonExistentId = `agent_${uuidv4()}`; + const userId = new mongoose.Types.ObjectId(); + const projectId = new mongoose.Types.ObjectId(); + + const result = await updateAgentProjects({ + user: { id: userId.toString() }, + agentId: nonExistentId, + projectIds: [projectId.toString()], + }); + + expect(result).toBeNull(); + }); + }); }); - test('should handle ephemeral agent loading', async () => { - const agentId = 'ephemeral_test'; - const endpoint = 'openai'; + describe('Agent Version History', () => { + let mongoServer; - const originalModule = jest.requireActual('librechat-data-provider'); + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema); + await mongoose.connect(mongoUri); + }, 20000); - const mockDataProvider = { - ...originalModule, - Constants: { - ...originalModule.Constants, - EPHEMERAL_AGENT_ID: 'ephemeral_test', - }, - }; + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); + }); - jest.doMock('librechat-data-provider', () => mockDataProvider); + beforeEach(async () => { + await Agent.deleteMany({}); + }); - const mockReq = { - user: { id: 'user123' }, - body: { - promptPrefix: 'This is a test instruction', - ephemeralAgent: { - execute_code: true, - mcp: ['server1', 'server2'], + test('should create an agent with a single entry in versions array', async () => { + const agent = await createBasicAgent(); + + expect(agent.versions).toBeDefined(); + expect(Array.isArray(agent.versions)).toBe(true); + expect(agent.versions).toHaveLength(1); + expect(agent.versions[0].name).toBe('Test Agent'); + expect(agent.versions[0].provider).toBe('test'); + expect(agent.versions[0].model).toBe('test-model'); + }); + + test('should accumulate version history across multiple updates', async () => { + const agentId = `agent_${uuidv4()}`; + const author = new mongoose.Types.ObjectId(); + await createAgent({ + id: agentId, + name: 'First Name', + provider: 'test', + model: 'test-model', + author, + description: 'First description', + }); + + await updateAgent( + { id: agentId }, + { name: 'Second Name', description: 'Second description' }, + ); + await updateAgent({ id: agentId }, { name: 'Third Name', model: 'new-model' }); + const finalAgent = await updateAgent({ id: agentId }, { description: 'Final description' }); + + expect(finalAgent.versions).toBeDefined(); + expect(Array.isArray(finalAgent.versions)).toBe(true); + expect(finalAgent.versions).toHaveLength(4); + + expect(finalAgent.versions[0].name).toBe('First Name'); + expect(finalAgent.versions[0].description).toBe('First description'); + expect(finalAgent.versions[0].model).toBe('test-model'); + + expect(finalAgent.versions[1].name).toBe('Second Name'); + expect(finalAgent.versions[1].description).toBe('Second description'); + expect(finalAgent.versions[1].model).toBe('test-model'); + + expect(finalAgent.versions[2].name).toBe('Third Name'); + expect(finalAgent.versions[2].description).toBe('Second description'); + expect(finalAgent.versions[2].model).toBe('new-model'); + + expect(finalAgent.versions[3].name).toBe('Third Name'); + expect(finalAgent.versions[3].description).toBe('Final description'); + expect(finalAgent.versions[3].model).toBe('new-model'); + + expect(finalAgent.name).toBe('Third Name'); + expect(finalAgent.description).toBe('Final description'); + expect(finalAgent.model).toBe('new-model'); + }); + + test('should not include metadata fields in version history', async () => { + const agentId = `agent_${uuidv4()}`; + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: new mongoose.Types.ObjectId(), + }); + + const updatedAgent = await updateAgent({ id: agentId }, { description: 'New description' }); + + expect(updatedAgent.versions).toHaveLength(2); + expect(updatedAgent.versions[0]._id).toBeUndefined(); + expect(updatedAgent.versions[0].__v).toBeUndefined(); + expect(updatedAgent.versions[0].name).toBe('Test Agent'); + expect(updatedAgent.versions[0].author).toBeUndefined(); + + expect(updatedAgent.versions[1]._id).toBeUndefined(); + expect(updatedAgent.versions[1].__v).toBeUndefined(); + }); + + test('should not recursively include previous versions', async () => { + const agentId = `agent_${uuidv4()}`; + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: new mongoose.Types.ObjectId(), + }); + + await updateAgent({ id: agentId }, { name: 'Updated Name 1' }); + await updateAgent({ id: agentId }, { name: 'Updated Name 2' }); + const finalAgent = await updateAgent({ id: agentId }, { name: 'Updated Name 3' }); + + expect(finalAgent.versions).toHaveLength(4); + + finalAgent.versions.forEach((version) => { + expect(version.versions).toBeUndefined(); + }); + }); + + test('should handle MongoDB operators and field updates correctly', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + const projectId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'MongoDB Operator Test', + provider: 'test', + model: 'test-model', + author: authorId, + tools: ['tool1'], + }); + + await updateAgent( + { id: agentId }, + { + description: 'Updated description', + $push: { tools: 'tool2' }, + $addToSet: { projectIds: projectId }, }, - }, - app: { - locals: { - availableTools: { - tool__server1: {}, - tool__server2: {}, - another_tool: {}, + ); + + const firstUpdate = await getAgent({ id: agentId }); + expect(firstUpdate.description).toBe('Updated description'); + expect(firstUpdate.tools).toContain('tool1'); + expect(firstUpdate.tools).toContain('tool2'); + expect(firstUpdate.projectIds.map((id) => id.toString())).toContain(projectId.toString()); + expect(firstUpdate.versions).toHaveLength(2); + + await updateAgent( + { id: agentId }, + { + tools: ['tool2', 'tool3'], + }, + ); + + const secondUpdate = await getAgent({ id: agentId }); + expect(secondUpdate.tools).toHaveLength(2); + expect(secondUpdate.tools).toContain('tool2'); + expect(secondUpdate.tools).toContain('tool3'); + expect(secondUpdate.tools).not.toContain('tool1'); + expect(secondUpdate.versions).toHaveLength(3); + + await updateAgent( + { id: agentId }, + { + $push: { tools: 'tool3' }, + }, + ); + + const thirdUpdate = await getAgent({ id: agentId }); + const toolCount = thirdUpdate.tools.filter((t) => t === 'tool3').length; + expect(toolCount).toBe(2); + expect(thirdUpdate.versions).toHaveLength(4); + }); + + test('should handle parameter objects correctly', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Parameters Test', + provider: 'test', + model: 'test-model', + author: authorId, + model_parameters: { temperature: 0.7 }, + }); + + const updatedAgent = await updateAgent( + { id: agentId }, + { model_parameters: { temperature: 0.8 } }, + ); + + expect(updatedAgent.versions).toHaveLength(2); + expect(updatedAgent.model_parameters.temperature).toBe(0.8); + + await updateAgent( + { id: agentId }, + { + model_parameters: { + temperature: 0.8, + max_tokens: 1000, }, }, - }, - }; + ); - const params = { - req: mockReq, - agent_id: agentId, - endpoint, - model_parameters: { - model: 'gpt-4', - temperature: 0.7, - }, - }; + const complexAgent = await getAgent({ id: agentId }); + expect(complexAgent.versions).toHaveLength(3); + expect(complexAgent.model_parameters.temperature).toBe(0.8); + expect(complexAgent.model_parameters.max_tokens).toBe(1000); - expect(agentId).toBeDefined(); - expect(endpoint).toBeDefined(); + await updateAgent({ id: agentId }, { model_parameters: {} }); - jest.dontMock('librechat-data-provider'); - }); - - test('should handle loadAgent functionality and errors', async () => { - const agentId = `agent_${uuidv4()}`; - const authorId = new mongoose.Types.ObjectId(); - - await createAgent({ - id: agentId, - name: 'Test Load Agent', - provider: 'test', - model: 'test-model', - author: authorId, - tools: ['tool1', 'tool2'], + const emptyParamsAgent = await getAgent({ id: agentId }); + expect(emptyParamsAgent.versions).toHaveLength(4); + expect(emptyParamsAgent.model_parameters).toEqual({}); }); - const agent = await getAgent({ id: agentId }); + test('should detect duplicate versions and reject updates', async () => { + const originalConsoleError = console.error; + console.error = jest.fn(); - expect(agent).toBeDefined(); - expect(agent.id).toBe(agentId); - expect(agent.name).toBe('Test Load Agent'); - expect(agent.tools).toEqual(expect.arrayContaining(['tool1', 'tool2'])); + try { + const authorId = new mongoose.Types.ObjectId(); + const testCases = generateVersionTestCases(); - const mockLoadAgent = jest.fn().mockResolvedValue(agent); - const loadedAgent = await mockLoadAgent(); - expect(loadedAgent).toBeDefined(); - expect(loadedAgent.id).toBe(agentId); + for (const testCase of testCases) { + const testAgentId = `agent_${uuidv4()}`; - const nonExistentId = `agent_${uuidv4()}`; - const nonExistentAgent = await getAgent({ id: nonExistentId }); - expect(nonExistentAgent).toBeNull(); + await createAgent({ + id: testAgentId, + provider: 'test', + model: 'test-model', + author: authorId, + ...testCase.initial, + }); - const mockLoadAgentError = jest.fn().mockRejectedValue(new Error('No agent found with ID')); - await expect(mockLoadAgentError()).rejects.toThrow('No agent found with ID'); - }); -}); + await updateAgent({ id: testAgentId }, testCase.update); -describe('Agent Version History', () => { - let mongoServer; + let error; + try { + await updateAgent({ id: testAgentId }, testCase.duplicate); + } catch (e) { + error = e; + } - beforeAll(async () => { - mongoServer = await MongoMemoryServer.create(); - const mongoUri = mongoServer.getUri(); - await mongoose.connect(mongoUri); - }); + expect(error).toBeDefined(); + expect(error.message).toContain('Duplicate version'); + expect(error.statusCode).toBe(409); + expect(error.details).toBeDefined(); + expect(error.details.duplicateVersion).toBeDefined(); - afterAll(async () => { - await mongoose.disconnect(); - await mongoServer.stop(); - }); - - beforeEach(async () => { - await Agent.deleteMany({}); - }); - - test('should create an agent with a single entry in versions array', async () => { - const agentId = `agent_${uuidv4()}`; - const agent = await createAgent({ - id: agentId, - name: 'Test Agent', - provider: 'test', - model: 'test-model', - author: new mongoose.Types.ObjectId(), + const agent = await getAgent({ id: testAgentId }); + expect(agent.versions).toHaveLength(2); + } + } finally { + console.error = originalConsoleError; + } }); - expect(agent.versions).toBeDefined(); - expect(Array.isArray(agent.versions)).toBe(true); - expect(agent.versions).toHaveLength(1); - expect(agent.versions[0].name).toBe('Test Agent'); - expect(agent.versions[0].provider).toBe('test'); - expect(agent.versions[0].model).toBe('test-model'); - }); + test('should track updatedBy when a different user updates an agent', async () => { + const agentId = `agent_${uuidv4()}`; + const originalAuthor = new mongoose.Types.ObjectId(); + const updatingUser = new mongoose.Types.ObjectId(); - test('should accumulate version history across multiple updates', async () => { - const agentId = `agent_${uuidv4()}`; - const author = new mongoose.Types.ObjectId(); - await createAgent({ - id: agentId, - name: 'First Name', - provider: 'test', - model: 'test-model', - author, - description: 'First description', + await createAgent({ + id: agentId, + name: 'Original Agent', + provider: 'test', + model: 'test-model', + author: originalAuthor, + description: 'Original description', + }); + + const updatedAgent = await updateAgent( + { id: agentId }, + { name: 'Updated Agent', description: 'Updated description' }, + { updatingUserId: updatingUser.toString() }, + ); + + expect(updatedAgent.versions).toHaveLength(2); + expect(updatedAgent.versions[1].updatedBy.toString()).toBe(updatingUser.toString()); + expect(updatedAgent.author.toString()).toBe(originalAuthor.toString()); }); - await updateAgent({ id: agentId }, { name: 'Second Name', description: 'Second description' }); - await updateAgent({ id: agentId }, { name: 'Third Name', model: 'new-model' }); - const finalAgent = await updateAgent({ id: agentId }, { description: 'Final description' }); + test('should include updatedBy even when the original author updates the agent', async () => { + const agentId = `agent_${uuidv4()}`; + const originalAuthor = new mongoose.Types.ObjectId(); - expect(finalAgent.versions).toBeDefined(); - expect(Array.isArray(finalAgent.versions)).toBe(true); - expect(finalAgent.versions).toHaveLength(4); + await createAgent({ + id: agentId, + name: 'Original Agent', + provider: 'test', + model: 'test-model', + author: originalAuthor, + description: 'Original description', + }); - expect(finalAgent.versions[0].name).toBe('First Name'); - expect(finalAgent.versions[0].description).toBe('First description'); - expect(finalAgent.versions[0].model).toBe('test-model'); + const updatedAgent = await updateAgent( + { id: agentId }, + { name: 'Updated Agent', description: 'Updated description' }, + { updatingUserId: originalAuthor.toString() }, + ); - expect(finalAgent.versions[1].name).toBe('Second Name'); - expect(finalAgent.versions[1].description).toBe('Second description'); - expect(finalAgent.versions[1].model).toBe('test-model'); - - expect(finalAgent.versions[2].name).toBe('Third Name'); - expect(finalAgent.versions[2].description).toBe('Second description'); - expect(finalAgent.versions[2].model).toBe('new-model'); - - expect(finalAgent.versions[3].name).toBe('Third Name'); - expect(finalAgent.versions[3].description).toBe('Final description'); - expect(finalAgent.versions[3].model).toBe('new-model'); - - expect(finalAgent.name).toBe('Third Name'); - expect(finalAgent.description).toBe('Final description'); - expect(finalAgent.model).toBe('new-model'); - }); - - test('should not include metadata fields in version history', async () => { - const agentId = `agent_${uuidv4()}`; - await createAgent({ - id: agentId, - name: 'Test Agent', - provider: 'test', - model: 'test-model', - author: new mongoose.Types.ObjectId(), + expect(updatedAgent.versions).toHaveLength(2); + expect(updatedAgent.versions[1].updatedBy.toString()).toBe(originalAuthor.toString()); + expect(updatedAgent.author.toString()).toBe(originalAuthor.toString()); }); - const updatedAgent = await updateAgent({ id: agentId }, { description: 'New description' }); + test('should track multiple different users updating the same agent', async () => { + const agentId = `agent_${uuidv4()}`; + const originalAuthor = new mongoose.Types.ObjectId(); + const user1 = new mongoose.Types.ObjectId(); + const user2 = new mongoose.Types.ObjectId(); + const user3 = new mongoose.Types.ObjectId(); - expect(updatedAgent.versions).toHaveLength(2); - expect(updatedAgent.versions[0]._id).toBeUndefined(); - expect(updatedAgent.versions[0].__v).toBeUndefined(); - expect(updatedAgent.versions[0].name).toBe('Test Agent'); - expect(updatedAgent.versions[0].author).toBeUndefined(); + await createAgent({ + id: agentId, + name: 'Original Agent', + provider: 'test', + model: 'test-model', + author: originalAuthor, + description: 'Original description', + }); - expect(updatedAgent.versions[1]._id).toBeUndefined(); - expect(updatedAgent.versions[1].__v).toBeUndefined(); - }); + // User 1 makes an update + await updateAgent( + { id: agentId }, + { name: 'Updated by User 1', description: 'First update' }, + { updatingUserId: user1.toString() }, + ); - test('should not recursively include previous versions', async () => { - const agentId = `agent_${uuidv4()}`; - await createAgent({ - id: agentId, - name: 'Test Agent', - provider: 'test', - model: 'test-model', - author: new mongoose.Types.ObjectId(), + // Original author makes an update + await updateAgent( + { id: agentId }, + { description: 'Updated by original author' }, + { updatingUserId: originalAuthor.toString() }, + ); + + // User 2 makes an update + await updateAgent( + { id: agentId }, + { name: 'Updated by User 2', model: 'new-model' }, + { updatingUserId: user2.toString() }, + ); + + // User 3 makes an update + const finalAgent = await updateAgent( + { id: agentId }, + { description: 'Final update by User 3' }, + { updatingUserId: user3.toString() }, + ); + + expect(finalAgent.versions).toHaveLength(5); + expect(finalAgent.author.toString()).toBe(originalAuthor.toString()); + + // Check that each version has the correct updatedBy + expect(finalAgent.versions[0].updatedBy).toBeUndefined(); // Initial creation has no updatedBy + expect(finalAgent.versions[1].updatedBy.toString()).toBe(user1.toString()); + expect(finalAgent.versions[2].updatedBy.toString()).toBe(originalAuthor.toString()); + expect(finalAgent.versions[3].updatedBy.toString()).toBe(user2.toString()); + expect(finalAgent.versions[4].updatedBy.toString()).toBe(user3.toString()); + + // Verify the final state + expect(finalAgent.name).toBe('Updated by User 2'); + expect(finalAgent.description).toBe('Final update by User 3'); + expect(finalAgent.model).toBe('new-model'); }); - await updateAgent({ id: agentId }, { name: 'Updated Name 1' }); - await updateAgent({ id: agentId }, { name: 'Updated Name 2' }); - const finalAgent = await updateAgent({ id: agentId }, { name: 'Updated Name 3' }); + test('should preserve original author during agent restoration', async () => { + const agentId = `agent_${uuidv4()}`; + const originalAuthor = new mongoose.Types.ObjectId(); + const updatingUser = new mongoose.Types.ObjectId(); - expect(finalAgent.versions).toHaveLength(4); + await createAgent({ + id: agentId, + name: 'Original Agent', + provider: 'test', + model: 'test-model', + author: originalAuthor, + description: 'Original description', + }); - finalAgent.versions.forEach((version) => { - expect(version.versions).toBeUndefined(); - }); - }); + await updateAgent( + { id: agentId }, + { name: 'Updated Agent', description: 'Updated description' }, + { updatingUserId: updatingUser.toString() }, + ); - test('should handle MongoDB operators and field updates correctly', async () => { - const agentId = `agent_${uuidv4()}`; - const authorId = new mongoose.Types.ObjectId(); - const projectId = new mongoose.Types.ObjectId(); + const { revertAgentVersion } = require('./Agent'); + const revertedAgent = await revertAgentVersion({ id: agentId }, 0); - await createAgent({ - id: agentId, - name: 'MongoDB Operator Test', - provider: 'test', - model: 'test-model', - author: authorId, - tools: ['tool1'], + expect(revertedAgent.author.toString()).toBe(originalAuthor.toString()); + expect(revertedAgent.name).toBe('Original Agent'); + expect(revertedAgent.description).toBe('Original description'); }); - await updateAgent( - { id: agentId }, - { - description: 'Updated description', - $push: { tools: 'tool2' }, - $addToSet: { projectIds: projectId }, - }, - ); + test('should detect action metadata changes and force version update', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + const actionId = 'testActionId123'; - const firstUpdate = await getAgent({ id: agentId }); - expect(firstUpdate.description).toBe('Updated description'); - expect(firstUpdate.tools).toContain('tool1'); - expect(firstUpdate.tools).toContain('tool2'); - expect(firstUpdate.projectIds.map((id) => id.toString())).toContain(projectId.toString()); - expect(firstUpdate.versions).toHaveLength(2); + // Create agent with actions + await createAgent({ + id: agentId, + name: 'Agent with Actions', + provider: 'test', + model: 'test-model', + author: authorId, + actions: [`test.com_action_${actionId}`], + tools: ['listEvents_action_test.com', 'createEvent_action_test.com'], + }); - await updateAgent( - { id: agentId }, - { - tools: ['tool2', 'tool3'], - }, - ); + // First update with forceVersion should create a version + const firstUpdate = await updateAgent( + { id: agentId }, + { tools: ['listEvents_action_test.com', 'createEvent_action_test.com'] }, + { updatingUserId: authorId.toString(), forceVersion: true }, + ); - const secondUpdate = await getAgent({ id: agentId }); - expect(secondUpdate.tools).toHaveLength(2); - expect(secondUpdate.tools).toContain('tool2'); - expect(secondUpdate.tools).toContain('tool3'); - expect(secondUpdate.tools).not.toContain('tool1'); - expect(secondUpdate.versions).toHaveLength(3); + expect(firstUpdate.versions).toHaveLength(2); - await updateAgent( - { id: agentId }, - { - $push: { tools: 'tool3' }, - }, - ); + // Second update with same data but forceVersion should still create a version + const secondUpdate = await updateAgent( + { id: agentId }, + { tools: ['listEvents_action_test.com', 'createEvent_action_test.com'] }, + { updatingUserId: authorId.toString(), forceVersion: true }, + ); - const thirdUpdate = await getAgent({ id: agentId }); - const toolCount = thirdUpdate.tools.filter((t) => t === 'tool3').length; - expect(toolCount).toBe(2); - expect(thirdUpdate.versions).toHaveLength(4); - }); + expect(secondUpdate.versions).toHaveLength(3); - test('should handle parameter objects correctly', async () => { - const agentId = `agent_${uuidv4()}`; - const authorId = new mongoose.Types.ObjectId(); + // Update without forceVersion and no changes should not create a version + let error; + try { + await updateAgent( + { id: agentId }, + { tools: ['listEvents_action_test.com', 'createEvent_action_test.com'] }, + { updatingUserId: authorId.toString(), forceVersion: false }, + ); + } catch (e) { + error = e; + } - await createAgent({ - id: agentId, - name: 'Parameters Test', - provider: 'test', - model: 'test-model', - author: authorId, - model_parameters: { temperature: 0.7 }, + expect(error).toBeDefined(); + expect(error.message).toContain('Duplicate version'); + expect(error.statusCode).toBe(409); }); - const updatedAgent = await updateAgent( - { id: agentId }, - { model_parameters: { temperature: 0.8 } }, - ); + test('should handle isDuplicateVersion with arrays containing null/undefined values', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); - expect(updatedAgent.versions).toHaveLength(2); - expect(updatedAgent.model_parameters.temperature).toBe(0.8); + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + tools: ['tool1', null, 'tool2', undefined], + }); - await updateAgent( - { id: agentId }, - { - model_parameters: { - temperature: 0.8, - max_tokens: 1000, + // Update with same array but different null/undefined arrangement + const updatedAgent = await updateAgent({ id: agentId }, { tools: ['tool1', 'tool2'] }); + + expect(updatedAgent.versions).toHaveLength(2); + expect(updatedAgent.tools).toEqual(['tool1', 'tool2']); + }); + + test('should handle isDuplicateVersion with empty objects in tool_kwargs', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + tool_kwargs: [ + { tool: 'tool1', config: { setting: 'value' } }, + {}, + { tool: 'tool2', config: {} }, + ], + }); + + // Try to update with reordered but equivalent tool_kwargs + const updatedAgent = await updateAgent( + { id: agentId }, + { + tool_kwargs: [ + { tool: 'tool2', config: {} }, + { tool: 'tool1', config: { setting: 'value' } }, + {}, + ], }, - }, - ); + ); - const complexAgent = await getAgent({ id: agentId }); - expect(complexAgent.versions).toHaveLength(3); - expect(complexAgent.model_parameters.temperature).toBe(0.8); - expect(complexAgent.model_parameters.max_tokens).toBe(1000); + // Should create new version as order matters for arrays + expect(updatedAgent.versions).toHaveLength(2); + }); - await updateAgent({ id: agentId }, { model_parameters: {} }); + test('should handle isDuplicateVersion with mixed primitive and object arrays', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); - const emptyParamsAgent = await getAgent({ id: agentId }); - expect(emptyParamsAgent.versions).toHaveLength(4); - expect(emptyParamsAgent.model_parameters).toEqual({}); + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + mixed_array: [1, 'string', { key: 'value' }, true, null], + }); + + // Update with same values but different types + const updatedAgent = await updateAgent( + { id: agentId }, + { mixed_array: ['1', 'string', { key: 'value' }, 'true', null] }, + ); + + // Should create new version as types differ + expect(updatedAgent.versions).toHaveLength(2); + }); + + test('should handle isDuplicateVersion with deeply nested objects', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + const deepObject = { + level1: { + level2: { + level3: { + level4: { + value: 'deep', + array: [1, 2, { nested: true }], + }, + }, + }, + }, + }; + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + model_parameters: deepObject, + }); + + // First create a version with changes + await updateAgent({ id: agentId }, { description: 'Updated' }); + + // Then try to create duplicate of the original version + await updateAgent( + { id: agentId }, + { + model_parameters: deepObject, + description: undefined, + }, + ); + + // Since we're updating back to the same model_parameters but with a different description, + // it should create a new version + const agent = await getAgent({ id: agentId }); + expect(agent.versions).toHaveLength(3); + }); + + test('should handle version comparison with special field types', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + const projectId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + projectIds: [projectId], + model_parameters: { temperature: 0.7 }, + }); + + // Update with a real field change first + const firstUpdate = await updateAgent({ id: agentId }, { description: 'New description' }); + + expect(firstUpdate.versions).toHaveLength(2); + + // Update with model parameters change + const secondUpdate = await updateAgent( + { id: agentId }, + { model_parameters: { temperature: 0.8 } }, + ); + + expect(secondUpdate.versions).toHaveLength(3); + }); + + describe('Edge Cases', () => { + test('should handle extremely large version history', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Version Test', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + for (let i = 0; i < 20; i++) { + await updateAgent({ id: agentId }, { description: `Version ${i}` }); + } + + const agent = await getAgent({ id: agentId }); + expect(agent.versions).toHaveLength(21); + expect(agent.description).toBe('Version 19'); + }); + + test('should handle revertAgentVersion with invalid version index', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + await expect(revertAgentVersion({ id: agentId }, 5)).rejects.toThrow('Version 5 not found'); + }); + + test('should handle revertAgentVersion with non-existent agent', async () => { + const nonExistentId = `agent_${uuidv4()}`; + + await expect(revertAgentVersion({ id: nonExistentId }, 0)).rejects.toThrow( + 'Agent not found', + ); + }); + + test('should handle updateAgent with empty update object', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + const updatedAgent = await updateAgent({ id: agentId }, {}); + + expect(updatedAgent).toBeDefined(); + expect(updatedAgent.name).toBe('Test Agent'); + expect(updatedAgent.versions).toHaveLength(1); + }); + }); }); - test('should detect duplicate versions and reject updates', async () => { - const originalConsoleError = console.error; - console.error = jest.fn(); + describe('Action Metadata and Hash Generation', () => { + let mongoServer; - try { + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema); + await mongoose.connect(mongoUri); + }, 20000); + + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); + }); + + beforeEach(async () => { + await Agent.deleteMany({}); + }); + + test('should generate consistent hash for same action metadata', async () => { + const actionIds = ['test.com_action_123', 'example.com_action_456']; + const actions = [ + { + action_id: '123', + metadata: { version: '1.0', endpoints: ['GET /api/test'], schema: { type: 'object' } }, + }, + { + action_id: '456', + metadata: { + version: '2.0', + endpoints: ['POST /api/example'], + schema: { type: 'string' }, + }, + }, + ]; + + const hash1 = await generateActionMetadataHash(actionIds, actions); + const hash2 = await generateActionMetadataHash(actionIds, actions); + + expect(hash1).toBe(hash2); + expect(typeof hash1).toBe('string'); + expect(hash1.length).toBe(64); // SHA-256 produces 64 character hex string + }); + + test('should generate different hashes for different action metadata', async () => { + const actionIds = ['test.com_action_123']; + const actions1 = [ + { action_id: '123', metadata: { version: '1.0', endpoints: ['GET /api/test'] } }, + ]; + const actions2 = [ + { action_id: '123', metadata: { version: '2.0', endpoints: ['GET /api/test'] } }, + ]; + + const hash1 = await generateActionMetadataHash(actionIds, actions1); + const hash2 = await generateActionMetadataHash(actionIds, actions2); + + expect(hash1).not.toBe(hash2); + }); + + test('should handle empty action arrays', async () => { + const hash = await generateActionMetadataHash([], []); + expect(hash).toBe(''); + }); + + test('should handle null or undefined action arrays', async () => { + const hash1 = await generateActionMetadataHash(null, []); + const hash2 = await generateActionMetadataHash(undefined, []); + + expect(hash1).toBe(''); + expect(hash2).toBe(''); + }); + + test('should handle missing action metadata gracefully', async () => { + const actionIds = ['test.com_action_123', 'missing.com_action_999']; + const actions = [ + { action_id: '123', metadata: { version: '1.0' } }, + // missing action with id '999' + ]; + + const hash = await generateActionMetadataHash(actionIds, actions); + expect(typeof hash).toBe('string'); + expect(hash.length).toBe(64); + }); + + test('should sort action IDs for consistent hashing', async () => { + const actionIds1 = ['b.com_action_2', 'a.com_action_1']; + const actionIds2 = ['a.com_action_1', 'b.com_action_2']; + const actions = [ + { action_id: '1', metadata: { version: '1.0' } }, + { action_id: '2', metadata: { version: '2.0' } }, + ]; + + const hash1 = await generateActionMetadataHash(actionIds1, actions); + const hash2 = await generateActionMetadataHash(actionIds2, actions); + + expect(hash1).toBe(hash2); + }); + + test('should handle complex nested metadata objects', async () => { + const actionIds = ['complex.com_action_1']; + const actions = [ + { + action_id: '1', + metadata: { + version: '1.0', + schema: { + type: 'object', + properties: { + name: { type: 'string' }, + nested: { + type: 'object', + properties: { + id: { type: 'number' }, + tags: { type: 'array', items: { type: 'string' } }, + }, + }, + }, + }, + endpoints: [ + { path: '/api/test', method: 'GET', params: ['id'] }, + { path: '/api/create', method: 'POST', body: true }, + ], + }, + }, + ]; + + const hash = await generateActionMetadataHash(actionIds, actions); + expect(typeof hash).toBe('string'); + expect(hash.length).toBe(64); + }); + + describe('Edge Cases', () => { + test('should handle generateActionMetadataHash with null metadata', async () => { + const hash = await generateActionMetadataHash( + ['test.com_action_1'], + [{ action_id: '1', metadata: null }], + ); + expect(typeof hash).toBe('string'); + }); + + test('should handle generateActionMetadataHash with deeply nested metadata', async () => { + const deepMetadata = { + level1: { + level2: { + level3: { + level4: { + level5: 'deep value', + array: [1, 2, { nested: true }], + }, + }, + }, + }, + }; + + const hash = await generateActionMetadataHash( + ['test.com_action_1'], + [{ action_id: '1', metadata: deepMetadata }], + ); + + expect(typeof hash).toBe('string'); + expect(hash.length).toBe(64); + }); + + test('should handle generateActionMetadataHash with special characters', async () => { + const specialMetadata = { + unicode: '🚀🎉👍', + symbols: '!@#$%^&*()_+-=[]{}|;:,.<>?', + quotes: 'single\'s and "doubles"', + newlines: 'line1\nline2\r\nline3', + }; + + const hash = await generateActionMetadataHash( + ['test.com_action_1'], + [{ action_id: '1', metadata: specialMetadata }], + ); + + expect(typeof hash).toBe('string'); + expect(hash.length).toBe(64); + }); + }); + }); + + describe('Load Agent Functionality', () => { + let mongoServer; + + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema); + await mongoose.connect(mongoUri); + }, 20000); + + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); + }); + + beforeEach(async () => { + await Agent.deleteMany({}); + }); + + test('should return null when agent_id is not provided', async () => { + const mockReq = { user: { id: 'user123' } }; + const result = await loadAgent({ + req: mockReq, + agent_id: null, + endpoint: 'openai', + model_parameters: { model: 'gpt-4' }, + }); + + expect(result).toBeNull(); + }); + + test('should return null when agent_id is empty string', async () => { + const mockReq = { user: { id: 'user123' } }; + const result = await loadAgent({ + req: mockReq, + agent_id: '', + endpoint: 'openai', + model_parameters: { model: 'gpt-4' }, + }); + + expect(result).toBeNull(); + }); + + test('should test ephemeral agent loading logic', async () => { + const { EPHEMERAL_AGENT_ID } = require('librechat-data-provider').Constants; + + getCachedTools.mockResolvedValue({ + tool1_mcp_server1: {}, + tool2_mcp_server2: {}, + another_tool: {}, + }); + + const mockReq = { + user: { id: 'user123' }, + body: { + promptPrefix: 'Test instructions', + ephemeralAgent: { + execute_code: true, + web_search: true, + mcp: ['server1', 'server2'], + }, + }, + }; + + const result = await loadAgent({ + req: mockReq, + agent_id: EPHEMERAL_AGENT_ID, + endpoint: 'openai', + model_parameters: { model: 'gpt-4', temperature: 0.7 }, + }); + + if (result) { + expect(result.id).toBe(EPHEMERAL_AGENT_ID); + expect(result.instructions).toBe('Test instructions'); + expect(result.provider).toBe('openai'); + expect(result.model).toBe('gpt-4'); + expect(result.model_parameters.temperature).toBe(0.7); + expect(result.tools).toContain('execute_code'); + expect(result.tools).toContain('web_search'); + expect(result.tools).toContain('tool1_mcp_server1'); + expect(result.tools).toContain('tool2_mcp_server2'); + } else { + expect(result).toBeNull(); + } + }); + + test('should return null for non-existent agent', async () => { + const mockReq = { user: { id: 'user123' } }; + const result = await loadAgent({ + req: mockReq, + agent_id: 'non_existent_agent', + endpoint: 'openai', + model_parameters: { model: 'gpt-4' }, + }); + + expect(result).toBeNull(); + }); + + test('should load agent when user is the author', async () => { + const userId = new mongoose.Types.ObjectId(); + const agentId = `agent_${uuidv4()}`; + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'openai', + model: 'gpt-4', + author: userId, + description: 'Test description', + tools: ['web_search'], + }); + + const mockReq = { user: { id: userId.toString() } }; + const result = await loadAgent({ + req: mockReq, + agent_id: agentId, + endpoint: 'openai', + model_parameters: { model: 'gpt-4' }, + }); + + expect(result).toBeDefined(); + expect(result.id).toBe(agentId); + expect(result.name).toBe('Test Agent'); + expect(result.author.toString()).toBe(userId.toString()); + expect(result.version).toBe(1); + }); + + test('should return null when user is not author and agent has no projectIds', async () => { + const authorId = new mongoose.Types.ObjectId(); + const userId = new mongoose.Types.ObjectId(); + const agentId = `agent_${uuidv4()}`; + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'openai', + model: 'gpt-4', + author: authorId, + }); + + const mockReq = { user: { id: userId.toString() } }; + const result = await loadAgent({ + req: mockReq, + agent_id: agentId, + endpoint: 'openai', + model_parameters: { model: 'gpt-4' }, + }); + + expect(result).toBeFalsy(); + }); + + test('should handle ephemeral agent with no MCP servers', async () => { + const { EPHEMERAL_AGENT_ID } = require('librechat-data-provider').Constants; + + getCachedTools.mockResolvedValue({}); + + const mockReq = { + user: { id: 'user123' }, + body: { + promptPrefix: 'Simple instructions', + ephemeralAgent: { + execute_code: false, + web_search: false, + mcp: [], + }, + }, + }; + + const result = await loadAgent({ + req: mockReq, + agent_id: EPHEMERAL_AGENT_ID, + endpoint: 'openai', + model_parameters: { model: 'gpt-3.5-turbo' }, + }); + + if (result) { + expect(result.tools).toEqual([]); + expect(result.instructions).toBe('Simple instructions'); + } else { + expect(result).toBeFalsy(); + } + }); + + test('should handle ephemeral agent with undefined ephemeralAgent in body', async () => { + const { EPHEMERAL_AGENT_ID } = require('librechat-data-provider').Constants; + + getCachedTools.mockResolvedValue({}); + + const mockReq = { + user: { id: 'user123' }, + body: { + promptPrefix: 'Basic instructions', + }, + }; + + const result = await loadAgent({ + req: mockReq, + agent_id: EPHEMERAL_AGENT_ID, + endpoint: 'openai', + model_parameters: { model: 'gpt-4' }, + }); + + if (result) { + expect(result.tools).toEqual([]); + } else { + expect(result).toBeFalsy(); + } + }); + + describe('Edge Cases', () => { + test('should handle loadAgent with malformed req object', async () => { + const result = await loadAgent({ + req: null, + agent_id: 'test', + endpoint: 'openai', + model_parameters: { model: 'gpt-4' }, + }); + + expect(result).toBeNull(); + }); + + test('should handle ephemeral agent with extremely large tool list', async () => { + const { EPHEMERAL_AGENT_ID } = require('librechat-data-provider').Constants; + + const largeToolList = Array.from({ length: 100 }, (_, i) => `tool_${i}_mcp_server1`); + const availableTools = largeToolList.reduce((acc, tool) => { + acc[tool] = {}; + return acc; + }, {}); + + getCachedTools.mockResolvedValue(availableTools); + + const mockReq = { + user: { id: 'user123' }, + body: { + promptPrefix: 'Test', + ephemeralAgent: { + execute_code: true, + web_search: true, + mcp: ['server1'], + }, + }, + }; + + const result = await loadAgent({ + req: mockReq, + agent_id: EPHEMERAL_AGENT_ID, + endpoint: 'openai', + model_parameters: { model: 'gpt-4' }, + }); + + if (result) { + expect(result.tools.length).toBeGreaterThan(100); + } + }); + + test('should handle loadAgent with agent from different project', async () => { + const authorId = new mongoose.Types.ObjectId(); + const userId = new mongoose.Types.ObjectId(); + const agentId = `agent_${uuidv4()}`; + const projectId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Project Agent', + provider: 'openai', + model: 'gpt-4', + author: authorId, + projectIds: [projectId], + }); + + const mockReq = { user: { id: userId.toString() } }; + const result = await loadAgent({ + req: mockReq, + agent_id: agentId, + endpoint: 'openai', + model_parameters: { model: 'gpt-4' }, + }); + + expect(result).toBeFalsy(); + }); + }); + }); + + describe('Agent Edge Cases and Error Handling', () => { + let mongoServer; + + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema); + await mongoose.connect(mongoUri); + }, 20000); + + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); + }); + + beforeEach(async () => { + await Agent.deleteMany({}); + }); + + test('should handle agent creation with minimal required fields', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + const agent = await createAgent({ + id: agentId, + provider: 'test', + model: 'test-model', + author: authorId, + }); + + expect(agent).toBeDefined(); + expect(agent.id).toBe(agentId); + expect(agent.versions).toHaveLength(1); + expect(agent.versions[0].provider).toBe('test'); + expect(agent.versions[0].model).toBe('test-model'); + }); + + test('should handle agent creation with all optional fields', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + const projectId = new mongoose.Types.ObjectId(); + + const agent = await createAgent({ + id: agentId, + name: 'Complex Agent', + provider: 'test', + model: 'test-model', + author: authorId, + description: 'Complex description', + instructions: 'Complex instructions', + tools: ['tool1', 'tool2'], + actions: ['action1', 'action2'], + model_parameters: { temperature: 0.8, max_tokens: 1000 }, + projectIds: [projectId], + avatar: 'https://example.com/avatar.png', + isCollaborative: true, + tool_resources: { + file_search: { file_ids: ['file1', 'file2'] }, + }, + }); + + expect(agent).toBeDefined(); + expect(agent.name).toBe('Complex Agent'); + expect(agent.description).toBe('Complex description'); + expect(agent.instructions).toBe('Complex instructions'); + expect(agent.tools).toEqual(['tool1', 'tool2']); + expect(agent.actions).toEqual(['action1', 'action2']); + expect(agent.model_parameters.temperature).toBe(0.8); + expect(agent.model_parameters.max_tokens).toBe(1000); + expect(agent.projectIds.map((id) => id.toString())).toContain(projectId.toString()); + expect(agent.avatar).toBe('https://example.com/avatar.png'); + expect(agent.isCollaborative).toBe(true); + expect(agent.tool_resources.file_search.file_ids).toEqual(['file1', 'file2']); + }); + + test('should handle updateAgent with empty update object', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + const updatedAgent = await updateAgent({ id: agentId }, {}); + + expect(updatedAgent).toBeDefined(); + expect(updatedAgent.name).toBe('Test Agent'); + expect(updatedAgent.versions).toHaveLength(1); // No new version should be created + }); + + test('should handle concurrent updates to different agents', async () => { + const agent1Id = `agent_${uuidv4()}`; + const agent2Id = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agent1Id, + name: 'Agent 1', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + await createAgent({ + id: agent2Id, + name: 'Agent 2', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + // Concurrent updates to different agents + const [updated1, updated2] = await Promise.all([ + updateAgent({ id: agent1Id }, { description: 'Updated Agent 1' }), + updateAgent({ id: agent2Id }, { description: 'Updated Agent 2' }), + ]); + + expect(updated1.description).toBe('Updated Agent 1'); + expect(updated2.description).toBe('Updated Agent 2'); + expect(updated1.versions).toHaveLength(2); + expect(updated2.versions).toHaveLength(2); + }); + + test('should handle agent deletion with non-existent ID', async () => { + const nonExistentId = `agent_${uuidv4()}`; + const result = await deleteAgent({ id: nonExistentId }); + + expect(result).toBeNull(); + }); + + test('should handle getListAgents with no agents', async () => { + const authorId = new mongoose.Types.ObjectId(); + const result = await getListAgents({ author: authorId.toString() }); + + expect(result).toBeDefined(); + expect(result.data).toEqual([]); + expect(result.has_more).toBe(false); + expect(result.first_id).toBeNull(); + expect(result.last_id).toBeNull(); + }); + + test('should handle updateAgent with MongoDB operators mixed with direct updates', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + tools: ['tool1'], + }); + + // Test with $push and direct field update + const updatedAgent = await updateAgent( + { id: agentId }, + { + name: 'Updated Name', + $push: { tools: 'tool2' }, + }, + ); + + expect(updatedAgent.name).toBe('Updated Name'); + expect(updatedAgent.tools).toContain('tool1'); + expect(updatedAgent.tools).toContain('tool2'); + expect(updatedAgent.versions).toHaveLength(2); + }); + + test('should handle revertAgentVersion with invalid version index', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + // Try to revert to non-existent version + await expect(revertAgentVersion({ id: agentId }, 5)).rejects.toThrow('Version 5 not found'); + }); + + test('should handle revertAgentVersion with non-existent agent', async () => { + const nonExistentId = `agent_${uuidv4()}`; + + await expect(revertAgentVersion({ id: nonExistentId }, 0)).rejects.toThrow('Agent not found'); + }); + + test('should handle addAgentResourceFile with non-existent agent', async () => { + const nonExistentId = `agent_${uuidv4()}`; + const mockReq = { user: { id: 'user123' } }; + + await expect( + addAgentResourceFile({ + req: mockReq, + agent_id: nonExistentId, + tool_resource: 'file_search', + file_id: 'file123', + }), + ).rejects.toThrow('Agent not found for adding resource file'); + }); + + test('should handle removeAgentResourceFiles with non-existent agent', async () => { + const nonExistentId = `agent_${uuidv4()}`; + + await expect( + removeAgentResourceFiles({ + agent_id: nonExistentId, + files: [{ tool_resource: 'file_search', file_id: 'file123' }], + }), + ).rejects.toThrow('Agent not found for removing resource files'); + }); + + test('should handle updateAgent with complex nested updates', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + model_parameters: { temperature: 0.5 }, + tools: ['tool1'], + }); + + // First update with $push operation + const firstUpdate = await updateAgent( + { id: agentId }, + { + $push: { tools: 'tool2' }, + }, + ); + + expect(firstUpdate.tools).toContain('tool1'); + expect(firstUpdate.tools).toContain('tool2'); + + // Second update with direct field update and $addToSet + const secondUpdate = await updateAgent( + { id: agentId }, + { + name: 'Updated Agent', + model_parameters: { temperature: 0.8, max_tokens: 500 }, + $addToSet: { tools: 'tool3' }, + }, + ); + + expect(secondUpdate.name).toBe('Updated Agent'); + expect(secondUpdate.model_parameters.temperature).toBe(0.8); + expect(secondUpdate.model_parameters.max_tokens).toBe(500); + expect(secondUpdate.tools).toContain('tool1'); + expect(secondUpdate.tools).toContain('tool2'); + expect(secondUpdate.tools).toContain('tool3'); + expect(secondUpdate.versions).toHaveLength(3); + }); + + test('should preserve version order in versions array', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Version 1', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + await updateAgent({ id: agentId }, { name: 'Version 2' }); + await updateAgent({ id: agentId }, { name: 'Version 3' }); + const finalAgent = await updateAgent({ id: agentId }, { name: 'Version 4' }); + + expect(finalAgent.versions).toHaveLength(4); + expect(finalAgent.versions[0].name).toBe('Version 1'); + expect(finalAgent.versions[1].name).toBe('Version 2'); + expect(finalAgent.versions[2].name).toBe('Version 3'); + expect(finalAgent.versions[3].name).toBe('Version 4'); + expect(finalAgent.name).toBe('Version 4'); + }); + + test('should handle updateAgentProjects error scenarios', async () => { + const nonExistentId = `agent_${uuidv4()}`; + const userId = new mongoose.Types.ObjectId(); + const projectId = new mongoose.Types.ObjectId(); + + // Test with non-existent agent + const result = await updateAgentProjects({ + user: { id: userId.toString() }, + agentId: nonExistentId, + projectIds: [projectId.toString()], + }); + + expect(result).toBeNull(); + }); + + test('should handle revertAgentVersion properly', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Original Name', + provider: 'test', + model: 'test-model', + author: authorId, + description: 'Original description', + }); + + await updateAgent( + { id: agentId }, + { name: 'Updated Name', description: 'Updated description' }, + ); + + const revertedAgent = await revertAgentVersion({ id: agentId }, 0); + + expect(revertedAgent.name).toBe('Original Name'); + expect(revertedAgent.description).toBe('Original description'); + expect(revertedAgent.author.toString()).toBe(authorId.toString()); + }); + + test('should handle action-related updates with getActions error', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + // Create agent with actions that might cause getActions to fail + await createAgent({ + id: agentId, + name: 'Agent with Actions', + provider: 'test', + model: 'test-model', + author: authorId, + actions: ['test.com_action_invalid_id'], + }); + + // Update should still work even if getActions fails + const updatedAgent = await updateAgent( + { id: agentId }, + { description: 'Updated description' }, + ); + + expect(updatedAgent).toBeDefined(); + expect(updatedAgent.description).toBe('Updated description'); + expect(updatedAgent.versions).toHaveLength(2); + }); + + test('should handle updateAgent with combined MongoDB operators', async () => { const agentId = `agent_${uuidv4()}`; const authorId = new mongoose.Types.ObjectId(); const projectId1 = new mongoose.Types.ObjectId(); const projectId2 = new mongoose.Types.ObjectId(); - const testCases = [ + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + tools: ['tool1'], + projectIds: [projectId1], + }); + + // Use multiple operators in single update - but avoid conflicting operations on same field + const updatedAgent = await updateAgent( + { id: agentId }, { - name: 'simple field update', - initial: { - name: 'Test Agent', - description: 'Initial description', - }, - update: { name: 'Updated Name' }, - duplicate: { name: 'Updated Name' }, + name: 'Updated Name', + $push: { tools: 'tool2' }, + $addToSet: { projectIds: projectId2 }, }, + ); + + const finalAgent = await updateAgent( + { id: agentId }, { - name: 'object field update', - initial: { - model_parameters: { temperature: 0.7 }, - }, - update: { model_parameters: { temperature: 0.8 } }, - duplicate: { model_parameters: { temperature: 0.8 } }, - }, - { - name: 'array field update', - initial: { - tools: ['tool1', 'tool2'], - }, - update: { tools: ['tool2', 'tool3'] }, - duplicate: { tools: ['tool2', 'tool3'] }, - }, - { - name: 'projectIds update', - initial: { - projectIds: [projectId1], - }, - update: { projectIds: [projectId1, projectId2] }, - duplicate: { projectIds: [projectId2, projectId1] }, + $pull: { projectIds: projectId1 }, }, + ); + + expect(updatedAgent).toBeDefined(); + expect(updatedAgent.name).toBe('Updated Name'); + expect(updatedAgent.tools).toContain('tool1'); + expect(updatedAgent.tools).toContain('tool2'); + expect(updatedAgent.projectIds.map((id) => id.toString())).toContain(projectId2.toString()); + + expect(finalAgent).toBeDefined(); + expect(finalAgent.projectIds.map((id) => id.toString())).not.toContain(projectId1.toString()); + expect(finalAgent.versions).toHaveLength(3); + }); + + test('should handle updateAgent when agent does not exist', async () => { + const nonExistentId = `agent_${uuidv4()}`; + + const result = await updateAgent({ id: nonExistentId }, { name: 'New Name' }); + + expect(result).toBeNull(); + }); + + test('should handle concurrent updates with database errors', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + // Mock findOneAndUpdate to simulate database error + const cleanup = mockFindOneAndUpdateError(2); + + // Concurrent updates where one fails + const promises = [ + updateAgent({ id: agentId }, { name: 'Update 1' }), + updateAgent({ id: agentId }, { name: 'Update 2' }), + updateAgent({ id: agentId }, { name: 'Update 3' }), ]; - for (const testCase of testCases) { - const testAgentId = `agent_${uuidv4()}`; + const results = await Promise.allSettled(promises); - await createAgent({ - id: testAgentId, - provider: 'test', - model: 'test-model', - author: authorId, - ...testCase.initial, + cleanup(); + + const succeeded = results.filter((r) => r.status === 'fulfilled').length; + const failed = results.filter((r) => r.status === 'rejected').length; + + expect(succeeded).toBe(2); + expect(failed).toBe(1); + }); + + test('should handle removeAgentResourceFiles when agent is deleted during operation', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + tool_resources: { + file_search: { + file_ids: ['file1', 'file2', 'file3'], + }, + }, + }); + + // Mock findOneAndUpdate to return null (simulating deletion) + const originalFindOneAndUpdate = Agent.findOneAndUpdate; + Agent.findOneAndUpdate = jest.fn().mockImplementation(() => ({ + lean: jest.fn().mockResolvedValue(null), + })); + + // Try to remove files from deleted agent + await expect( + removeAgentResourceFiles({ + agent_id: agentId, + files: [ + { tool_resource: 'file_search', file_id: 'file1' }, + { tool_resource: 'file_search', file_id: 'file2' }, + ], + }), + ).rejects.toThrow('Failed to update agent during file removal (pull step)'); + + Agent.findOneAndUpdate = originalFindOneAndUpdate; + }); + + test('should handle loadEphemeralAgent with malformed MCP tool names', async () => { + const { EPHEMERAL_AGENT_ID } = require('librechat-data-provider').Constants; + + getCachedTools.mockResolvedValue({ + malformed_tool_name: {}, // No mcp delimiter + tool__server1: {}, // Wrong delimiter + tool_mcp_server1: {}, // Correct format + tool_mcp_server2: {}, // Different server + }); + + const mockReq = { + user: { id: 'user123' }, + body: { + promptPrefix: 'Test instructions', + ephemeralAgent: { + execute_code: false, + web_search: false, + mcp: ['server1'], + }, + }, + }; + + const result = await loadAgent({ + req: mockReq, + agent_id: EPHEMERAL_AGENT_ID, + endpoint: 'openai', + model_parameters: { model: 'gpt-4' }, + }); + + if (result) { + expect(result.tools).toEqual(['tool_mcp_server1']); + expect(result.tools).not.toContain('malformed_tool_name'); + expect(result.tools).not.toContain('tool__server1'); + expect(result.tools).not.toContain('tool_mcp_server2'); + } + }); + + test('should handle addAgentResourceFile when array initialization fails', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + // Mock the updateOne operation to fail but let updateAgent succeed + const originalUpdateOne = Agent.updateOne; + let updateOneCalled = false; + Agent.updateOne = jest.fn().mockImplementation((...args) => { + if (!updateOneCalled) { + updateOneCalled = true; + return Promise.reject(new Error('Database error')); + } + return originalUpdateOne.apply(Agent, args); + }); + + try { + const result = await addAgentResourceFile({ + agent_id: agentId, + tool_resource: 'new_tool', + file_id: 'file123', }); - await updateAgent({ id: testAgentId }, testCase.update); - - let error; - try { - await updateAgent({ id: testAgentId }, testCase.duplicate); - } catch (e) { - error = e; - } - - expect(error).toBeDefined(); - expect(error.message).toContain('Duplicate version'); - expect(error.statusCode).toBe(409); - expect(error.details).toBeDefined(); - expect(error.details.duplicateVersion).toBeDefined(); - - const agent = await getAgent({ id: testAgentId }); - expect(agent.versions).toHaveLength(2); + expect(result).toBeDefined(); + expect(result.tools).toContain('new_tool'); + } catch (error) { + expect(error.message).toBe('Database error'); } - } finally { - console.error = originalConsoleError; - } + + Agent.updateOne = originalUpdateOne; + }); }); - test('should track updatedBy when a different user updates an agent', async () => { - const agentId = `agent_${uuidv4()}`; - const originalAuthor = new mongoose.Types.ObjectId(); - const updatingUser = new mongoose.Types.ObjectId(); + describe('Agent IDs Field in Version Detection', () => { + let mongoServer; - await createAgent({ - id: agentId, - name: 'Original Agent', - provider: 'test', - model: 'test-model', - author: originalAuthor, - description: 'Original description', + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema); + await mongoose.connect(mongoUri); + }, 20000); + + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); }); - const updatedAgent = await updateAgent( - { id: agentId }, - { name: 'Updated Agent', description: 'Updated description' }, - { updatingUserId: updatingUser.toString() }, - ); - - expect(updatedAgent.versions).toHaveLength(2); - expect(updatedAgent.versions[1].updatedBy.toString()).toBe(updatingUser.toString()); - expect(updatedAgent.author.toString()).toBe(originalAuthor.toString()); - }); - - test('should include updatedBy even when the original author updates the agent', async () => { - const agentId = `agent_${uuidv4()}`; - const originalAuthor = new mongoose.Types.ObjectId(); - - await createAgent({ - id: agentId, - name: 'Original Agent', - provider: 'test', - model: 'test-model', - author: originalAuthor, - description: 'Original description', + beforeEach(async () => { + await Agent.deleteMany({}); }); - const updatedAgent = await updateAgent( - { id: agentId }, - { name: 'Updated Agent', description: 'Updated description' }, - { updatingUserId: originalAuthor.toString() }, - ); + test('should now create new version when agent_ids field changes', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); - expect(updatedAgent.versions).toHaveLength(2); - expect(updatedAgent.versions[1].updatedBy.toString()).toBe(originalAuthor.toString()); - expect(updatedAgent.author.toString()).toBe(originalAuthor.toString()); - }); + const agent = await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + agent_ids: ['agent1', 'agent2'], + }); - test('should track multiple different users updating the same agent', async () => { - const agentId = `agent_${uuidv4()}`; - const originalAuthor = new mongoose.Types.ObjectId(); - const user1 = new mongoose.Types.ObjectId(); - const user2 = new mongoose.Types.ObjectId(); - const user3 = new mongoose.Types.ObjectId(); + expect(agent).toBeDefined(); + expect(agent.versions).toHaveLength(1); - await createAgent({ - id: agentId, - name: 'Original Agent', - provider: 'test', - model: 'test-model', - author: originalAuthor, - description: 'Original description', - }); - - // User 1 makes an update - await updateAgent( - { id: agentId }, - { name: 'Updated by User 1', description: 'First update' }, - { updatingUserId: user1.toString() }, - ); - - // Original author makes an update - await updateAgent( - { id: agentId }, - { description: 'Updated by original author' }, - { updatingUserId: originalAuthor.toString() }, - ); - - // User 2 makes an update - await updateAgent( - { id: agentId }, - { name: 'Updated by User 2', model: 'new-model' }, - { updatingUserId: user2.toString() }, - ); - - // User 3 makes an update - const finalAgent = await updateAgent( - { id: agentId }, - { description: 'Final update by User 3' }, - { updatingUserId: user3.toString() }, - ); - - expect(finalAgent.versions).toHaveLength(5); - expect(finalAgent.author.toString()).toBe(originalAuthor.toString()); - - // Check that each version has the correct updatedBy - expect(finalAgent.versions[0].updatedBy).toBeUndefined(); // Initial creation has no updatedBy - expect(finalAgent.versions[1].updatedBy.toString()).toBe(user1.toString()); - expect(finalAgent.versions[2].updatedBy.toString()).toBe(originalAuthor.toString()); - expect(finalAgent.versions[3].updatedBy.toString()).toBe(user2.toString()); - expect(finalAgent.versions[4].updatedBy.toString()).toBe(user3.toString()); - - // Verify the final state - expect(finalAgent.name).toBe('Updated by User 2'); - expect(finalAgent.description).toBe('Final update by User 3'); - expect(finalAgent.model).toBe('new-model'); - }); - - test('should preserve original author during agent restoration', async () => { - const agentId = `agent_${uuidv4()}`; - const originalAuthor = new mongoose.Types.ObjectId(); - const updatingUser = new mongoose.Types.ObjectId(); - - await createAgent({ - id: agentId, - name: 'Original Agent', - provider: 'test', - model: 'test-model', - author: originalAuthor, - description: 'Original description', - }); - - await updateAgent( - { id: agentId }, - { name: 'Updated Agent', description: 'Updated description' }, - { updatingUserId: updatingUser.toString() }, - ); - - const { revertAgentVersion } = require('./Agent'); - const revertedAgent = await revertAgentVersion({ id: agentId }, 0); - - expect(revertedAgent.author.toString()).toBe(originalAuthor.toString()); - expect(revertedAgent.name).toBe('Original Agent'); - expect(revertedAgent.description).toBe('Original description'); - }); - - test('should detect action metadata changes and force version update', async () => { - const agentId = `agent_${uuidv4()}`; - const authorId = new mongoose.Types.ObjectId(); - const actionId = 'testActionId123'; - - // Create agent with actions - await createAgent({ - id: agentId, - name: 'Agent with Actions', - provider: 'test', - model: 'test-model', - author: authorId, - actions: [`test.com_action_${actionId}`], - tools: ['listEvents_action_test.com', 'createEvent_action_test.com'], - }); - - // First update with forceVersion should create a version - const firstUpdate = await updateAgent( - { id: agentId }, - { tools: ['listEvents_action_test.com', 'createEvent_action_test.com'] }, - { updatingUserId: authorId.toString(), forceVersion: true }, - ); - - expect(firstUpdate.versions).toHaveLength(2); - - // Second update with same data but forceVersion should still create a version - const secondUpdate = await updateAgent( - { id: agentId }, - { tools: ['listEvents_action_test.com', 'createEvent_action_test.com'] }, - { updatingUserId: authorId.toString(), forceVersion: true }, - ); - - expect(secondUpdate.versions).toHaveLength(3); - - // Update without forceVersion and no changes should not create a version - let error; - try { - await updateAgent( + const updated = await updateAgent( { id: agentId }, - { tools: ['listEvents_action_test.com', 'createEvent_action_test.com'] }, - { updatingUserId: authorId.toString(), forceVersion: false }, + { agent_ids: ['agent1', 'agent2', 'agent3'] }, ); - } catch (e) { - error = e; - } - expect(error).toBeDefined(); - expect(error.message).toContain('Duplicate version'); - expect(error.statusCode).toBe(409); + // Since agent_ids is no longer excluded, this should create a new version + expect(updated.versions).toHaveLength(2); + expect(updated.agent_ids).toEqual(['agent1', 'agent2', 'agent3']); + }); + + test('should detect duplicate version if agent_ids is updated to same value', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + agent_ids: ['agent1', 'agent2'], + }); + + await updateAgent({ id: agentId }, { agent_ids: ['agent1', 'agent2', 'agent3'] }); + + await expect( + updateAgent({ id: agentId }, { agent_ids: ['agent1', 'agent2', 'agent3'] }), + ).rejects.toThrow('Duplicate version'); + }); + + test('should handle agent_ids field alongside other fields', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + description: 'Initial description', + agent_ids: ['agent1'], + }); + + const updated = await updateAgent( + { id: agentId }, + { + agent_ids: ['agent1', 'agent2'], + description: 'Updated description', + }, + ); + + expect(updated.versions).toHaveLength(2); + expect(updated.agent_ids).toEqual(['agent1', 'agent2']); + expect(updated.description).toBe('Updated description'); + + const updated2 = await updateAgent({ id: agentId }, { description: 'Another description' }); + + expect(updated2.versions).toHaveLength(3); + expect(updated2.agent_ids).toEqual(['agent1', 'agent2']); + expect(updated2.description).toBe('Another description'); + }); + + test('should skip version creation when skipVersioning option is used', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + const projectId1 = new mongoose.Types.ObjectId(); + const projectId2 = new mongoose.Types.ObjectId(); + + // Create agent with initial projectIds + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + projectIds: [projectId1], + }); + + // Share agent using updateAgentProjects (which uses skipVersioning) + const shared = await updateAgentProjects({ + user: { id: authorId.toString() }, // Use the same author ID + agentId: agentId, + projectIds: [projectId2.toString()], + }); + + // Should NOT create a new version due to skipVersioning + expect(shared.versions).toHaveLength(1); + expect(shared.projectIds.map((id) => id.toString())).toContain(projectId1.toString()); + expect(shared.projectIds.map((id) => id.toString())).toContain(projectId2.toString()); + + // Unshare agent using updateAgentProjects + const unshared = await updateAgentProjects({ + user: { id: authorId.toString() }, + agentId: agentId, + removeProjectIds: [projectId1.toString()], + }); + + // Still should NOT create a new version + expect(unshared.versions).toHaveLength(1); + expect(unshared.projectIds.map((id) => id.toString())).not.toContain(projectId1.toString()); + expect(unshared.projectIds.map((id) => id.toString())).toContain(projectId2.toString()); + + // Regular update without skipVersioning should create a version + const regularUpdate = await updateAgent( + { id: agentId }, + { description: 'Updated description' }, + ); + + expect(regularUpdate.versions).toHaveLength(2); + expect(regularUpdate.description).toBe('Updated description'); + + // Direct updateAgent with MongoDB operators should still create versions + const directUpdate = await updateAgent( + { id: agentId }, + { $addToSet: { projectIds: { $each: [projectId1] } } }, + ); + + expect(directUpdate.versions).toHaveLength(3); + expect(directUpdate.projectIds.length).toBe(2); + }); + + test('should preserve agent_ids in version history', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + agent_ids: ['agent1'], + }); + + await updateAgent({ id: agentId }, { agent_ids: ['agent1', 'agent2'] }); + + await updateAgent({ id: agentId }, { agent_ids: ['agent3'] }); + + const finalAgent = await getAgent({ id: agentId }); + + expect(finalAgent.versions).toHaveLength(3); + expect(finalAgent.versions[0].agent_ids).toEqual(['agent1']); + expect(finalAgent.versions[1].agent_ids).toEqual(['agent1', 'agent2']); + expect(finalAgent.versions[2].agent_ids).toEqual(['agent3']); + expect(finalAgent.agent_ids).toEqual(['agent3']); + }); + + test('should handle empty agent_ids arrays', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + agent_ids: ['agent1', 'agent2'], + }); + + const updated = await updateAgent({ id: agentId }, { agent_ids: [] }); + + expect(updated.versions).toHaveLength(2); + expect(updated.agent_ids).toEqual([]); + + await expect(updateAgent({ id: agentId }, { agent_ids: [] })).rejects.toThrow( + 'Duplicate version', + ); + }); + + test('should handle agent without agent_ids field', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + const agent = await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + expect(agent.agent_ids).toEqual([]); + + const updated = await updateAgent({ id: agentId }, { agent_ids: ['agent1'] }); + + expect(updated.versions).toHaveLength(2); + expect(updated.agent_ids).toEqual(['agent1']); + }); }); }); + +function createBasicAgent(overrides = {}) { + const defaults = { + id: `agent_${uuidv4()}`, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: new mongoose.Types.ObjectId(), + }; + return createAgent({ ...defaults, ...overrides }); +} + +function createTestIds() { + return { + agentId: `agent_${uuidv4()}`, + authorId: new mongoose.Types.ObjectId(), + projectId: new mongoose.Types.ObjectId(), + fileId: uuidv4(), + }; +} + +function createFileOperations(agentId, fileIds, operation = 'add') { + return fileIds.map((fileId) => + operation === 'add' + ? addAgentResourceFile({ agent_id: agentId, tool_resource: 'test_tool', file_id: fileId }) + : removeAgentResourceFiles({ + agent_id: agentId, + files: [{ tool_resource: 'test_tool', file_id: fileId }], + }), + ); +} + +function mockFindOneAndUpdateError(errorOnCall = 1) { + const original = Agent.findOneAndUpdate; + let callCount = 0; + + Agent.findOneAndUpdate = jest.fn().mockImplementation((...args) => { + callCount++; + if (callCount === errorOnCall) { + throw new Error('Database connection lost'); + } + return original.apply(Agent, args); + }); + + return () => { + Agent.findOneAndUpdate = original; + }; +} + +function generateVersionTestCases() { + const projectId1 = new mongoose.Types.ObjectId(); + const projectId2 = new mongoose.Types.ObjectId(); + + return [ + { + name: 'simple field update', + initial: { + name: 'Test Agent', + description: 'Initial description', + }, + update: { name: 'Updated Name' }, + duplicate: { name: 'Updated Name' }, + }, + { + name: 'object field update', + initial: { + model_parameters: { temperature: 0.7 }, + }, + update: { model_parameters: { temperature: 0.8 } }, + duplicate: { model_parameters: { temperature: 0.8 } }, + }, + { + name: 'array field update', + initial: { + tools: ['tool1', 'tool2'], + }, + update: { tools: ['tool2', 'tool3'] }, + duplicate: { tools: ['tool2', 'tool3'] }, + }, + { + name: 'projectIds update', + initial: { + projectIds: [projectId1], + }, + update: { projectIds: [projectId1, projectId2] }, + duplicate: { projectIds: [projectId2, projectId1] }, + }, + ]; +} diff --git a/api/models/Assistant.js b/api/models/Assistant.js index a8a5b98157..be94d35d7d 100644 --- a/api/models/Assistant.js +++ b/api/models/Assistant.js @@ -1,7 +1,4 @@ -const mongoose = require('mongoose'); -const { assistantSchema } = require('@librechat/data-schemas'); - -const Assistant = mongoose.model('assistant', assistantSchema); +const { Assistant } = require('~/db/models'); /** * Update an assistant with new data without overwriting existing properties, diff --git a/api/models/Balance.js b/api/models/Balance.js deleted file mode 100644 index 226f6ef508..0000000000 --- a/api/models/Balance.js +++ /dev/null @@ -1,4 +0,0 @@ -const mongoose = require('mongoose'); -const { balanceSchema } = require('@librechat/data-schemas'); - -module.exports = mongoose.model('Balance', balanceSchema); diff --git a/api/models/Banner.js b/api/models/Banner.js index 399a8e72ee..42ad1599ed 100644 --- a/api/models/Banner.js +++ b/api/models/Banner.js @@ -1,8 +1,5 @@ -const mongoose = require('mongoose'); -const logger = require('~/config/winston'); -const { bannerSchema } = require('@librechat/data-schemas'); - -const Banner = mongoose.model('Banner', bannerSchema); +const { logger } = require('@librechat/data-schemas'); +const { Banner } = require('~/db/models'); /** * Retrieves the current active banner. @@ -28,4 +25,4 @@ const getBanner = async (user) => { } }; -module.exports = { Banner, getBanner }; +module.exports = { getBanner }; diff --git a/api/models/Config.js b/api/models/Config.js deleted file mode 100644 index fefb84b8f9..0000000000 --- a/api/models/Config.js +++ /dev/null @@ -1,86 +0,0 @@ -const mongoose = require('mongoose'); -const { logger } = require('~/config'); - -const major = [0, 0]; -const minor = [0, 0]; -const patch = [0, 5]; - -const configSchema = mongoose.Schema( - { - tag: { - type: String, - required: true, - validate: { - validator: function (tag) { - const [part1, part2, part3] = tag.replace('v', '').split('.').map(Number); - - // Check if all parts are numbers - if (isNaN(part1) || isNaN(part2) || isNaN(part3)) { - return false; - } - - // Check if all parts are within their respective ranges - if (part1 < major[0] || part1 > major[1]) { - return false; - } - if (part2 < minor[0] || part2 > minor[1]) { - return false; - } - if (part3 < patch[0] || part3 > patch[1]) { - return false; - } - return true; - }, - message: 'Invalid tag value', - }, - }, - searchEnabled: { - type: Boolean, - default: false, - }, - usersEnabled: { - type: Boolean, - default: false, - }, - startupCounts: { - type: Number, - default: 0, - }, - }, - { timestamps: true }, -); - -// Instance method -configSchema.methods.incrementCount = function () { - this.startupCounts += 1; -}; - -// Static methods -configSchema.statics.findByTag = async function (tag) { - return await this.findOne({ tag }).lean(); -}; - -configSchema.statics.updateByTag = async function (tag, update) { - return await this.findOneAndUpdate({ tag }, update, { new: true }); -}; - -const Config = mongoose.models.Config || mongoose.model('Config', configSchema); - -module.exports = { - getConfigs: async (filter) => { - try { - return await Config.find(filter).lean(); - } catch (error) { - logger.error('Error getting configs', error); - return { config: 'Error getting configs' }; - } - }, - deleteConfigs: async (filter) => { - try { - return await Config.deleteMany(filter); - } catch (error) { - logger.error('Error deleting configs', error); - return { config: 'Error deleting configs' }; - } - }, -}; diff --git a/api/models/Conversation.js b/api/models/Conversation.js index 51081a6491..698762d43d 100644 --- a/api/models/Conversation.js +++ b/api/models/Conversation.js @@ -1,6 +1,8 @@ -const Conversation = require('./schema/convoSchema'); +const { logger } = require('@librechat/data-schemas'); +const { createTempChatExpirationDate } = require('@librechat/api'); +const getCustomConfig = require('~/server/services/Config/loadCustomConfig'); const { getMessages, deleteMessages } = require('./Message'); -const logger = require('~/config/winston'); +const { Conversation } = require('~/db/models'); /** * Searches for a conversation by conversationId and returns a lean document with only conversationId and user. @@ -75,7 +77,6 @@ const getConvoFiles = async (conversationId) => { }; module.exports = { - Conversation, getConvoFiles, searchConversation, deleteNullOrEmptyConversations, @@ -99,10 +100,15 @@ module.exports = { update.conversationId = newConversationId; } - if (req.body.isTemporary) { - const expiredAt = new Date(); - expiredAt.setDate(expiredAt.getDate() + 30); - update.expiredAt = expiredAt; + if (req?.body?.isTemporary) { + try { + const customConfig = await getCustomConfig(); + update.expiredAt = createTempChatExpirationDate(customConfig); + } catch (err) { + logger.error('Error creating temporary chat expiration date:', err); + logger.info(`---\`saveConvo\` context: ${metadata?.context}`); + update.expiredAt = null; + } } else { update.expiredAt = null; } @@ -155,7 +161,6 @@ module.exports = { { cursor, limit = 25, isArchived = false, tags, search, order = 'desc' } = {}, ) => { const filters = [{ user }]; - if (isArchived) { filters.push({ isArchived: true }); } else { @@ -288,7 +293,6 @@ module.exports = { deleteConvos: async (user, filter) => { try { const userFilter = { ...filter, user }; - const conversations = await Conversation.find(userFilter).select('conversationId'); const conversationIds = conversations.map((c) => c.conversationId); diff --git a/api/models/ConversationTag.js b/api/models/ConversationTag.js index f0cac8620e..e6dc96be64 100644 --- a/api/models/ConversationTag.js +++ b/api/models/ConversationTag.js @@ -1,10 +1,5 @@ -const mongoose = require('mongoose'); -const Conversation = require('./schema/convoSchema'); -const logger = require('~/config/winston'); - -const { conversationTagSchema } = require('@librechat/data-schemas'); - -const ConversationTag = mongoose.model('ConversationTag', conversationTagSchema); +const { logger } = require('@librechat/data-schemas'); +const { ConversationTag, Conversation } = require('~/db/models'); /** * Retrieves all conversation tags for a user. @@ -140,13 +135,13 @@ const adjustPositions = async (user, oldPosition, newPosition) => { const position = oldPosition < newPosition ? { - $gt: Math.min(oldPosition, newPosition), - $lte: Math.max(oldPosition, newPosition), - } + $gt: Math.min(oldPosition, newPosition), + $lte: Math.max(oldPosition, newPosition), + } : { - $gte: Math.min(oldPosition, newPosition), - $lt: Math.max(oldPosition, newPosition), - }; + $gte: Math.min(oldPosition, newPosition), + $lt: Math.max(oldPosition, newPosition), + }; await ConversationTag.updateMany( { diff --git a/api/models/File.js b/api/models/File.js index 4d94994478..1ee943131d 100644 --- a/api/models/File.js +++ b/api/models/File.js @@ -1,9 +1,6 @@ -const mongoose = require('mongoose'); -const { EToolResources } = require('librechat-data-provider'); -const { fileSchema } = require('@librechat/data-schemas'); -const { logger } = require('~/config'); - -const File = mongoose.model('File', fileSchema); +const { logger } = require('@librechat/data-schemas'); +const { EToolResources, FileContext } = require('librechat-data-provider'); +const { File } = require('~/db/models'); /** * Finds a file by its file_id with additional query options. @@ -35,19 +32,19 @@ const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }) => { * @returns {Promise>} Files that match the criteria */ const getToolFilesByIds = async (fileIds, toolResourceSet) => { - if (!fileIds || !fileIds.length) { + if (!fileIds || !fileIds.length || !toolResourceSet?.size) { return []; } try { const filter = { file_id: { $in: fileIds }, + $or: [], }; - if (toolResourceSet.size) { - filter.$or = []; + if (toolResourceSet.has(EToolResources.ocr)) { + filter.$or.push({ text: { $exists: true, $ne: null }, context: FileContext.agents }); } - if (toolResourceSet.has(EToolResources.file_search)) { filter.$or.push({ embedded: true }); } @@ -169,7 +166,6 @@ async function batchUpdateFiles(updates) { } module.exports = { - File, findFileById, getFiles, getToolFilesByIds, diff --git a/api/models/Key.js b/api/models/Key.js deleted file mode 100644 index c69c350a42..0000000000 --- a/api/models/Key.js +++ /dev/null @@ -1,4 +0,0 @@ -const mongoose = require('mongoose'); -const { keySchema } = require('@librechat/data-schemas'); - -module.exports = mongoose.model('Key', keySchema); diff --git a/api/models/Message.js b/api/models/Message.js index 86fd2fd549..c200c5f4d4 100644 --- a/api/models/Message.js +++ b/api/models/Message.js @@ -1,6 +1,8 @@ const { z } = require('zod'); -const Message = require('./schema/messageSchema'); -const { logger } = require('~/config'); +const { logger } = require('@librechat/data-schemas'); +const { createTempChatExpirationDate } = require('@librechat/api'); +const getCustomConfig = require('~/server/services/Config/loadCustomConfig'); +const { Message } = require('~/db/models'); const idSchema = z.string().uuid(); @@ -54,9 +56,14 @@ async function saveMessage(req, params, metadata) { }; if (req?.body?.isTemporary) { - const expiredAt = new Date(); - expiredAt.setDate(expiredAt.getDate() + 30); - update.expiredAt = expiredAt; + try { + const customConfig = await getCustomConfig(); + update.expiredAt = createTempChatExpirationDate(customConfig); + } catch (err) { + logger.error('Error creating temporary chat expiration date:', err); + logger.info(`---\`saveMessage\` context: ${metadata?.context}`); + update.expiredAt = null; + } } else { update.expiredAt = null; } @@ -68,7 +75,6 @@ async function saveMessage(req, params, metadata) { logger.info(`---\`saveMessage\` context: ${metadata?.context}`); update.tokenCount = 0; } - const message = await Message.findOneAndUpdate( { messageId: params.messageId, user: req.user.id }, update, @@ -140,7 +146,6 @@ async function bulkSaveMessages(messages, overrideTimestamp = false) { upsert: true, }, })); - const result = await Message.bulkWrite(bulkOps); return result; } catch (err) { @@ -255,6 +260,7 @@ async function updateMessage(req, message, metadata) { text: updatedMessage.text, isCreatedByUser: updatedMessage.isCreatedByUser, tokenCount: updatedMessage.tokenCount, + feedback: updatedMessage.feedback, }; } catch (err) { logger.error('Error updating message:', err); @@ -355,7 +361,6 @@ async function deleteMessages(filter) { } module.exports = { - Message, saveMessage, bulkSaveMessages, recordMessage, diff --git a/api/models/Message.spec.js b/api/models/Message.spec.js index a542130b59..aebaebb442 100644 --- a/api/models/Message.spec.js +++ b/api/models/Message.spec.js @@ -1,32 +1,7 @@ const mongoose = require('mongoose'); +const { MongoMemoryServer } = require('mongodb-memory-server'); const { v4: uuidv4 } = require('uuid'); - -jest.mock('mongoose'); - -const mockFindQuery = { - select: jest.fn().mockReturnThis(), - sort: jest.fn().mockReturnThis(), - lean: jest.fn().mockReturnThis(), - deleteMany: jest.fn().mockResolvedValue({ deletedCount: 1 }), -}; - -const mockSchema = { - findOneAndUpdate: jest.fn(), - updateOne: jest.fn(), - findOne: jest.fn(() => ({ - lean: jest.fn(), - })), - find: jest.fn(() => mockFindQuery), - deleteMany: jest.fn(), -}; - -mongoose.model.mockReturnValue(mockSchema); - -jest.mock('~/models/schema/messageSchema', () => mockSchema); - -jest.mock('~/config/winston', () => ({ - error: jest.fn(), -})); +const { messageSchema } = require('@librechat/data-schemas'); const { saveMessage, @@ -35,77 +10,102 @@ const { deleteMessages, updateMessageText, deleteMessagesSince, -} = require('~/models/Message'); +} = require('./Message'); + +/** + * @type {import('mongoose').Model} + */ +let Message; describe('Message Operations', () => { + let mongoServer; let mockReq; - let mockMessage; + let mockMessageData; - beforeEach(() => { - jest.clearAllMocks(); + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + Message = mongoose.models.Message || mongoose.model('Message', messageSchema); + await mongoose.connect(mongoUri); + }); + + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); + }); + + beforeEach(async () => { + // Clear database + await Message.deleteMany({}); mockReq = { user: { id: 'user123' }, }; - mockMessage = { + mockMessageData = { messageId: 'msg123', conversationId: uuidv4(), text: 'Hello, world!', user: 'user123', }; - - mockSchema.findOneAndUpdate.mockResolvedValue({ - toObject: () => mockMessage, - }); }); describe('saveMessage', () => { it('should save a message for an authenticated user', async () => { - const result = await saveMessage(mockReq, mockMessage); - expect(result).toEqual(mockMessage); - expect(mockSchema.findOneAndUpdate).toHaveBeenCalledWith( - { messageId: 'msg123', user: 'user123' }, - expect.objectContaining({ user: 'user123' }), - expect.any(Object), - ); + const result = await saveMessage(mockReq, mockMessageData); + + expect(result.messageId).toBe('msg123'); + expect(result.user).toBe('user123'); + expect(result.text).toBe('Hello, world!'); + + // Verify the message was actually saved to the database + const savedMessage = await Message.findOne({ messageId: 'msg123', user: 'user123' }); + expect(savedMessage).toBeTruthy(); + expect(savedMessage.text).toBe('Hello, world!'); }); it('should throw an error for unauthenticated user', async () => { mockReq.user = null; - await expect(saveMessage(mockReq, mockMessage)).rejects.toThrow('User not authenticated'); + await expect(saveMessage(mockReq, mockMessageData)).rejects.toThrow('User not authenticated'); }); - it('should throw an error for invalid conversation ID', async () => { - mockMessage.conversationId = 'invalid-id'; - await expect(saveMessage(mockReq, mockMessage)).resolves.toBeUndefined(); + it('should handle invalid conversation ID gracefully', async () => { + mockMessageData.conversationId = 'invalid-id'; + const result = await saveMessage(mockReq, mockMessageData); + expect(result).toBeUndefined(); }); }); describe('updateMessageText', () => { it('should update message text for the authenticated user', async () => { + // First save a message + await saveMessage(mockReq, mockMessageData); + + // Then update it await updateMessageText(mockReq, { messageId: 'msg123', text: 'Updated text' }); - expect(mockSchema.updateOne).toHaveBeenCalledWith( - { messageId: 'msg123', user: 'user123' }, - { text: 'Updated text' }, - ); + + // Verify the update + const updatedMessage = await Message.findOne({ messageId: 'msg123', user: 'user123' }); + expect(updatedMessage.text).toBe('Updated text'); }); }); describe('updateMessage', () => { it('should update a message for the authenticated user', async () => { - mockSchema.findOneAndUpdate.mockResolvedValue(mockMessage); + // First save a message + await saveMessage(mockReq, mockMessageData); + const result = await updateMessage(mockReq, { messageId: 'msg123', text: 'Updated text' }); - expect(result).toEqual( - expect.objectContaining({ - messageId: 'msg123', - text: 'Hello, world!', - }), - ); + + expect(result.messageId).toBe('msg123'); + expect(result.text).toBe('Updated text'); + + // Verify in database + const updatedMessage = await Message.findOne({ messageId: 'msg123', user: 'user123' }); + expect(updatedMessage.text).toBe('Updated text'); }); it('should throw an error if message is not found', async () => { - mockSchema.findOneAndUpdate.mockResolvedValue(null); await expect( updateMessage(mockReq, { messageId: 'nonexistent', text: 'Test' }), ).rejects.toThrow('Message not found or user not authorized.'); @@ -114,19 +114,45 @@ describe('Message Operations', () => { describe('deleteMessagesSince', () => { it('should delete messages only for the authenticated user', async () => { - mockSchema.findOne().lean.mockResolvedValueOnce({ createdAt: new Date() }); - mockFindQuery.deleteMany.mockResolvedValueOnce({ deletedCount: 1 }); - const result = await deleteMessagesSince(mockReq, { - messageId: 'msg123', - conversationId: 'convo123', + const conversationId = uuidv4(); + + // Create multiple messages in the same conversation + const message1 = await saveMessage(mockReq, { + messageId: 'msg1', + conversationId, + text: 'First message', + user: 'user123', }); - expect(mockSchema.findOne).toHaveBeenCalledWith({ messageId: 'msg123', user: 'user123' }); - expect(mockSchema.find).not.toHaveBeenCalled(); - expect(result).toBeUndefined(); + + const message2 = await saveMessage(mockReq, { + messageId: 'msg2', + conversationId, + text: 'Second message', + user: 'user123', + }); + + const message3 = await saveMessage(mockReq, { + messageId: 'msg3', + conversationId, + text: 'Third message', + user: 'user123', + }); + + // Delete messages since message2 (this should only delete messages created AFTER msg2) + await deleteMessagesSince(mockReq, { + messageId: 'msg2', + conversationId, + }); + + // Verify msg1 and msg2 remain, msg3 is deleted + const remainingMessages = await Message.find({ conversationId, user: 'user123' }); + expect(remainingMessages).toHaveLength(2); + expect(remainingMessages.map((m) => m.messageId)).toContain('msg1'); + expect(remainingMessages.map((m) => m.messageId)).toContain('msg2'); + expect(remainingMessages.map((m) => m.messageId)).not.toContain('msg3'); }); it('should return undefined if no message is found', async () => { - mockSchema.findOne().lean.mockResolvedValueOnce(null); const result = await deleteMessagesSince(mockReq, { messageId: 'nonexistent', conversationId: 'convo123', @@ -137,29 +163,71 @@ describe('Message Operations', () => { describe('getMessages', () => { it('should retrieve messages with the correct filter', async () => { - const filter = { conversationId: 'convo123' }; - await getMessages(filter); - expect(mockSchema.find).toHaveBeenCalledWith(filter); - expect(mockFindQuery.sort).toHaveBeenCalledWith({ createdAt: 1 }); - expect(mockFindQuery.lean).toHaveBeenCalled(); + const conversationId = uuidv4(); + + // Save some messages + await saveMessage(mockReq, { + messageId: 'msg1', + conversationId, + text: 'First message', + user: 'user123', + }); + + await saveMessage(mockReq, { + messageId: 'msg2', + conversationId, + text: 'Second message', + user: 'user123', + }); + + const messages = await getMessages({ conversationId }); + expect(messages).toHaveLength(2); + expect(messages[0].text).toBe('First message'); + expect(messages[1].text).toBe('Second message'); }); }); describe('deleteMessages', () => { it('should delete messages with the correct filter', async () => { + // Save some messages for different users + await saveMessage(mockReq, mockMessageData); + await saveMessage( + { user: { id: 'user456' } }, + { + messageId: 'msg456', + conversationId: uuidv4(), + text: 'Other user message', + user: 'user456', + }, + ); + await deleteMessages({ user: 'user123' }); - expect(mockSchema.deleteMany).toHaveBeenCalledWith({ user: 'user123' }); + + // Verify only user123's messages were deleted + const user123Messages = await Message.find({ user: 'user123' }); + const user456Messages = await Message.find({ user: 'user456' }); + + expect(user123Messages).toHaveLength(0); + expect(user456Messages).toHaveLength(1); }); }); describe('Conversation Hijacking Prevention', () => { - it('should not allow editing a message in another user\'s conversation', async () => { + it("should not allow editing a message in another user's conversation", async () => { const attackerReq = { user: { id: 'attacker123' } }; - const victimConversationId = 'victim-convo-123'; + const victimConversationId = uuidv4(); const victimMessageId = 'victim-msg-123'; - mockSchema.findOneAndUpdate.mockResolvedValue(null); + // First, save a message as the victim (but we'll try to edit as attacker) + const victimReq = { user: { id: 'victim123' } }; + await saveMessage(victimReq, { + messageId: victimMessageId, + conversationId: victimConversationId, + text: 'Victim message', + user: 'victim123', + }); + // Attacker tries to edit the victim's message await expect( updateMessage(attackerReq, { messageId: victimMessageId, @@ -168,71 +236,82 @@ describe('Message Operations', () => { }), ).rejects.toThrow('Message not found or user not authorized.'); - expect(mockSchema.findOneAndUpdate).toHaveBeenCalledWith( - { messageId: victimMessageId, user: 'attacker123' }, - expect.anything(), - expect.anything(), - ); + // Verify the original message is unchanged + const originalMessage = await Message.findOne({ + messageId: victimMessageId, + user: 'victim123', + }); + expect(originalMessage.text).toBe('Victim message'); }); - it('should not allow deleting messages from another user\'s conversation', async () => { + it("should not allow deleting messages from another user's conversation", async () => { const attackerReq = { user: { id: 'attacker123' } }; - const victimConversationId = 'victim-convo-123'; + const victimConversationId = uuidv4(); const victimMessageId = 'victim-msg-123'; - mockSchema.findOne().lean.mockResolvedValueOnce(null); // Simulating message not found for this user + // Save a message as the victim + const victimReq = { user: { id: 'victim123' } }; + await saveMessage(victimReq, { + messageId: victimMessageId, + conversationId: victimConversationId, + text: 'Victim message', + user: 'victim123', + }); + + // Attacker tries to delete from victim's conversation const result = await deleteMessagesSince(attackerReq, { messageId: victimMessageId, conversationId: victimConversationId, }); expect(result).toBeUndefined(); - expect(mockSchema.findOne).toHaveBeenCalledWith({ + + // Verify the victim's message still exists + const victimMessage = await Message.findOne({ messageId: victimMessageId, - user: 'attacker123', + user: 'victim123', }); + expect(victimMessage).toBeTruthy(); + expect(victimMessage.text).toBe('Victim message'); }); - it('should not allow inserting a new message into another user\'s conversation', async () => { + it("should not allow inserting a new message into another user's conversation", async () => { const attackerReq = { user: { id: 'attacker123' } }; - const victimConversationId = uuidv4(); // Use a valid UUID + const victimConversationId = uuidv4(); - await expect( - saveMessage(attackerReq, { - conversationId: victimConversationId, - text: 'Inserted malicious message', - messageId: 'new-msg-123', - }), - ).resolves.not.toThrow(); // It should not throw an error + // Attacker tries to save a message - this should succeed but with attacker's user ID + const result = await saveMessage(attackerReq, { + conversationId: victimConversationId, + text: 'Inserted malicious message', + messageId: 'new-msg-123', + user: 'attacker123', + }); - // Check that the message was saved with the attacker's user ID - expect(mockSchema.findOneAndUpdate).toHaveBeenCalledWith( - { messageId: 'new-msg-123', user: 'attacker123' }, - expect.objectContaining({ - user: 'attacker123', - conversationId: victimConversationId, - }), - expect.anything(), - ); + expect(result).toBeTruthy(); + expect(result.user).toBe('attacker123'); + + // Verify the message was saved with the attacker's user ID, not as an anonymous message + const savedMessage = await Message.findOne({ messageId: 'new-msg-123' }); + expect(savedMessage.user).toBe('attacker123'); + expect(savedMessage.conversationId).toBe(victimConversationId); }); it('should allow retrieving messages from any conversation', async () => { - const victimConversationId = 'victim-convo-123'; + const victimConversationId = uuidv4(); - await getMessages({ conversationId: victimConversationId }); - - expect(mockSchema.find).toHaveBeenCalledWith({ + // Save a message in the victim's conversation + const victimReq = { user: { id: 'victim123' } }; + await saveMessage(victimReq, { + messageId: 'victim-msg', conversationId: victimConversationId, + text: 'Victim message', + user: 'victim123', }); - mockSchema.find.mockReturnValueOnce({ - select: jest.fn().mockReturnThis(), - sort: jest.fn().mockReturnThis(), - lean: jest.fn().mockResolvedValue([{ text: 'Test message' }]), - }); - - const result = await getMessages({ conversationId: victimConversationId }); - expect(result).toEqual([{ text: 'Test message' }]); + // Anyone should be able to retrieve messages by conversation ID + const messages = await getMessages({ conversationId: victimConversationId }); + expect(messages).toHaveLength(1); + expect(messages[0].text).toBe('Victim message'); }); }); }); diff --git a/api/models/Preset.js b/api/models/Preset.js index 970b2958fb..4db3d59066 100644 --- a/api/models/Preset.js +++ b/api/models/Preset.js @@ -1,5 +1,5 @@ -const Preset = require('./schema/presetSchema'); -const { logger } = require('~/config'); +const { logger } = require('@librechat/data-schemas'); +const { Preset } = require('~/db/models'); const getPreset = async (user, presetId) => { try { @@ -11,7 +11,6 @@ const getPreset = async (user, presetId) => { }; module.exports = { - Preset, getPreset, getPresets: async (user, filter) => { try { diff --git a/api/models/Project.js b/api/models/Project.js index 43d7263723..8fd1e556f9 100644 --- a/api/models/Project.js +++ b/api/models/Project.js @@ -1,8 +1,5 @@ -const { model } = require('mongoose'); const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants; -const { projectSchema } = require('@librechat/data-schemas'); - -const Project = model('Project', projectSchema); +const { Project } = require('~/db/models'); /** * Retrieve a project by ID and convert the found project document to a plain object. diff --git a/api/models/Prompt.js b/api/models/Prompt.js index 43dc3ec22b..9499e19c8e 100644 --- a/api/models/Prompt.js +++ b/api/models/Prompt.js @@ -1,5 +1,5 @@ -const mongoose = require('mongoose'); const { ObjectId } = require('mongodb'); +const { logger } = require('@librechat/data-schemas'); const { SystemRoles, SystemCategories, Constants } = require('librechat-data-provider'); const { getProjectByName, @@ -7,12 +7,8 @@ const { removeGroupIdsFromProject, removeGroupFromAllProjects, } = require('./Project'); -const { promptGroupSchema, promptSchema } = require('@librechat/data-schemas'); +const { PromptGroup, Prompt } = require('~/db/models'); const { escapeRegExp } = require('~/server/utils'); -const { logger } = require('~/config'); - -const PromptGroup = mongoose.model('PromptGroup', promptGroupSchema); -const Prompt = mongoose.model('Prompt', promptSchema); /** * Create a pipeline for the aggregation to get prompt groups diff --git a/api/models/Role.js b/api/models/Role.js index 07bf5a2ccb..d7f1c0f9cf 100644 --- a/api/models/Role.js +++ b/api/models/Role.js @@ -1,4 +1,3 @@ -const mongoose = require('mongoose'); const { CacheKeys, SystemRoles, @@ -7,11 +6,9 @@ const { permissionsSchema, removeNullishValues, } = require('librechat-data-provider'); +const { logger } = require('@librechat/data-schemas'); const getLogStores = require('~/cache/getLogStores'); -const { roleSchema } = require('@librechat/data-schemas'); -const { logger } = require('~/config'); - -const Role = mongoose.model('Role', roleSchema); +const { Role } = require('~/db/models'); /** * Retrieve a role by name and convert the found role document to a plain object. @@ -173,35 +170,6 @@ async function updateAccessPermissions(roleName, permissionsUpdate) { } } -/** - * Initialize default roles in the system. - * Creates the default roles (ADMIN, USER) if they don't exist in the database. - * Updates existing roles with new permission types if they're missing. - * - * @returns {Promise} - */ -const initializeRoles = async function () { - for (const roleName of [SystemRoles.ADMIN, SystemRoles.USER]) { - let role = await Role.findOne({ name: roleName }); - const defaultPerms = roleDefaults[roleName].permissions; - - if (!role) { - // Create new role if it doesn't exist. - role = new Role(roleDefaults[roleName]); - } else { - // Ensure role.permissions is defined. - role.permissions = role.permissions || {}; - // For each permission type in defaults, add it if missing. - for (const permType of Object.keys(defaultPerms)) { - if (role.permissions[permType] == null) { - role.permissions[permType] = defaultPerms[permType]; - } - } - } - await role.save(); - } -}; - /** * Migrates roles from old schema to new schema structure. * This can be called directly to fix existing roles. @@ -282,10 +250,8 @@ const migrateRoleSchema = async function (roleName) { }; module.exports = { - Role, getRoleByName, - initializeRoles, updateRoleByName, - updateAccessPermissions, migrateRoleSchema, + updateAccessPermissions, }; diff --git a/api/models/Role.spec.js b/api/models/Role.spec.js index a8b60801ca..c344f719dd 100644 --- a/api/models/Role.spec.js +++ b/api/models/Role.spec.js @@ -6,8 +6,10 @@ const { roleDefaults, PermissionTypes, } = require('librechat-data-provider'); -const { Role, getRoleByName, updateAccessPermissions, initializeRoles } = require('~/models/Role'); +const { getRoleByName, updateAccessPermissions } = require('~/models/Role'); const getLogStores = require('~/cache/getLogStores'); +const { initializeRoles } = require('~/models'); +const { Role } = require('~/db/models'); // Mock the cache jest.mock('~/cache/getLogStores', () => diff --git a/api/models/Session.js b/api/models/Session.js deleted file mode 100644 index 38821b77dd..0000000000 --- a/api/models/Session.js +++ /dev/null @@ -1,275 +0,0 @@ -const mongoose = require('mongoose'); -const signPayload = require('~/server/services/signPayload'); -const { hashToken } = require('~/server/utils/crypto'); -const { sessionSchema } = require('@librechat/data-schemas'); -const { logger } = require('~/config'); - -const Session = mongoose.model('Session', sessionSchema); - -const { REFRESH_TOKEN_EXPIRY } = process.env ?? {}; -const expires = eval(REFRESH_TOKEN_EXPIRY) ?? 1000 * 60 * 60 * 24 * 7; // 7 days default - -/** - * Error class for Session-related errors - */ -class SessionError extends Error { - constructor(message, code = 'SESSION_ERROR') { - super(message); - this.name = 'SessionError'; - this.code = code; - } -} - -/** - * Creates a new session for a user - * @param {string} userId - The ID of the user - * @param {Object} options - Additional options for session creation - * @param {Date} options.expiration - Custom expiration date - * @returns {Promise<{session: Session, refreshToken: string}>} - * @throws {SessionError} - */ -const createSession = async (userId, options = {}) => { - if (!userId) { - throw new SessionError('User ID is required', 'INVALID_USER_ID'); - } - - try { - const session = new Session({ - user: userId, - expiration: options.expiration || new Date(Date.now() + expires), - }); - const refreshToken = await generateRefreshToken(session); - return { session, refreshToken }; - } catch (error) { - logger.error('[createSession] Error creating session:', error); - throw new SessionError('Failed to create session', 'CREATE_SESSION_FAILED'); - } -}; - -/** - * Finds a session by various parameters - * @param {Object} params - Search parameters - * @param {string} [params.refreshToken] - The refresh token to search by - * @param {string} [params.userId] - The user ID to search by - * @param {string} [params.sessionId] - The session ID to search by - * @param {Object} [options] - Additional options - * @param {boolean} [options.lean=true] - Whether to return plain objects instead of documents - * @returns {Promise} - * @throws {SessionError} - */ -const findSession = async (params, options = { lean: true }) => { - try { - const query = {}; - - if (!params.refreshToken && !params.userId && !params.sessionId) { - throw new SessionError('At least one search parameter is required', 'INVALID_SEARCH_PARAMS'); - } - - if (params.refreshToken) { - const tokenHash = await hashToken(params.refreshToken); - query.refreshTokenHash = tokenHash; - } - - if (params.userId) { - query.user = params.userId; - } - - if (params.sessionId) { - const sessionId = params.sessionId.sessionId || params.sessionId; - if (!mongoose.Types.ObjectId.isValid(sessionId)) { - throw new SessionError('Invalid session ID format', 'INVALID_SESSION_ID'); - } - query._id = sessionId; - } - - // Add expiration check to only return valid sessions - query.expiration = { $gt: new Date() }; - - const sessionQuery = Session.findOne(query); - - if (options.lean) { - return await sessionQuery.lean(); - } - - return await sessionQuery.exec(); - } catch (error) { - logger.error('[findSession] Error finding session:', error); - throw new SessionError('Failed to find session', 'FIND_SESSION_FAILED'); - } -}; - -/** - * Updates session expiration - * @param {Session|string} session - The session or session ID to update - * @param {Date} [newExpiration] - Optional new expiration date - * @returns {Promise} - * @throws {SessionError} - */ -const updateExpiration = async (session, newExpiration) => { - try { - const sessionDoc = typeof session === 'string' ? await Session.findById(session) : session; - - if (!sessionDoc) { - throw new SessionError('Session not found', 'SESSION_NOT_FOUND'); - } - - sessionDoc.expiration = newExpiration || new Date(Date.now() + expires); - return await sessionDoc.save(); - } catch (error) { - logger.error('[updateExpiration] Error updating session:', error); - throw new SessionError('Failed to update session expiration', 'UPDATE_EXPIRATION_FAILED'); - } -}; - -/** - * Deletes a session by refresh token or session ID - * @param {Object} params - Delete parameters - * @param {string} [params.refreshToken] - The refresh token of the session to delete - * @param {string} [params.sessionId] - The ID of the session to delete - * @returns {Promise} - * @throws {SessionError} - */ -const deleteSession = async (params) => { - try { - if (!params.refreshToken && !params.sessionId) { - throw new SessionError( - 'Either refreshToken or sessionId is required', - 'INVALID_DELETE_PARAMS', - ); - } - - const query = {}; - - if (params.refreshToken) { - query.refreshTokenHash = await hashToken(params.refreshToken); - } - - if (params.sessionId) { - query._id = params.sessionId; - } - - const result = await Session.deleteOne(query); - - if (result.deletedCount === 0) { - logger.warn('[deleteSession] No session found to delete'); - } - - return result; - } catch (error) { - logger.error('[deleteSession] Error deleting session:', error); - throw new SessionError('Failed to delete session', 'DELETE_SESSION_FAILED'); - } -}; - -/** - * Deletes all sessions for a user - * @param {string} userId - The ID of the user - * @param {Object} [options] - Additional options - * @param {boolean} [options.excludeCurrentSession] - Whether to exclude the current session - * @param {string} [options.currentSessionId] - The ID of the current session to exclude - * @returns {Promise} - * @throws {SessionError} - */ -const deleteAllUserSessions = async (userId, options = {}) => { - try { - if (!userId) { - throw new SessionError('User ID is required', 'INVALID_USER_ID'); - } - - // Extract userId if it's passed as an object - const userIdString = userId.userId || userId; - - if (!mongoose.Types.ObjectId.isValid(userIdString)) { - throw new SessionError('Invalid user ID format', 'INVALID_USER_ID_FORMAT'); - } - - const query = { user: userIdString }; - - if (options.excludeCurrentSession && options.currentSessionId) { - query._id = { $ne: options.currentSessionId }; - } - - const result = await Session.deleteMany(query); - - if (result.deletedCount > 0) { - logger.debug( - `[deleteAllUserSessions] Deleted ${result.deletedCount} sessions for user ${userIdString}.`, - ); - } - - return result; - } catch (error) { - logger.error('[deleteAllUserSessions] Error deleting user sessions:', error); - throw new SessionError('Failed to delete user sessions', 'DELETE_ALL_SESSIONS_FAILED'); - } -}; - -/** - * Generates a refresh token for a session - * @param {Session} session - The session to generate a token for - * @returns {Promise} - * @throws {SessionError} - */ -const generateRefreshToken = async (session) => { - if (!session || !session.user) { - throw new SessionError('Invalid session object', 'INVALID_SESSION'); - } - - try { - const expiresIn = session.expiration ? session.expiration.getTime() : Date.now() + expires; - - if (!session.expiration) { - session.expiration = new Date(expiresIn); - } - - const refreshToken = await signPayload({ - payload: { - id: session.user, - sessionId: session._id, - }, - secret: process.env.JWT_REFRESH_SECRET, - expirationTime: Math.floor((expiresIn - Date.now()) / 1000), - }); - - session.refreshTokenHash = await hashToken(refreshToken); - await session.save(); - - return refreshToken; - } catch (error) { - logger.error('[generateRefreshToken] Error generating refresh token:', error); - throw new SessionError('Failed to generate refresh token', 'GENERATE_TOKEN_FAILED'); - } -}; - -/** - * Counts active sessions for a user - * @param {string} userId - The ID of the user - * @returns {Promise} - * @throws {SessionError} - */ -const countActiveSessions = async (userId) => { - try { - if (!userId) { - throw new SessionError('User ID is required', 'INVALID_USER_ID'); - } - - return await Session.countDocuments({ - user: userId, - expiration: { $gt: new Date() }, - }); - } catch (error) { - logger.error('[countActiveSessions] Error counting active sessions:', error); - throw new SessionError('Failed to count active sessions', 'COUNT_SESSIONS_FAILED'); - } -}; - -module.exports = { - createSession, - findSession, - updateExpiration, - deleteSession, - deleteAllUserSessions, - generateRefreshToken, - countActiveSessions, - SessionError, -}; diff --git a/api/models/Share.js b/api/models/Share.js deleted file mode 100644 index 8611d01bc0..0000000000 --- a/api/models/Share.js +++ /dev/null @@ -1,351 +0,0 @@ -const mongoose = require('mongoose'); -const { nanoid } = require('nanoid'); -const { Constants } = require('librechat-data-provider'); -const { Conversation } = require('~/models/Conversation'); -const { shareSchema } = require('@librechat/data-schemas'); -const SharedLink = mongoose.model('SharedLink', shareSchema); -const { getMessages } = require('./Message'); -const logger = require('~/config/winston'); - -class ShareServiceError extends Error { - constructor(message, code) { - super(message); - this.name = 'ShareServiceError'; - this.code = code; - } -} - -const memoizedAnonymizeId = (prefix) => { - const memo = new Map(); - return (id) => { - if (!memo.has(id)) { - memo.set(id, `${prefix}_${nanoid()}`); - } - return memo.get(id); - }; -}; - -const anonymizeConvoId = memoizedAnonymizeId('convo'); -const anonymizeAssistantId = memoizedAnonymizeId('a'); -const anonymizeMessageId = (id) => - id === Constants.NO_PARENT ? id : memoizedAnonymizeId('msg')(id); - -function anonymizeConvo(conversation) { - if (!conversation) { - return null; - } - - const newConvo = { ...conversation }; - if (newConvo.assistant_id) { - newConvo.assistant_id = anonymizeAssistantId(newConvo.assistant_id); - } - return newConvo; -} - -function anonymizeMessages(messages, newConvoId) { - if (!Array.isArray(messages)) { - return []; - } - - const idMap = new Map(); - return messages.map((message) => { - const newMessageId = anonymizeMessageId(message.messageId); - idMap.set(message.messageId, newMessageId); - - const anonymizedAttachments = message.attachments?.map((attachment) => { - return { - ...attachment, - messageId: newMessageId, - conversationId: newConvoId, - }; - }); - - return { - ...message, - messageId: newMessageId, - parentMessageId: - idMap.get(message.parentMessageId) || anonymizeMessageId(message.parentMessageId), - conversationId: newConvoId, - model: message.model?.startsWith('asst_') - ? anonymizeAssistantId(message.model) - : message.model, - attachments: anonymizedAttachments, - }; - }); -} - -async function getSharedMessages(shareId) { - try { - const share = await SharedLink.findOne({ shareId, isPublic: true }) - .populate({ - path: 'messages', - select: '-_id -__v -user', - }) - .select('-_id -__v -user') - .lean(); - - if (!share?.conversationId || !share.isPublic) { - return null; - } - - const newConvoId = anonymizeConvoId(share.conversationId); - const result = { - ...share, - conversationId: newConvoId, - messages: anonymizeMessages(share.messages, newConvoId), - }; - - return result; - } catch (error) { - logger.error('[getShare] Error getting share link', { - error: error.message, - shareId, - }); - throw new ShareServiceError('Error getting share link', 'SHARE_FETCH_ERROR'); - } -} - -async function getSharedLinks(user, pageParam, pageSize, isPublic, sortBy, sortDirection, search) { - try { - const query = { user, isPublic }; - - if (pageParam) { - if (sortDirection === 'desc') { - query[sortBy] = { $lt: pageParam }; - } else { - query[sortBy] = { $gt: pageParam }; - } - } - - if (search && search.trim()) { - try { - const searchResults = await Conversation.meiliSearch(search); - - if (!searchResults?.hits?.length) { - return { - links: [], - nextCursor: undefined, - hasNextPage: false, - }; - } - - const conversationIds = searchResults.hits.map((hit) => hit.conversationId); - query['conversationId'] = { $in: conversationIds }; - } catch (searchError) { - logger.error('[getSharedLinks] Meilisearch error', { - error: searchError.message, - user, - }); - return { - links: [], - nextCursor: undefined, - hasNextPage: false, - }; - } - } - - const sort = {}; - sort[sortBy] = sortDirection === 'desc' ? -1 : 1; - - if (Array.isArray(query.conversationId)) { - query.conversationId = { $in: query.conversationId }; - } - - const sharedLinks = await SharedLink.find(query) - .sort(sort) - .limit(pageSize + 1) - .select('-__v -user') - .lean(); - - const hasNextPage = sharedLinks.length > pageSize; - const links = sharedLinks.slice(0, pageSize); - - const nextCursor = hasNextPage ? links[links.length - 1][sortBy] : undefined; - - return { - links: links.map((link) => ({ - shareId: link.shareId, - title: link?.title || 'Untitled', - isPublic: link.isPublic, - createdAt: link.createdAt, - conversationId: link.conversationId, - })), - nextCursor, - hasNextPage, - }; - } catch (error) { - logger.error('[getSharedLinks] Error getting shares', { - error: error.message, - user, - }); - throw new ShareServiceError('Error getting shares', 'SHARES_FETCH_ERROR'); - } -} - -async function deleteAllSharedLinks(user) { - try { - const result = await SharedLink.deleteMany({ user }); - return { - message: 'All shared links deleted successfully', - deletedCount: result.deletedCount, - }; - } catch (error) { - logger.error('[deleteAllSharedLinks] Error deleting shared links', { - error: error.message, - user, - }); - throw new ShareServiceError('Error deleting shared links', 'BULK_DELETE_ERROR'); - } -} - -async function createSharedLink(user, conversationId) { - if (!user || !conversationId) { - throw new ShareServiceError('Missing required parameters', 'INVALID_PARAMS'); - } - - try { - const [existingShare, conversationMessages] = await Promise.all([ - SharedLink.findOne({ conversationId, isPublic: true }).select('-_id -__v -user').lean(), - getMessages({ conversationId }), - ]); - - if (existingShare && existingShare.isPublic) { - throw new ShareServiceError('Share already exists', 'SHARE_EXISTS'); - } else if (existingShare) { - await SharedLink.deleteOne({ conversationId }); - } - - const conversation = await Conversation.findOne({ conversationId }).lean(); - const title = conversation?.title || 'Untitled'; - - const shareId = nanoid(); - await SharedLink.create({ - shareId, - conversationId, - messages: conversationMessages, - title, - user, - }); - - return { shareId, conversationId }; - } catch (error) { - logger.error('[createSharedLink] Error creating shared link', { - error: error.message, - user, - conversationId, - }); - throw new ShareServiceError('Error creating shared link', 'SHARE_CREATE_ERROR'); - } -} - -async function getSharedLink(user, conversationId) { - if (!user || !conversationId) { - throw new ShareServiceError('Missing required parameters', 'INVALID_PARAMS'); - } - - try { - const share = await SharedLink.findOne({ conversationId, user, isPublic: true }) - .select('shareId -_id') - .lean(); - - if (!share) { - return { shareId: null, success: false }; - } - - return { shareId: share.shareId, success: true }; - } catch (error) { - logger.error('[getSharedLink] Error getting shared link', { - error: error.message, - user, - conversationId, - }); - throw new ShareServiceError('Error getting shared link', 'SHARE_FETCH_ERROR'); - } -} - -async function updateSharedLink(user, shareId) { - if (!user || !shareId) { - throw new ShareServiceError('Missing required parameters', 'INVALID_PARAMS'); - } - - try { - const share = await SharedLink.findOne({ shareId }).select('-_id -__v -user').lean(); - - if (!share) { - throw new ShareServiceError('Share not found', 'SHARE_NOT_FOUND'); - } - - const [updatedMessages] = await Promise.all([ - getMessages({ conversationId: share.conversationId }), - ]); - - const newShareId = nanoid(); - const update = { - messages: updatedMessages, - user, - shareId: newShareId, - }; - - const updatedShare = await SharedLink.findOneAndUpdate({ shareId, user }, update, { - new: true, - upsert: false, - runValidators: true, - }).lean(); - - if (!updatedShare) { - throw new ShareServiceError('Share update failed', 'SHARE_UPDATE_ERROR'); - } - - anonymizeConvo(updatedShare); - - return { shareId: newShareId, conversationId: updatedShare.conversationId }; - } catch (error) { - logger.error('[updateSharedLink] Error updating shared link', { - error: error.message, - user, - shareId, - }); - throw new ShareServiceError( - error.code === 'SHARE_UPDATE_ERROR' ? error.message : 'Error updating shared link', - error.code || 'SHARE_UPDATE_ERROR', - ); - } -} - -async function deleteSharedLink(user, shareId) { - if (!user || !shareId) { - throw new ShareServiceError('Missing required parameters', 'INVALID_PARAMS'); - } - - try { - const result = await SharedLink.findOneAndDelete({ shareId, user }).lean(); - - if (!result) { - return null; - } - - return { - success: true, - shareId, - message: 'Share deleted successfully', - }; - } catch (error) { - logger.error('[deleteSharedLink] Error deleting shared link', { - error: error.message, - user, - shareId, - }); - throw new ShareServiceError('Error deleting shared link', 'SHARE_DELETE_ERROR'); - } -} - -module.exports = { - SharedLink, - getSharedLink, - getSharedLinks, - createSharedLink, - updateSharedLink, - deleteSharedLink, - getSharedMessages, - deleteAllSharedLinks, -}; diff --git a/api/models/Token.js b/api/models/Token.js deleted file mode 100644 index c89abb8c84..0000000000 --- a/api/models/Token.js +++ /dev/null @@ -1,199 +0,0 @@ -const mongoose = require('mongoose'); -const { encryptV2 } = require('~/server/utils/crypto'); -const { tokenSchema } = require('@librechat/data-schemas'); -const { logger } = require('~/config'); - -/** - * Token model. - * @type {mongoose.Model} - */ -const Token = mongoose.model('Token', tokenSchema); -/** - * Fixes the indexes for the Token collection from legacy TTL indexes to the new expiresAt index. - */ -async function fixIndexes() { - try { - if ( - process.env.NODE_ENV === 'CI' || - process.env.NODE_ENV === 'development' || - process.env.NODE_ENV === 'test' - ) { - return; - } - const indexes = await Token.collection.indexes(); - logger.debug('Existing Token Indexes:', JSON.stringify(indexes, null, 2)); - const unwantedTTLIndexes = indexes.filter( - (index) => index.key.createdAt === 1 && index.expireAfterSeconds !== undefined, - ); - if (unwantedTTLIndexes.length === 0) { - logger.debug('No unwanted Token indexes found.'); - return; - } - for (const index of unwantedTTLIndexes) { - logger.debug(`Dropping unwanted Token index: ${index.name}`); - await Token.collection.dropIndex(index.name); - logger.debug(`Dropped Token index: ${index.name}`); - } - logger.debug('Token index cleanup completed successfully.'); - } catch (error) { - logger.error('An error occurred while fixing Token indexes:', error); - } -} - -fixIndexes(); - -/** - * Creates a new Token instance. - * @param {Object} tokenData - The data for the new Token. - * @param {mongoose.Types.ObjectId} tokenData.userId - The user's ID. It is required. - * @param {String} tokenData.email - The user's email. - * @param {String} tokenData.token - The token. It is required. - * @param {Number} tokenData.expiresIn - The number of seconds until the token expires. - * @returns {Promise} The new Token instance. - * @throws Will throw an error if token creation fails. - */ -async function createToken(tokenData) { - try { - const currentTime = new Date(); - const expiresAt = new Date(currentTime.getTime() + tokenData.expiresIn * 1000); - - const newTokenData = { - ...tokenData, - createdAt: currentTime, - expiresAt, - }; - - return await Token.create(newTokenData); - } catch (error) { - logger.debug('An error occurred while creating token:', error); - throw error; - } -} - -/** - * Finds a Token document that matches the provided query. - * @param {Object} query - The query to match against. - * @param {mongoose.Types.ObjectId|String} query.userId - The ID of the user. - * @param {String} query.token - The token value. - * @param {String} [query.email] - The email of the user. - * @param {String} [query.identifier] - Unique, alternative identifier for the token. - * @returns {Promise} The matched Token document, or null if not found. - * @throws Will throw an error if the find operation fails. - */ -async function findToken(query) { - try { - const conditions = []; - - if (query.userId) { - conditions.push({ userId: query.userId }); - } - if (query.token) { - conditions.push({ token: query.token }); - } - if (query.email) { - conditions.push({ email: query.email }); - } - if (query.identifier) { - conditions.push({ identifier: query.identifier }); - } - - const token = await Token.findOne({ - $and: conditions, - }).lean(); - - return token; - } catch (error) { - logger.debug('An error occurred while finding token:', error); - throw error; - } -} - -/** - * Updates a Token document that matches the provided query. - * @param {Object} query - The query to match against. - * @param {mongoose.Types.ObjectId|String} query.userId - The ID of the user. - * @param {String} query.token - The token value. - * @param {String} [query.email] - The email of the user. - * @param {String} [query.identifier] - Unique, alternative identifier for the token. - * @param {Object} updateData - The data to update the Token with. - * @returns {Promise} The updated Token document, or null if not found. - * @throws Will throw an error if the update operation fails. - */ -async function updateToken(query, updateData) { - try { - return await Token.findOneAndUpdate(query, updateData, { new: true }); - } catch (error) { - logger.debug('An error occurred while updating token:', error); - throw error; - } -} - -/** - * Deletes all Token documents that match the provided token, user ID, or email. - * @param {Object} query - The query to match against. - * @param {mongoose.Types.ObjectId|String} query.userId - The ID of the user. - * @param {String} query.token - The token value. - * @param {String} [query.email] - The email of the user. - * @param {String} [query.identifier] - Unique, alternative identifier for the token. - * @returns {Promise} The result of the delete operation. - * @throws Will throw an error if the delete operation fails. - */ -async function deleteTokens(query) { - try { - return await Token.deleteMany({ - $or: [ - { userId: query.userId }, - { token: query.token }, - { email: query.email }, - { identifier: query.identifier }, - ], - }); - } catch (error) { - logger.debug('An error occurred while deleting tokens:', error); - throw error; - } -} - -/** - * Handles the OAuth token by creating or updating the token. - * @param {object} fields - * @param {string} fields.userId - The user's ID. - * @param {string} fields.token - The full token to store. - * @param {string} fields.identifier - Unique, alternative identifier for the token. - * @param {number} fields.expiresIn - The number of seconds until the token expires. - * @param {object} fields.metadata - Additional metadata to store with the token. - * @param {string} [fields.type="oauth"] - The type of token. Default is 'oauth'. - */ -async function handleOAuthToken({ - token, - userId, - identifier, - expiresIn, - metadata, - type = 'oauth', -}) { - const encrypedToken = await encryptV2(token); - const tokenData = { - type, - userId, - metadata, - identifier, - token: encrypedToken, - expiresIn: parseInt(expiresIn, 10) || 3600, - }; - - const existingToken = await findToken({ userId, identifier }); - if (existingToken) { - return await updateToken({ identifier }, tokenData); - } else { - return await createToken(tokenData); - } -} - -module.exports = { - findToken, - createToken, - updateToken, - deleteTokens, - handleOAuthToken, -}; diff --git a/api/models/ToolCall.js b/api/models/ToolCall.js index 7bc0f157dc..689386114b 100644 --- a/api/models/ToolCall.js +++ b/api/models/ToolCall.js @@ -1,6 +1,4 @@ -const mongoose = require('mongoose'); -const { toolCallSchema } = require('@librechat/data-schemas'); -const ToolCall = mongoose.model('ToolCall', toolCallSchema); +const { ToolCall } = require('~/db/models'); /** * Create a new tool call diff --git a/api/models/Transaction.js b/api/models/Transaction.js index e171241b61..0e0e327857 100644 --- a/api/models/Transaction.js +++ b/api/models/Transaction.js @@ -1,9 +1,7 @@ -const mongoose = require('mongoose'); -const { transactionSchema } = require('@librechat/data-schemas'); +const { logger } = require('@librechat/data-schemas'); const { getBalanceConfig } = require('~/server/services/Config'); const { getMultiplier, getCacheMultiplier } = require('./tx'); -const { logger } = require('~/config'); -const Balance = require('./Balance'); +const { Transaction, Balance } = require('~/db/models'); const cancelRate = 1.15; @@ -140,19 +138,19 @@ const updateBalance = async ({ user, incrementValue, setValues }) => { }; /** Method to calculate and set the tokenValue for a transaction */ -transactionSchema.methods.calculateTokenValue = function () { - if (!this.valueKey || !this.tokenType) { - this.tokenValue = this.rawAmount; +function calculateTokenValue(txn) { + if (!txn.valueKey || !txn.tokenType) { + txn.tokenValue = txn.rawAmount; } - const { valueKey, tokenType, model, endpointTokenConfig } = this; + const { valueKey, tokenType, model, endpointTokenConfig } = txn; const multiplier = Math.abs(getMultiplier({ valueKey, tokenType, model, endpointTokenConfig })); - this.rate = multiplier; - this.tokenValue = this.rawAmount * multiplier; - if (this.context && this.tokenType === 'completion' && this.context === 'incomplete') { - this.tokenValue = Math.ceil(this.tokenValue * cancelRate); - this.rate *= cancelRate; + txn.rate = multiplier; + txn.tokenValue = txn.rawAmount * multiplier; + if (txn.context && txn.tokenType === 'completion' && txn.context === 'incomplete') { + txn.tokenValue = Math.ceil(txn.tokenValue * cancelRate); + txn.rate *= cancelRate; } -}; +} /** * New static method to create an auto-refill transaction that does NOT trigger a balance update. @@ -163,13 +161,13 @@ transactionSchema.methods.calculateTokenValue = function () { * @param {number} txData.rawAmount - The raw amount of tokens. * @returns {Promise} - The created transaction. */ -transactionSchema.statics.createAutoRefillTransaction = async function (txData) { +async function createAutoRefillTransaction(txData) { if (txData.rawAmount != null && isNaN(txData.rawAmount)) { return; } - const transaction = new this(txData); + const transaction = new Transaction(txData); transaction.endpointTokenConfig = txData.endpointTokenConfig; - transaction.calculateTokenValue(); + calculateTokenValue(transaction); await transaction.save(); const balanceResponse = await updateBalance({ @@ -185,21 +183,20 @@ transactionSchema.statics.createAutoRefillTransaction = async function (txData) logger.debug('[Balance.check] Auto-refill performed', result); result.transaction = transaction; return result; -}; +} /** * Static method to create a transaction and update the balance * @param {txData} txData - Transaction data. */ -transactionSchema.statics.create = async function (txData) { - const Transaction = this; +async function createTransaction(txData) { if (txData.rawAmount != null && isNaN(txData.rawAmount)) { return; } const transaction = new Transaction(txData); transaction.endpointTokenConfig = txData.endpointTokenConfig; - transaction.calculateTokenValue(); + calculateTokenValue(transaction); await transaction.save(); @@ -209,7 +206,6 @@ transactionSchema.statics.create = async function (txData) { } let incrementValue = transaction.tokenValue; - const balanceResponse = await updateBalance({ user: transaction.user, incrementValue, @@ -221,21 +217,19 @@ transactionSchema.statics.create = async function (txData) { balance: balanceResponse.tokenCredits, [transaction.tokenType]: incrementValue, }; -}; +} /** * Static method to create a structured transaction and update the balance * @param {txData} txData - Transaction data. */ -transactionSchema.statics.createStructured = async function (txData) { - const Transaction = this; - +async function createStructuredTransaction(txData) { const transaction = new Transaction({ ...txData, endpointTokenConfig: txData.endpointTokenConfig, }); - transaction.calculateStructuredTokenValue(); + calculateStructuredTokenValue(transaction); await transaction.save(); @@ -257,71 +251,69 @@ transactionSchema.statics.createStructured = async function (txData) { balance: balanceResponse.tokenCredits, [transaction.tokenType]: incrementValue, }; -}; +} /** Method to calculate token value for structured tokens */ -transactionSchema.methods.calculateStructuredTokenValue = function () { - if (!this.tokenType) { - this.tokenValue = this.rawAmount; +function calculateStructuredTokenValue(txn) { + if (!txn.tokenType) { + txn.tokenValue = txn.rawAmount; return; } - const { model, endpointTokenConfig } = this; + const { model, endpointTokenConfig } = txn; - if (this.tokenType === 'prompt') { + if (txn.tokenType === 'prompt') { const inputMultiplier = getMultiplier({ tokenType: 'prompt', model, endpointTokenConfig }); const writeMultiplier = getCacheMultiplier({ cacheType: 'write', model, endpointTokenConfig }) ?? inputMultiplier; const readMultiplier = getCacheMultiplier({ cacheType: 'read', model, endpointTokenConfig }) ?? inputMultiplier; - this.rateDetail = { + txn.rateDetail = { input: inputMultiplier, write: writeMultiplier, read: readMultiplier, }; const totalPromptTokens = - Math.abs(this.inputTokens || 0) + - Math.abs(this.writeTokens || 0) + - Math.abs(this.readTokens || 0); + Math.abs(txn.inputTokens || 0) + + Math.abs(txn.writeTokens || 0) + + Math.abs(txn.readTokens || 0); if (totalPromptTokens > 0) { - this.rate = - (Math.abs(inputMultiplier * (this.inputTokens || 0)) + - Math.abs(writeMultiplier * (this.writeTokens || 0)) + - Math.abs(readMultiplier * (this.readTokens || 0))) / + txn.rate = + (Math.abs(inputMultiplier * (txn.inputTokens || 0)) + + Math.abs(writeMultiplier * (txn.writeTokens || 0)) + + Math.abs(readMultiplier * (txn.readTokens || 0))) / totalPromptTokens; } else { - this.rate = Math.abs(inputMultiplier); // Default to input rate if no tokens + txn.rate = Math.abs(inputMultiplier); // Default to input rate if no tokens } - this.tokenValue = -( - Math.abs(this.inputTokens || 0) * inputMultiplier + - Math.abs(this.writeTokens || 0) * writeMultiplier + - Math.abs(this.readTokens || 0) * readMultiplier + txn.tokenValue = -( + Math.abs(txn.inputTokens || 0) * inputMultiplier + + Math.abs(txn.writeTokens || 0) * writeMultiplier + + Math.abs(txn.readTokens || 0) * readMultiplier ); - this.rawAmount = -totalPromptTokens; - } else if (this.tokenType === 'completion') { - const multiplier = getMultiplier({ tokenType: this.tokenType, model, endpointTokenConfig }); - this.rate = Math.abs(multiplier); - this.tokenValue = -Math.abs(this.rawAmount) * multiplier; - this.rawAmount = -Math.abs(this.rawAmount); + txn.rawAmount = -totalPromptTokens; + } else if (txn.tokenType === 'completion') { + const multiplier = getMultiplier({ tokenType: txn.tokenType, model, endpointTokenConfig }); + txn.rate = Math.abs(multiplier); + txn.tokenValue = -Math.abs(txn.rawAmount) * multiplier; + txn.rawAmount = -Math.abs(txn.rawAmount); } - if (this.context && this.tokenType === 'completion' && this.context === 'incomplete') { - this.tokenValue = Math.ceil(this.tokenValue * cancelRate); - this.rate *= cancelRate; - if (this.rateDetail) { - this.rateDetail = Object.fromEntries( - Object.entries(this.rateDetail).map(([k, v]) => [k, v * cancelRate]), + if (txn.context && txn.tokenType === 'completion' && txn.context === 'incomplete') { + txn.tokenValue = Math.ceil(txn.tokenValue * cancelRate); + txn.rate *= cancelRate; + if (txn.rateDetail) { + txn.rateDetail = Object.fromEntries( + Object.entries(txn.rateDetail).map(([k, v]) => [k, v * cancelRate]), ); } } -}; - -const Transaction = mongoose.model('Transaction', transactionSchema); +} /** * Queries and retrieves transactions based on a given filter. @@ -340,4 +332,9 @@ async function getTransactions(filter) { } } -module.exports = { Transaction, getTransactions }; +module.exports = { + getTransactions, + createTransaction, + createAutoRefillTransaction, + createStructuredTransaction, +}; diff --git a/api/models/Transaction.spec.js b/api/models/Transaction.spec.js index 43f3c004b2..3a1303edec 100644 --- a/api/models/Transaction.spec.js +++ b/api/models/Transaction.spec.js @@ -3,14 +3,13 @@ const { MongoMemoryServer } = require('mongodb-memory-server'); const { spendTokens, spendStructuredTokens } = require('./spendTokens'); const { getBalanceConfig } = require('~/server/services/Config'); const { getMultiplier, getCacheMultiplier } = require('./tx'); -const { Transaction } = require('./Transaction'); -const Balance = require('./Balance'); +const { createTransaction } = require('./Transaction'); +const { Balance } = require('~/db/models'); // Mock the custom config module so we can control the balance flag. jest.mock('~/server/services/Config'); let mongoServer; - beforeAll(async () => { mongoServer = await MongoMemoryServer.create(); const mongoUri = mongoServer.getUri(); @@ -368,7 +367,7 @@ describe('NaN Handling Tests', () => { }; // Act - const result = await Transaction.create(txData); + const result = await createTransaction(txData); // Assert: No transaction should be created and balance remains unchanged. expect(result).toBeUndefined(); diff --git a/api/models/User.js b/api/models/User.js deleted file mode 100644 index f4e8b0ec5b..0000000000 --- a/api/models/User.js +++ /dev/null @@ -1,6 +0,0 @@ -const mongoose = require('mongoose'); -const { userSchema } = require('@librechat/data-schemas'); - -const User = mongoose.model('User', userSchema); - -module.exports = User; diff --git a/api/models/balanceMethods.js b/api/models/balanceMethods.js index 4b788160aa..7e6321ab2c 100644 --- a/api/models/balanceMethods.js +++ b/api/models/balanceMethods.js @@ -1,9 +1,9 @@ +const { logger } = require('@librechat/data-schemas'); const { ViolationTypes } = require('librechat-data-provider'); -const { Transaction } = require('./Transaction'); +const { createAutoRefillTransaction } = require('./Transaction'); const { logViolation } = require('~/cache'); const { getMultiplier } = require('./tx'); -const { logger } = require('~/config'); -const Balance = require('./Balance'); +const { Balance } = require('~/db/models'); function isInvalidDate(date) { return isNaN(date); @@ -60,7 +60,7 @@ const checkBalanceRecord = async function ({ ) { try { /** @type {{ rate: number, user: string, balance: number, transaction: import('@librechat/data-schemas').ITransaction}} */ - const result = await Transaction.createAutoRefillTransaction({ + const result = await createAutoRefillTransaction({ user: user, tokenType: 'credits', context: 'autoRefill', diff --git a/api/models/convoStructure.spec.js b/api/models/convoStructure.spec.js index e672e0fa1c..33bf0c9b2b 100644 --- a/api/models/convoStructure.spec.js +++ b/api/models/convoStructure.spec.js @@ -1,6 +1,7 @@ const mongoose = require('mongoose'); const { MongoMemoryServer } = require('mongodb-memory-server'); -const { Message, getMessages, bulkSaveMessages } = require('./Message'); +const { getMessages, bulkSaveMessages } = require('./Message'); +const { Message } = require('~/db/models'); // Original version of buildTree function function buildTree({ messages, fileMap }) { @@ -42,7 +43,6 @@ function buildTree({ messages, fileMap }) { } let mongod; - beforeAll(async () => { mongod = await MongoMemoryServer.create(); const uri = mongod.getUri(); diff --git a/api/models/index.js b/api/models/index.js index 73cfa1c96c..7ecb9adcbb 100644 --- a/api/models/index.js +++ b/api/models/index.js @@ -1,13 +1,7 @@ -const { - comparePassword, - deleteUserById, - generateToken, - getUserById, - updateUser, - createUser, - countUsers, - findUser, -} = require('./userMethods'); +const mongoose = require('mongoose'); +const { createMethods } = require('@librechat/data-schemas'); +const methods = createMethods(mongoose); +const { comparePassword } = require('./userMethods'); const { findFileById, createFile, @@ -26,32 +20,12 @@ const { deleteMessagesSince, deleteMessages, } = require('./Message'); -const { - createSession, - findSession, - updateExpiration, - deleteSession, - deleteAllUserSessions, - generateRefreshToken, - countActiveSessions, -} = require('./Session'); const { getConvoTitle, getConvo, saveConvo, deleteConvos } = require('./Conversation'); const { getPreset, getPresets, savePreset, deletePresets } = require('./Preset'); -const { createToken, findToken, updateToken, deleteTokens } = require('./Token'); -const Balance = require('./Balance'); -const User = require('./User'); -const Key = require('./Key'); module.exports = { + ...methods, comparePassword, - deleteUserById, - generateToken, - getUserById, - updateUser, - createUser, - countUsers, - findUser, - findFileById, createFile, updateFile, @@ -77,21 +51,4 @@ module.exports = { getPresets, savePreset, deletePresets, - - createToken, - findToken, - updateToken, - deleteTokens, - - createSession, - findSession, - updateExpiration, - deleteSession, - deleteAllUserSessions, - generateRefreshToken, - countActiveSessions, - - User, - Key, - Balance, }; diff --git a/api/models/inviteUser.js b/api/models/inviteUser.js index 6cd699fd66..eeb42841bf 100644 --- a/api/models/inviteUser.js +++ b/api/models/inviteUser.js @@ -1,7 +1,7 @@ const mongoose = require('mongoose'); -const { getRandomValues, hashToken } = require('~/server/utils/crypto'); -const { createToken, findToken } = require('./Token'); -const logger = require('~/config/winston'); +const { getRandomValues } = require('@librechat/api'); +const { logger, hashToken } = require('@librechat/data-schemas'); +const { createToken, findToken } = require('~/models'); /** * @module inviteUser diff --git a/api/models/plugins/mongoMeili.js b/api/models/plugins/mongoMeili.js deleted file mode 100644 index 75e3738e5d..0000000000 --- a/api/models/plugins/mongoMeili.js +++ /dev/null @@ -1,475 +0,0 @@ -const _ = require('lodash'); -const mongoose = require('mongoose'); -const { MeiliSearch } = require('meilisearch'); -const { parseTextParts, ContentTypes } = require('librechat-data-provider'); -const { cleanUpPrimaryKeyValue } = require('~/lib/utils/misc'); -const logger = require('~/config/meiliLogger'); - -// Environment flags -/** - * Flag to indicate if search is enabled based on environment variables. - * @type {boolean} - */ -const searchEnabled = process.env.SEARCH && process.env.SEARCH.toLowerCase() === 'true'; - -/** - * Flag to indicate if MeiliSearch is enabled based on required environment variables. - * @type {boolean} - */ -const meiliEnabled = process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY && searchEnabled; - -/** - * Validates the required options for configuring the mongoMeili plugin. - * - * @param {Object} options - The configuration options. - * @param {string} options.host - The MeiliSearch host. - * @param {string} options.apiKey - The MeiliSearch API key. - * @param {string} options.indexName - The name of the index. - * @throws {Error} Throws an error if any required option is missing. - */ -const validateOptions = function (options) { - const requiredKeys = ['host', 'apiKey', 'indexName']; - requiredKeys.forEach((key) => { - if (!options[key]) { - throw new Error(`Missing mongoMeili Option: ${key}`); - } - }); -}; - -/** - * Factory function to create a MeiliMongooseModel class which extends a Mongoose model. - * This class contains static and instance methods to synchronize and manage the MeiliSearch index - * corresponding to the MongoDB collection. - * - * @param {Object} config - Configuration object. - * @param {Object} config.index - The MeiliSearch index object. - * @param {Array} config.attributesToIndex - List of attributes to index. - * @returns {Function} A class definition that will be loaded into the Mongoose schema. - */ -const createMeiliMongooseModel = function ({ index, attributesToIndex }) { - // The primary key is assumed to be the first attribute in the attributesToIndex array. - const primaryKey = attributesToIndex[0]; - - class MeiliMongooseModel { - /** - * Synchronizes the data between the MongoDB collection and the MeiliSearch index. - * - * The synchronization process involves: - * 1. Fetching all documents from the MongoDB collection and MeiliSearch index. - * 2. Comparing documents from both sources. - * 3. Deleting documents from MeiliSearch that no longer exist in MongoDB. - * 4. Adding documents to MeiliSearch that exist in MongoDB but not in the index. - * 5. Updating documents in MeiliSearch if key fields (such as `text` or `title`) differ. - * 6. Updating the `_meiliIndex` field in MongoDB to indicate the indexing status. - * - * Note: The function processes documents in batches because MeiliSearch's - * `index.getDocuments` requires an exact limit and `index.addDocuments` does not handle - * partial failures in a batch. - * - * @returns {Promise} Resolves when the synchronization is complete. - */ - static async syncWithMeili() { - try { - let moreDocuments = true; - // Retrieve all MongoDB documents from the collection as plain JavaScript objects. - const mongoDocuments = await this.find().lean(); - - // Helper function to format a document by selecting only the attributes to index - // and omitting keys starting with '$'. - const format = (doc) => - _.omitBy(_.pick(doc, attributesToIndex), (v, k) => k.startsWith('$')); - - // Build a map of MongoDB documents for quick lookup based on the primary key. - const mongoMap = new Map(mongoDocuments.map((doc) => [doc[primaryKey], format(doc)])); - const indexMap = new Map(); - let offset = 0; - const batchSize = 1000; - - // Fetch documents from the MeiliSearch index in batches. - while (moreDocuments) { - const batch = await index.getDocuments({ limit: batchSize, offset }); - if (batch.results.length === 0) { - moreDocuments = false; - } - for (const doc of batch.results) { - indexMap.set(doc[primaryKey], format(doc)); - } - offset += batchSize; - } - - logger.debug('[syncWithMeili]', { indexMap: indexMap.size, mongoMap: mongoMap.size }); - - const updateOps = []; - - // Process documents present in the MeiliSearch index. - for (const [id, doc] of indexMap) { - const update = {}; - update[primaryKey] = id; - if (mongoMap.has(id)) { - // If document exists in MongoDB, check for discrepancies in key fields. - if ( - (doc.text && doc.text !== mongoMap.get(id).text) || - (doc.title && doc.title !== mongoMap.get(id).title) - ) { - logger.debug( - `[syncWithMeili] ${id} had document discrepancy in ${ - doc.text ? 'text' : 'title' - } field`, - ); - updateOps.push({ - updateOne: { filter: update, update: { $set: { _meiliIndex: true } } }, - }); - await index.addDocuments([doc]); - } - } else { - // If the document does not exist in MongoDB, delete it from MeiliSearch. - await index.deleteDocument(id); - updateOps.push({ - updateOne: { filter: update, update: { $set: { _meiliIndex: false } } }, - }); - } - } - - // Process documents present in MongoDB. - for (const [id, doc] of mongoMap) { - const update = {}; - update[primaryKey] = id; - // If the document is missing in the Meili index, add it. - if (!indexMap.has(id)) { - await index.addDocuments([doc]); - updateOps.push({ - updateOne: { filter: update, update: { $set: { _meiliIndex: true } } }, - }); - } else if (doc._meiliIndex === false) { - // If the document exists but is marked as not indexed, update the flag. - updateOps.push({ - updateOne: { filter: update, update: { $set: { _meiliIndex: true } } }, - }); - } - } - - // Execute bulk update operations in MongoDB to update the _meiliIndex flags. - if (updateOps.length > 0) { - await this.collection.bulkWrite(updateOps); - logger.debug( - `[syncWithMeili] Finished indexing ${ - primaryKey === 'messageId' ? 'messages' : 'conversations' - }`, - ); - } - } catch (error) { - logger.error('[syncWithMeili] Error adding document to Meili', error); - } - } - - /** - * Updates settings for the MeiliSearch index. - * - * @param {Object} settings - The settings to update on the MeiliSearch index. - * @returns {Promise} Promise resolving to the update result. - */ - static async setMeiliIndexSettings(settings) { - return await index.updateSettings(settings); - } - - /** - * Searches the MeiliSearch index and optionally populates the results with data from MongoDB. - * - * @param {string} q - The search query. - * @param {Object} params - Additional search parameters for MeiliSearch. - * @param {boolean} populate - Whether to populate search hits with full MongoDB documents. - * @returns {Promise} The search results with populated hits if requested. - */ - static async meiliSearch(q, params, populate) { - const data = await index.search(q, params); - - if (populate) { - // Build a query using the primary key values from the search hits. - const query = {}; - query[primaryKey] = _.map(data.hits, (hit) => cleanUpPrimaryKeyValue(hit[primaryKey])); - - // Build a projection object, including only keys that do not start with '$'. - const projection = Object.keys(this.schema.obj).reduce( - (results, key) => { - if (!key.startsWith('$')) { - results[key] = 1; - } - return results; - }, - { _id: 1, __v: 1 }, - ); - - // Retrieve the full documents from MongoDB. - const hitsFromMongoose = await this.find(query, projection).lean(); - - // Merge the MongoDB documents with the search hits. - const populatedHits = data.hits.map(function (hit) { - const query = {}; - query[primaryKey] = hit[primaryKey]; - const originalHit = _.find(hitsFromMongoose, query); - - return { - ...(originalHit ?? {}), - ...hit, - }; - }); - data.hits = populatedHits; - } - - return data; - } - - /** - * Preprocesses the current document for indexing. - * - * This method: - * - Picks only the defined attributes to index. - * - Omits any keys starting with '$'. - * - Replaces pipe characters ('|') in `conversationId` with '--'. - * - Extracts and concatenates text from an array of content items. - * - * @returns {Object} The preprocessed object ready for indexing. - */ - preprocessObjectForIndex() { - const object = _.omitBy(_.pick(this.toJSON(), attributesToIndex), (v, k) => - k.startsWith('$'), - ); - if (object.conversationId && object.conversationId.includes('|')) { - object.conversationId = object.conversationId.replace(/\|/g, '--'); - } - - if (object.content && Array.isArray(object.content)) { - object.text = parseTextParts(object.content); - delete object.content; - } - - return object; - } - - /** - * Adds the current document to the MeiliSearch index. - * - * The method preprocesses the document, adds it to MeiliSearch, and then updates - * the MongoDB document's `_meiliIndex` flag to true. - * - * @returns {Promise} - */ - async addObjectToMeili() { - const object = this.preprocessObjectForIndex(); - try { - await index.addDocuments([object]); - } catch (error) { - // Error handling can be enhanced as needed. - logger.error('[addObjectToMeili] Error adding document to Meili', error); - } - - await this.collection.updateMany({ _id: this._id }, { $set: { _meiliIndex: true } }); - } - - /** - * Updates the current document in the MeiliSearch index. - * - * @returns {Promise} - */ - async updateObjectToMeili() { - const object = _.omitBy(_.pick(this.toJSON(), attributesToIndex), (v, k) => - k.startsWith('$'), - ); - await index.updateDocuments([object]); - } - - /** - * Deletes the current document from the MeiliSearch index. - * - * @returns {Promise} - */ - async deleteObjectFromMeili() { - await index.deleteDocument(this._id); - } - - /** - * Post-save hook to synchronize the document with MeiliSearch. - * - * If the document is already indexed (i.e. `_meiliIndex` is true), it updates it; - * otherwise, it adds the document to the index. - */ - postSaveHook() { - if (this._meiliIndex) { - this.updateObjectToMeili(); - } else { - this.addObjectToMeili(); - } - } - - /** - * Post-update hook to update the document in MeiliSearch. - * - * This hook is triggered after a document update, ensuring that changes are - * propagated to the MeiliSearch index if the document is indexed. - */ - postUpdateHook() { - if (this._meiliIndex) { - this.updateObjectToMeili(); - } - } - - /** - * Post-remove hook to delete the document from MeiliSearch. - * - * This hook is triggered after a document is removed, ensuring that the document - * is also removed from the MeiliSearch index if it was previously indexed. - */ - postRemoveHook() { - if (this._meiliIndex) { - this.deleteObjectFromMeili(); - } - } - } - - return MeiliMongooseModel; -}; - -/** - * Mongoose plugin to synchronize MongoDB collections with a MeiliSearch index. - * - * This plugin: - * - Validates the provided options. - * - Adds a `_meiliIndex` field to the schema to track indexing status. - * - Sets up a MeiliSearch client and creates an index if it doesn't already exist. - * - Loads class methods for syncing, searching, and managing documents in MeiliSearch. - * - Registers Mongoose hooks (post-save, post-update, post-remove, etc.) to maintain index consistency. - * - * @param {mongoose.Schema} schema - The Mongoose schema to which the plugin is applied. - * @param {Object} options - Configuration options. - * @param {string} options.host - The MeiliSearch host. - * @param {string} options.apiKey - The MeiliSearch API key. - * @param {string} options.indexName - The name of the MeiliSearch index. - * @param {string} options.primaryKey - The primary key field for indexing. - */ -module.exports = function mongoMeili(schema, options) { - validateOptions(options); - - // Add _meiliIndex field to the schema to track if a document has been indexed in MeiliSearch. - schema.add({ - _meiliIndex: { - type: Boolean, - required: false, - select: false, - default: false, - }, - }); - - const { host, apiKey, indexName, primaryKey } = options; - - // Setup the MeiliSearch client. - const client = new MeiliSearch({ host, apiKey }); - - // Create the index asynchronously if it doesn't exist. - client.createIndex(indexName, { primaryKey }); - - // Setup the MeiliSearch index for this schema. - const index = client.index(indexName); - - // Collect attributes from the schema that should be indexed. - const attributesToIndex = [ - ..._.reduce( - schema.obj, - function (results, value, key) { - return value.meiliIndex ? [...results, key] : results; - }, - [], - ), - ]; - - // Load the class methods into the schema. - schema.loadClass(createMeiliMongooseModel({ index, indexName, client, attributesToIndex })); - - // Register Mongoose hooks to synchronize with MeiliSearch. - - // Post-save: synchronize after a document is saved. - schema.post('save', function (doc) { - doc.postSaveHook(); - }); - - // Post-update: synchronize after a document is updated. - schema.post('update', function (doc) { - doc.postUpdateHook(); - }); - - // Post-remove: synchronize after a document is removed. - schema.post('remove', function (doc) { - doc.postRemoveHook(); - }); - - // Pre-deleteMany hook: remove corresponding documents from MeiliSearch when multiple documents are deleted. - schema.pre('deleteMany', async function (next) { - if (!meiliEnabled) { - return next(); - } - - try { - // Check if the schema has a "messages" field to determine if it's a conversation schema. - if (Object.prototype.hasOwnProperty.call(schema.obj, 'messages')) { - const convoIndex = client.index('convos'); - const deletedConvos = await mongoose.model('Conversation').find(this._conditions).lean(); - const promises = deletedConvos.map((convo) => - convoIndex.deleteDocument(convo.conversationId), - ); - await Promise.all(promises); - } - - // Check if the schema has a "messageId" field to determine if it's a message schema. - if (Object.prototype.hasOwnProperty.call(schema.obj, 'messageId')) { - const messageIndex = client.index('messages'); - const deletedMessages = await mongoose.model('Message').find(this._conditions).lean(); - const promises = deletedMessages.map((message) => - messageIndex.deleteDocument(message.messageId), - ); - await Promise.all(promises); - } - return next(); - } catch (error) { - if (meiliEnabled) { - logger.error( - '[MeiliMongooseModel.deleteMany] There was an issue deleting conversation indexes upon deletion. Next startup may be slow due to syncing.', - error, - ); - } - return next(); - } - }); - - // Post-findOneAndUpdate hook: update MeiliSearch index after a document is updated via findOneAndUpdate. - schema.post('findOneAndUpdate', async function (doc) { - if (!meiliEnabled) { - return; - } - - // If the document is unfinished, do not update the index. - if (doc.unfinished) { - return; - } - - let meiliDoc; - // For conversation documents, try to fetch the document from the "convos" index. - if (doc.messages) { - try { - meiliDoc = await client.index('convos').getDocument(doc.conversationId); - } catch (error) { - logger.debug( - '[MeiliMongooseModel.findOneAndUpdate] Convo not found in MeiliSearch and will index ' + - doc.conversationId, - error, - ); - } - } - - // If the MeiliSearch document exists and the title is unchanged, do nothing. - if (meiliDoc && meiliDoc.title === doc.title) { - return; - } - - // Otherwise, trigger a post-save hook to synchronize the document. - doc.postSaveHook(); - }); -}; diff --git a/api/models/schema/convoSchema.js b/api/models/schema/convoSchema.js deleted file mode 100644 index 89cb9c80b5..0000000000 --- a/api/models/schema/convoSchema.js +++ /dev/null @@ -1,18 +0,0 @@ -const mongoose = require('mongoose'); -const mongoMeili = require('../plugins/mongoMeili'); - -const { convoSchema } = require('@librechat/data-schemas'); - -if (process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY) { - convoSchema.plugin(mongoMeili, { - host: process.env.MEILI_HOST, - apiKey: process.env.MEILI_MASTER_KEY, - /** Note: Will get created automatically if it doesn't exist already */ - indexName: 'convos', - primaryKey: 'conversationId', - }); -} - -const Conversation = mongoose.models.Conversation || mongoose.model('Conversation', convoSchema); - -module.exports = Conversation; diff --git a/api/models/schema/messageSchema.js b/api/models/schema/messageSchema.js deleted file mode 100644 index cf97b84eea..0000000000 --- a/api/models/schema/messageSchema.js +++ /dev/null @@ -1,16 +0,0 @@ -const mongoose = require('mongoose'); -const mongoMeili = require('~/models/plugins/mongoMeili'); -const { messageSchema } = require('@librechat/data-schemas'); - -if (process.env.MEILI_HOST && process.env.MEILI_MASTER_KEY) { - messageSchema.plugin(mongoMeili, { - host: process.env.MEILI_HOST, - apiKey: process.env.MEILI_MASTER_KEY, - indexName: 'messages', - primaryKey: 'messageId', - }); -} - -const Message = mongoose.models.Message || mongoose.model('Message', messageSchema); - -module.exports = Message; diff --git a/api/models/schema/pluginAuthSchema.js b/api/models/schema/pluginAuthSchema.js deleted file mode 100644 index 2066eda4c4..0000000000 --- a/api/models/schema/pluginAuthSchema.js +++ /dev/null @@ -1,6 +0,0 @@ -const mongoose = require('mongoose'); -const { pluginAuthSchema } = require('@librechat/data-schemas'); - -const PluginAuth = mongoose.models.Plugin || mongoose.model('PluginAuth', pluginAuthSchema); - -module.exports = PluginAuth; diff --git a/api/models/schema/presetSchema.js b/api/models/schema/presetSchema.js deleted file mode 100644 index 6d03803ace..0000000000 --- a/api/models/schema/presetSchema.js +++ /dev/null @@ -1,6 +0,0 @@ -const mongoose = require('mongoose'); -const { presetSchema } = require('@librechat/data-schemas'); - -const Preset = mongoose.models.Preset || mongoose.model('Preset', presetSchema); - -module.exports = Preset; diff --git a/api/models/spendTokens.js b/api/models/spendTokens.js index 36b71ca9fc..834f740926 100644 --- a/api/models/spendTokens.js +++ b/api/models/spendTokens.js @@ -1,6 +1,5 @@ -const { Transaction } = require('./Transaction'); const { logger } = require('~/config'); - +const { createTransaction, createStructuredTransaction } = require('./Transaction'); /** * Creates up to two transactions to record the spending of tokens. * @@ -33,7 +32,7 @@ const spendTokens = async (txData, tokenUsage) => { let prompt, completion; try { if (promptTokens !== undefined) { - prompt = await Transaction.create({ + prompt = await createTransaction({ ...txData, tokenType: 'prompt', rawAmount: promptTokens === 0 ? 0 : -Math.max(promptTokens, 0), @@ -41,7 +40,7 @@ const spendTokens = async (txData, tokenUsage) => { } if (completionTokens !== undefined) { - completion = await Transaction.create({ + completion = await createTransaction({ ...txData, tokenType: 'completion', rawAmount: completionTokens === 0 ? 0 : -Math.max(completionTokens, 0), @@ -101,7 +100,7 @@ const spendStructuredTokens = async (txData, tokenUsage) => { try { if (promptTokens) { const { input = 0, write = 0, read = 0 } = promptTokens; - prompt = await Transaction.createStructured({ + prompt = await createStructuredTransaction({ ...txData, tokenType: 'prompt', inputTokens: -input, @@ -111,7 +110,7 @@ const spendStructuredTokens = async (txData, tokenUsage) => { } if (completionTokens) { - completion = await Transaction.create({ + completion = await createTransaction({ ...txData, tokenType: 'completion', rawAmount: -completionTokens, diff --git a/api/models/spendTokens.spec.js b/api/models/spendTokens.spec.js index eacf420330..7ee067e589 100644 --- a/api/models/spendTokens.spec.js +++ b/api/models/spendTokens.spec.js @@ -1,8 +1,9 @@ const mongoose = require('mongoose'); const { MongoMemoryServer } = require('mongodb-memory-server'); -const { Transaction } = require('./Transaction'); -const Balance = require('./Balance'); const { spendTokens, spendStructuredTokens } = require('./spendTokens'); +const { createTransaction, createAutoRefillTransaction } = require('./Transaction'); + +require('~/db/models'); // Mock the logger to prevent console output during tests jest.mock('~/config', () => ({ @@ -19,11 +20,15 @@ jest.mock('~/server/services/Config'); describe('spendTokens', () => { let mongoServer; let userId; + let Transaction; + let Balance; beforeAll(async () => { mongoServer = await MongoMemoryServer.create(); - const mongoUri = mongoServer.getUri(); - await mongoose.connect(mongoUri); + await mongoose.connect(mongoServer.getUri()); + + Transaction = mongoose.model('Transaction'); + Balance = mongoose.model('Balance'); }); afterAll(async () => { @@ -197,7 +202,7 @@ describe('spendTokens', () => { // Check that the transaction records show the adjusted values const transactionResults = await Promise.all( transactions.map((t) => - Transaction.create({ + createTransaction({ ...txData, tokenType: t.tokenType, rawAmount: t.rawAmount, @@ -280,7 +285,7 @@ describe('spendTokens', () => { // Check the return values from Transaction.create directly // This is to verify that the incrementValue is not becoming positive - const directResult = await Transaction.create({ + const directResult = await createTransaction({ user: userId, conversationId: 'test-convo-3', model: 'gpt-4', @@ -607,7 +612,7 @@ describe('spendTokens', () => { const promises = []; for (let i = 0; i < numberOfRefills; i++) { promises.push( - Transaction.createAutoRefillTransaction({ + createAutoRefillTransaction({ user: userId, tokenType: 'credits', context: 'concurrent-refill-test', diff --git a/api/models/tx.js b/api/models/tx.js index ddd098b80f..f3ba38652d 100644 --- a/api/models/tx.js +++ b/api/models/tx.js @@ -78,7 +78,7 @@ const tokenValues = Object.assign( 'gpt-3.5-turbo-1106': { prompt: 1, completion: 2 }, 'o4-mini': { prompt: 1.1, completion: 4.4 }, 'o3-mini': { prompt: 1.1, completion: 4.4 }, - o3: { prompt: 10, completion: 40 }, + o3: { prompt: 2, completion: 8 }, 'o1-mini': { prompt: 1.1, completion: 4.4 }, 'o1-preview': { prompt: 15, completion: 60 }, o1: { prompt: 15, completion: 60 }, diff --git a/api/models/userMethods.js b/api/models/userMethods.js index fbcd33aba8..a36409ebcf 100644 --- a/api/models/userMethods.js +++ b/api/models/userMethods.js @@ -1,159 +1,4 @@ const bcrypt = require('bcryptjs'); -const { getBalanceConfig } = require('~/server/services/Config'); -const signPayload = require('~/server/services/signPayload'); -const Balance = require('./Balance'); -const User = require('./User'); - -/** - * Retrieve a user by ID and convert the found user document to a plain object. - * - * @param {string} userId - The ID of the user to find and return as a plain object. - * @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document. - * @returns {Promise} A plain object representing the user document, or `null` if no user is found. - */ -const getUserById = async function (userId, fieldsToSelect = null) { - const query = User.findById(userId); - if (fieldsToSelect) { - query.select(fieldsToSelect); - } - return await query.lean(); -}; - -/** - * Search for a single user based on partial data and return matching user document as plain object. - * @param {Partial} searchCriteria - The partial data to use for searching the user. - * @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document. - * @returns {Promise} A plain object representing the user document, or `null` if no user is found. - */ -const findUser = async function (searchCriteria, fieldsToSelect = null) { - const query = User.findOne(searchCriteria); - if (fieldsToSelect) { - query.select(fieldsToSelect); - } - return await query.lean(); -}; - -/** - * Update a user with new data without overwriting existing properties. - * - * @param {string} userId - The ID of the user to update. - * @param {Object} updateData - An object containing the properties to update. - * @returns {Promise} The updated user document as a plain object, or `null` if no user is found. - */ -const updateUser = async function (userId, updateData) { - const updateOperation = { - $set: updateData, - $unset: { expiresAt: '' }, // Remove the expiresAt field to prevent TTL - }; - return await User.findByIdAndUpdate(userId, updateOperation, { - new: true, - runValidators: true, - }).lean(); -}; - -/** - * Creates a new user, optionally with a TTL of 1 week. - * @param {MongoUser} data - The user data to be created, must contain user_id. - * @param {boolean} [disableTTL=true] - Whether to disable the TTL. Defaults to `true`. - * @param {boolean} [returnUser=false] - Whether to return the created user object. - * @returns {Promise} A promise that resolves to the created user document ID or user object. - * @throws {Error} If a user with the same user_id already exists. - */ -const createUser = async (data, disableTTL = true, returnUser = false) => { - const balance = await getBalanceConfig(); - const userData = { - ...data, - expiresAt: disableTTL ? null : new Date(Date.now() + 604800 * 1000), // 1 week in milliseconds - }; - - if (disableTTL) { - delete userData.expiresAt; - } - - const user = await User.create(userData); - - // If balance is enabled, create or update a balance record for the user using global.interfaceConfig.balance - if (balance?.enabled && balance?.startBalance) { - const update = { - $inc: { tokenCredits: balance.startBalance }, - }; - - if ( - balance.autoRefillEnabled && - balance.refillIntervalValue != null && - balance.refillIntervalUnit != null && - balance.refillAmount != null - ) { - update.$set = { - autoRefillEnabled: true, - refillIntervalValue: balance.refillIntervalValue, - refillIntervalUnit: balance.refillIntervalUnit, - refillAmount: balance.refillAmount, - }; - } - - await Balance.findOneAndUpdate({ user: user._id }, update, { upsert: true, new: true }).lean(); - } - - if (returnUser) { - return user.toObject(); - } - return user._id; -}; - -/** - * Count the number of user documents in the collection based on the provided filter. - * - * @param {Object} [filter={}] - The filter to apply when counting the documents. - * @returns {Promise} The count of documents that match the filter. - */ -const countUsers = async function (filter = {}) { - return await User.countDocuments(filter); -}; - -/** - * Delete a user by their unique ID. - * - * @param {string} userId - The ID of the user to delete. - * @returns {Promise<{ deletedCount: number }>} An object indicating the number of deleted documents. - */ -const deleteUserById = async function (userId) { - try { - const result = await User.deleteOne({ _id: userId }); - if (result.deletedCount === 0) { - return { deletedCount: 0, message: 'No user found with that ID.' }; - } - return { deletedCount: result.deletedCount, message: 'User was deleted successfully.' }; - } catch (error) { - throw new Error('Error deleting user: ' + error.message); - } -}; - -const { SESSION_EXPIRY } = process.env ?? {}; -const expires = eval(SESSION_EXPIRY) ?? 1000 * 60 * 15; - -/** - * Generates a JWT token for a given user. - * - * @param {MongoUser} user - The user for whom the token is being generated. - * @returns {Promise} A promise that resolves to a JWT token. - */ -const generateToken = async (user) => { - if (!user) { - throw new Error('No user provided'); - } - - return await signPayload({ - payload: { - id: user._id, - username: user.username, - provider: user.provider, - email: user.email, - }, - secret: process.env.JWT_SECRET, - expirationTime: expires / 1000, - }); -}; /** * Compares the provided password with the user's password. @@ -167,6 +12,10 @@ const comparePassword = async (user, candidatePassword) => { throw new Error('No user provided'); } + if (!user.password) { + throw new Error('No password, likely an email first registered via Social/OIDC login'); + } + return new Promise((resolve, reject) => { bcrypt.compare(candidatePassword, user.password, (err, isMatch) => { if (err) { @@ -179,11 +28,4 @@ const comparePassword = async (user, candidatePassword) => { module.exports = { comparePassword, - deleteUserById, - generateToken, - getUserById, - countUsers, - createUser, - updateUser, - findUser, }; diff --git a/api/package.json b/api/package.json index 3d3766bde8..1fe8cff2fc 100644 --- a/api/package.json +++ b/api/package.json @@ -1,6 +1,6 @@ { "name": "@librechat/backend", - "version": "v0.7.8", + "version": "v0.7.9-rc1", "description": "", "scripts": { "start": "echo 'please run this from the root directory'", @@ -34,27 +34,27 @@ }, "homepage": "https://librechat.ai", "dependencies": { - "@anthropic-ai/sdk": "^0.37.0", + "@anthropic-ai/sdk": "^0.52.0", "@aws-sdk/client-s3": "^3.758.0", "@aws-sdk/s3-request-presigner": "^3.758.0", "@azure/identity": "^4.7.0", "@azure/search-documents": "^12.0.0", "@azure/storage-blob": "^12.27.0", - "@google/generative-ai": "^0.23.0", + "@google/generative-ai": "^0.24.0", "@googleapis/youtube": "^20.0.0", "@keyv/redis": "^4.3.3", - "@langchain/community": "^0.3.44", - "@langchain/core": "^0.3.57", - "@langchain/google-genai": "^0.2.9", - "@langchain/google-vertexai": "^0.2.9", + "@langchain/community": "^0.3.47", + "@langchain/core": "^0.3.60", + "@langchain/google-genai": "^0.2.13", + "@langchain/google-vertexai": "^0.2.13", "@langchain/textsplitters": "^0.1.0", - "@librechat/agents": "^2.4.37", + "@librechat/agents": "^2.4.51", + "@librechat/api": "*", "@librechat/data-schemas": "*", "@node-saml/passport-saml": "^5.0.0", "@waylaidwanderer/fetch-event-source": "^3.0.1", "axios": "^1.8.2", "bcryptjs": "^2.4.3", - "cohere-ai": "^7.9.1", "compression": "^1.7.4", "connect-redis": "^7.1.0", "cookie": "^0.7.2", @@ -81,15 +81,15 @@ "keyv-file": "^5.1.2", "klona": "^2.0.6", "librechat-data-provider": "*", - "librechat-mcp": "*", "lodash": "^4.17.21", "meilisearch": "^0.38.0", "memorystore": "^1.6.7", "mime": "^3.0.0", "module-alias": "^2.2.3", "mongoose": "^8.12.1", - "multer": "^2.0.0", + "multer": "^2.0.1", "nanoid": "^3.3.7", + "node-fetch": "^2.7.0", "nodemailer": "^6.9.15", "ollama": "^0.5.0", "openai": "^4.96.2", @@ -109,8 +109,9 @@ "tiktoken": "^1.0.15", "traverse": "^0.6.7", "ua-parser-js": "^1.0.36", + "undici": "^7.10.0", "winston": "^3.11.0", - "winston-daily-rotate-file": "^4.7.1", + "winston-daily-rotate-file": "^5.0.0", "youtube-transcript": "^1.2.1", "zod": "^3.22.4" }, diff --git a/api/server/cleanup.js b/api/server/cleanup.js index 5bf336eed5..84164eb641 100644 --- a/api/server/cleanup.js +++ b/api/server/cleanup.js @@ -169,9 +169,6 @@ function disposeClient(client) { client.isGenerativeModel = null; } // Properties specific to OpenAIClient - if (client.ChatGPTClient) { - client.ChatGPTClient = null; - } if (client.completionsUrl) { client.completionsUrl = null; } @@ -220,6 +217,9 @@ function disposeClient(client) { if (client.maxResponseTokens) { client.maxResponseTokens = null; } + if (client.processMemory) { + client.processMemory = null; + } if (client.run) { // Break circular references in run if (client.run.Graph) { diff --git a/api/server/controllers/AskController.js b/api/server/controllers/AskController.js deleted file mode 100644 index 40b209ef35..0000000000 --- a/api/server/controllers/AskController.js +++ /dev/null @@ -1,282 +0,0 @@ -const { getResponseSender, Constants } = require('librechat-data-provider'); -const { - handleAbortError, - createAbortController, - cleanupAbortController, -} = require('~/server/middleware'); -const { - disposeClient, - processReqData, - clientRegistry, - requestDataMap, -} = require('~/server/cleanup'); -const { sendMessage, createOnProgress } = require('~/server/utils'); -const { saveMessage } = require('~/models'); -const { logger } = require('~/config'); - -const AskController = async (req, res, next, initializeClient, addTitle) => { - let { - text, - endpointOption, - conversationId, - modelDisplayLabel, - parentMessageId = null, - overrideParentMessageId = null, - } = req.body; - - let client = null; - let abortKey = null; - let cleanupHandlers = []; - let clientRef = null; - - logger.debug('[AskController]', { - text, - conversationId, - ...endpointOption, - modelsConfig: endpointOption?.modelsConfig ? 'exists' : '', - }); - - let userMessage = null; - let userMessagePromise = null; - let promptTokens = null; - let userMessageId = null; - let responseMessageId = null; - let getAbortData = null; - - const sender = getResponseSender({ - ...endpointOption, - model: endpointOption.modelOptions.model, - modelDisplayLabel, - }); - const initialConversationId = conversationId; - const newConvo = !initialConversationId; - const userId = req.user.id; - - let reqDataContext = { - userMessage, - userMessagePromise, - responseMessageId, - promptTokens, - conversationId, - userMessageId, - }; - - const updateReqData = (data = {}) => { - reqDataContext = processReqData(data, reqDataContext); - abortKey = reqDataContext.abortKey; - userMessage = reqDataContext.userMessage; - userMessagePromise = reqDataContext.userMessagePromise; - responseMessageId = reqDataContext.responseMessageId; - promptTokens = reqDataContext.promptTokens; - conversationId = reqDataContext.conversationId; - userMessageId = reqDataContext.userMessageId; - }; - - let { onProgress: progressCallback, getPartialText } = createOnProgress(); - - const performCleanup = () => { - logger.debug('[AskController] Performing cleanup'); - if (Array.isArray(cleanupHandlers)) { - for (const handler of cleanupHandlers) { - try { - if (typeof handler === 'function') { - handler(); - } - } catch (e) { - // Ignore - } - } - } - - if (abortKey) { - logger.debug('[AskController] Cleaning up abort controller'); - cleanupAbortController(abortKey); - abortKey = null; - } - - if (client) { - disposeClient(client); - client = null; - } - - reqDataContext = null; - userMessage = null; - userMessagePromise = null; - promptTokens = null; - getAbortData = null; - progressCallback = null; - endpointOption = null; - cleanupHandlers = null; - addTitle = null; - - if (requestDataMap.has(req)) { - requestDataMap.delete(req); - } - logger.debug('[AskController] Cleanup completed'); - }; - - try { - ({ client } = await initializeClient({ req, res, endpointOption })); - if (clientRegistry && client) { - clientRegistry.register(client, { userId }, client); - } - - if (client) { - requestDataMap.set(req, { client }); - } - - clientRef = new WeakRef(client); - - getAbortData = () => { - const currentClient = clientRef?.deref(); - const currentText = - currentClient?.getStreamText != null ? currentClient.getStreamText() : getPartialText(); - - return { - sender, - conversationId, - messageId: reqDataContext.responseMessageId, - parentMessageId: overrideParentMessageId ?? userMessageId, - text: currentText, - userMessage: userMessage, - userMessagePromise: userMessagePromise, - promptTokens: reqDataContext.promptTokens, - }; - }; - - const { onStart, abortController } = createAbortController( - req, - res, - getAbortData, - updateReqData, - ); - - const closeHandler = () => { - logger.debug('[AskController] Request closed'); - if (!abortController || abortController.signal.aborted || abortController.requestCompleted) { - return; - } - abortController.abort(); - logger.debug('[AskController] Request aborted on close'); - }; - - res.on('close', closeHandler); - cleanupHandlers.push(() => { - try { - res.removeListener('close', closeHandler); - } catch (e) { - // Ignore - } - }); - - const messageOptions = { - user: userId, - parentMessageId, - conversationId: reqDataContext.conversationId, - overrideParentMessageId, - getReqData: updateReqData, - onStart, - abortController, - progressCallback, - progressOptions: { - res, - }, - }; - - /** @type {TMessage} */ - let response = await client.sendMessage(text, messageOptions); - response.endpoint = endpointOption.endpoint; - - const databasePromise = response.databasePromise; - delete response.databasePromise; - - const { conversation: convoData = {} } = await databasePromise; - const conversation = { ...convoData }; - conversation.title = - conversation && !conversation.title ? null : conversation?.title || 'New Chat'; - - const latestUserMessage = reqDataContext.userMessage; - - if (client?.options?.attachments && latestUserMessage) { - latestUserMessage.files = client.options.attachments; - if (endpointOption?.modelOptions?.model) { - conversation.model = endpointOption.modelOptions.model; - } - delete latestUserMessage.image_urls; - } - - if (!abortController.signal.aborted) { - const finalResponseMessage = { ...response }; - - sendMessage(res, { - final: true, - conversation, - title: conversation.title, - requestMessage: latestUserMessage, - responseMessage: finalResponseMessage, - }); - res.end(); - - if (client?.savedMessageIds && !client.savedMessageIds.has(response.messageId)) { - await saveMessage( - req, - { ...finalResponseMessage, user: userId }, - { context: 'api/server/controllers/AskController.js - response end' }, - ); - } - } - - if (!client?.skipSaveUserMessage && latestUserMessage) { - await saveMessage(req, latestUserMessage, { - context: "api/server/controllers/AskController.js - don't skip saving user message", - }); - } - - if (typeof addTitle === 'function' && parentMessageId === Constants.NO_PARENT && newConvo) { - addTitle(req, { - text, - response: { ...response }, - client, - }) - .then(() => { - logger.debug('[AskController] Title generation started'); - }) - .catch((err) => { - logger.error('[AskController] Error in title generation', err); - }) - .finally(() => { - logger.debug('[AskController] Title generation completed'); - performCleanup(); - }); - } else { - performCleanup(); - } - } catch (error) { - logger.error('[AskController] Error handling request', error); - let partialText = ''; - try { - const currentClient = clientRef?.deref(); - partialText = - currentClient?.getStreamText != null ? currentClient.getStreamText() : getPartialText(); - } catch (getTextError) { - logger.error('[AskController] Error calling getText() during error handling', getTextError); - } - - handleAbortError(res, req, error, { - sender, - partialText, - conversationId: reqDataContext.conversationId, - messageId: reqDataContext.responseMessageId, - parentMessageId: overrideParentMessageId ?? reqDataContext.userMessageId ?? parentMessageId, - userMessageId: reqDataContext.userMessageId, - }) - .catch((err) => { - logger.error('[AskController] Error in `handleAbortError` during catch block', err); - }) - .finally(() => { - performCleanup(); - }); - } -}; - -module.exports = AskController; diff --git a/api/server/controllers/AuthController.js b/api/server/controllers/AuthController.js index a71ce7d59a..3dbb1a2f31 100644 --- a/api/server/controllers/AuthController.js +++ b/api/server/controllers/AuthController.js @@ -1,17 +1,17 @@ -const openIdClient = require('openid-client'); const cookies = require('cookie'); const jwt = require('jsonwebtoken'); +const openIdClient = require('openid-client'); +const { isEnabled } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { - registerUser, - resetPassword, - setAuthTokens, requestPasswordReset, setOpenIDAuthTokens, + resetPassword, + setAuthTokens, + registerUser, } = require('~/server/services/AuthService'); -const { findSession, getUserById, deleteAllUserSessions, findUser } = require('~/models'); +const { findUser, getUserById, deleteAllUserSessions, findSession } = require('~/models'); const { getOpenIdConfig } = require('~/strategies'); -const { logger } = require('~/config'); -const { isEnabled } = require('~/server/utils'); const registrationController = async (req, res) => { try { @@ -96,7 +96,10 @@ const refreshController = async (req, res) => { } // Find the session with the hashed refresh token - const session = await findSession({ userId: userId, refreshToken: refreshToken }); + const session = await findSession({ + userId: userId, + refreshToken: refreshToken, + }); if (session && session.expiration > new Date()) { const token = await setAuthTokens(userId, res, session._id); diff --git a/api/server/controllers/Balance.js b/api/server/controllers/Balance.js index 0361045c72..c892a73b0c 100644 --- a/api/server/controllers/Balance.js +++ b/api/server/controllers/Balance.js @@ -1,4 +1,4 @@ -const Balance = require('~/models/Balance'); +const { Balance } = require('~/db/models'); async function balanceController(req, res) { const balanceData = await Balance.findOne( diff --git a/api/server/controllers/EditController.js b/api/server/controllers/EditController.js index d142d474df..d24e87ce3a 100644 --- a/api/server/controllers/EditController.js +++ b/api/server/controllers/EditController.js @@ -1,3 +1,5 @@ +const { sendEvent } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { getResponseSender } = require('librechat-data-provider'); const { handleAbortError, @@ -10,9 +12,8 @@ const { clientRegistry, requestDataMap, } = require('~/server/cleanup'); -const { sendMessage, createOnProgress } = require('~/server/utils'); +const { createOnProgress } = require('~/server/utils'); const { saveMessage } = require('~/models'); -const { logger } = require('~/config'); const EditController = async (req, res, next, initializeClient) => { let { @@ -84,7 +85,7 @@ const EditController = async (req, res, next, initializeClient) => { } if (abortKey) { - logger.debug('[AskController] Cleaning up abort controller'); + logger.debug('[EditController] Cleaning up abort controller'); cleanupAbortController(abortKey); abortKey = null; } @@ -198,7 +199,7 @@ const EditController = async (req, res, next, initializeClient) => { const finalUserMessage = reqDataContext.userMessage; const finalResponseMessage = { ...response }; - sendMessage(res, { + sendEvent(res, { final: true, conversation, title: conversation.title, diff --git a/api/server/controllers/PluginController.js b/api/server/controllers/PluginController.js index 674e36002a..f7aad84aeb 100644 --- a/api/server/controllers/PluginController.js +++ b/api/server/controllers/PluginController.js @@ -1,9 +1,11 @@ +const { logger } = require('@librechat/data-schemas'); const { CacheKeys, AuthType } = require('librechat-data-provider'); +const { getCustomConfig, getCachedTools } = require('~/server/services/Config'); const { getToolkitKey } = require('~/server/services/ToolService'); -const { getCustomConfig } = require('~/server/services/Config'); +const { getMCPManager, getFlowStateManager } = require('~/config'); const { availableTools } = require('~/app/clients/tools'); -const { getMCPManager } = require('~/config'); const { getLogStores } = require('~/cache'); +const { Constants } = require('librechat-data-provider'); /** * Filters out duplicate plugins from the list of plugins. @@ -84,6 +86,45 @@ const getAvailablePluginsController = async (req, res) => { } }; +function createServerToolsCallback() { + /** + * @param {string} serverName + * @param {TPlugin[] | null} serverTools + */ + return async function (serverName, serverTools) { + try { + const mcpToolsCache = getLogStores(CacheKeys.MCP_TOOLS); + if (!serverName || !mcpToolsCache) { + return; + } + await mcpToolsCache.set(serverName, serverTools); + logger.debug(`MCP tools for ${serverName} added to cache.`); + } catch (error) { + logger.error('Error retrieving MCP tools from cache:', error); + } + }; +} + +function createGetServerTools() { + /** + * Retrieves cached server tools + * @param {string} serverName + * @returns {Promise} + */ + return async function (serverName) { + try { + const mcpToolsCache = getLogStores(CacheKeys.MCP_TOOLS); + if (!mcpToolsCache) { + return null; + } + return await mcpToolsCache.get(serverName); + } catch (error) { + logger.error('Error retrieving MCP tools from cache:', error); + return null; + } + }; +} + /** * Retrieves and returns a list of available tools, either from a cache or by reading a plugin manifest file. * @@ -109,7 +150,16 @@ const getAvailableTools = async (req, res) => { const customConfig = await getCustomConfig(); if (customConfig?.mcpServers != null) { const mcpManager = getMCPManager(); - pluginManifest = await mcpManager.loadManifestTools(pluginManifest); + const flowsCache = getLogStores(CacheKeys.FLOWS); + const flowManager = flowsCache ? getFlowStateManager(flowsCache) : null; + const serverToolsCallback = createServerToolsCallback(); + const getServerTools = createGetServerTools(); + const mcpTools = await mcpManager.loadManifestTools({ + flowManager, + serverToolsCallback, + getServerTools, + }); + pluginManifest = [...mcpTools, ...pluginManifest]; } /** @type {TPlugin[]} */ @@ -123,17 +173,57 @@ const getAvailableTools = async (req, res) => { } }); - const toolDefinitions = req.app.locals.availableTools; - const tools = authenticatedPlugins.filter( - (plugin) => - toolDefinitions[plugin.pluginKey] !== undefined || - (plugin.toolkit === true && - Object.keys(toolDefinitions).some((key) => getToolkitKey(key) === plugin.pluginKey)), - ); + const toolDefinitions = await getCachedTools({ includeGlobal: true }); - await cache.set(CacheKeys.TOOLS, tools); - res.status(200).json(tools); + const toolsOutput = []; + for (const plugin of authenticatedPlugins) { + const isToolDefined = toolDefinitions[plugin.pluginKey] !== undefined; + const isToolkit = + plugin.toolkit === true && + Object.keys(toolDefinitions).some((key) => getToolkitKey(key) === plugin.pluginKey); + + if (!isToolDefined && !isToolkit) { + continue; + } + + const toolToAdd = { ...plugin }; + + if (!plugin.pluginKey.includes(Constants.mcp_delimiter)) { + toolsOutput.push(toolToAdd); + continue; + } + + const parts = plugin.pluginKey.split(Constants.mcp_delimiter); + const serverName = parts[parts.length - 1]; + const serverConfig = customConfig?.mcpServers?.[serverName]; + + if (!serverConfig?.customUserVars) { + toolsOutput.push(toolToAdd); + continue; + } + + const customVarKeys = Object.keys(serverConfig.customUserVars); + + if (customVarKeys.length === 0) { + toolToAdd.authConfig = []; + toolToAdd.authenticated = true; + } else { + toolToAdd.authConfig = Object.entries(serverConfig.customUserVars).map(([key, value]) => ({ + authField: key, + label: value.title || key, + description: value.description || '', + })); + toolToAdd.authenticated = false; + } + + toolsOutput.push(toolToAdd); + } + + const finalTools = filterUniquePlugins(toolsOutput); + await cache.set(CacheKeys.TOOLS, finalTools); + res.status(200).json(finalTools); } catch (error) { + logger.error('[getAvailableTools]', error); res.status(500).json({ message: error.message }); } }; diff --git a/api/server/controllers/TwoFactorController.js b/api/server/controllers/TwoFactorController.js index f5783f45ad..44baf92ee7 100644 --- a/api/server/controllers/TwoFactorController.js +++ b/api/server/controllers/TwoFactorController.js @@ -1,13 +1,13 @@ +const { encryptV3 } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { + verifyTOTP, + getTOTPSecret, + verifyBackupCode, generateTOTPSecret, generateBackupCodes, - verifyTOTP, - verifyBackupCode, - getTOTPSecret, } = require('~/server/services/twoFactorService'); -const { updateUser, getUserById } = require('~/models'); -const { logger } = require('~/config'); -const { encryptV3 } = require('~/server/utils/crypto'); +const { getUserById, updateUser } = require('~/models'); const safeAppTitle = (process.env.APP_TITLE || 'LibreChat').replace(/\s+/g, ''); diff --git a/api/server/controllers/UserController.js b/api/server/controllers/UserController.js index 571c454552..69791dd7a5 100644 --- a/api/server/controllers/UserController.js +++ b/api/server/controllers/UserController.js @@ -5,8 +5,8 @@ const { webSearchKeys, extractWebSearchEnvVars, } = require('librechat-data-provider'); +const { logger } = require('@librechat/data-schemas'); const { - Balance, getFiles, updateUser, deleteFiles, @@ -16,16 +16,15 @@ const { deleteUserById, deleteAllUserSessions, } = require('~/models'); -const User = require('~/models/User'); const { updateUserPluginAuth, deleteUserPluginAuth } = require('~/server/services/PluginService'); const { updateUserPluginsService, deleteUserKey } = require('~/server/services/UserService'); const { verifyEmail, resendVerificationEmail } = require('~/server/services/AuthService'); const { needsRefresh, getNewS3URL } = require('~/server/services/Files/S3/crud'); const { processDeleteRequest } = require('~/server/services/Files/process'); -const { deleteAllSharedLinks } = require('~/models/Share'); +const { Transaction, Balance, User } = require('~/db/models'); const { deleteToolCalls } = require('~/models/ToolCall'); -const { Transaction } = require('~/models/Transaction'); -const { logger } = require('~/config'); +const { deleteAllSharedLinks } = require('~/models'); +const { getMCPManager } = require('~/config'); const getUserController = async (req, res) => { /** @type {MongoUser} */ @@ -105,10 +104,22 @@ const updateUserPluginsController = async (req, res) => { } let keys = Object.keys(auth); - if (keys.length === 0 && pluginKey !== Tools.web_search) { + const values = Object.values(auth); // Used in 'install' block + + const isMCPTool = pluginKey.startsWith('mcp_') || pluginKey.includes(Constants.mcp_delimiter); + + // Early exit condition: + // If keys are empty (meaning auth: {} was likely sent for uninstall, or auth was empty for install) + // AND it's not web_search (which has special key handling to populate `keys` for uninstall) + // AND it's NOT (an uninstall action FOR an MCP tool - we need to proceed for this case to clear all its auth) + // THEN return. + if ( + keys.length === 0 && + pluginKey !== Tools.web_search && + !(action === 'uninstall' && isMCPTool) + ) { return res.status(200).send(); } - const values = Object.values(auth); /** @type {number} */ let status = 200; @@ -135,16 +146,53 @@ const updateUserPluginsController = async (req, res) => { } } } else if (action === 'uninstall') { - for (let i = 0; i < keys.length; i++) { - authService = await deleteUserPluginAuth(user.id, keys[i]); + // const isMCPTool was defined earlier + if (isMCPTool && keys.length === 0) { + // This handles the case where auth: {} is sent for an MCP tool uninstall. + // It means "delete all credentials associated with this MCP pluginKey". + authService = await deleteUserPluginAuth(user.id, null, true, pluginKey); if (authService instanceof Error) { - logger.error('[authService]', authService); + logger.error( + `[authService] Error deleting all auth for MCP tool ${pluginKey}:`, + authService, + ); ({ status, message } = authService); } + } else { + // This handles: + // 1. Web_search uninstall (keys will be populated with all webSearchKeys if auth was {}). + // 2. Other tools uninstall (if keys were provided). + // 3. MCP tool uninstall if specific keys were provided in `auth` (not current frontend behavior). + // If keys is empty for non-MCP tools (and not web_search), this loop won't run, and nothing is deleted. + for (let i = 0; i < keys.length; i++) { + authService = await deleteUserPluginAuth(user.id, keys[i]); // Deletes by authField name + if (authService instanceof Error) { + logger.error('[authService] Error deleting specific auth key:', authService); + ({ status, message } = authService); + } + } } } if (status === 200) { + // If auth was updated successfully, disconnect MCP sessions as they might use these credentials + if (pluginKey.startsWith(Constants.mcp_prefix)) { + try { + const mcpManager = getMCPManager(user.id); + if (mcpManager) { + logger.info( + `[updateUserPluginsController] Disconnecting MCP connections for user ${user.id} after plugin auth update for ${pluginKey}.`, + ); + await mcpManager.disconnectUserConnections(user.id); + } + } catch (disconnectError) { + logger.error( + `[updateUserPluginsController] Error disconnecting MCP connections for user ${user.id} after plugin auth update:`, + disconnectError, + ); + // Do not fail the request for this, but log it. + } + } return res.status(status).send(); } @@ -166,7 +214,11 @@ const deleteUserController = async (req, res) => { await Balance.deleteMany({ user: user._id }); // delete user balances await deletePresets(user.id); // delete user presets /* TODO: Delete Assistant Threads */ - await deleteConvos(user.id); // delete user convos + try { + await deleteConvos(user.id); // delete user convos + } catch (error) { + logger.error('[deleteUserController] Error deleting user convos, likely no convos', error); + } await deleteUserPluginAuth(user.id, null, true); // delete user plugin auth await deleteUserById(user.id); // delete user await deleteAllSharedLinks(user.id); // delete user shared links diff --git a/api/server/controllers/agents/__tests__/v1.spec.js b/api/server/controllers/agents/__tests__/v1.spec.js new file mode 100644 index 0000000000..b097cd98ce --- /dev/null +++ b/api/server/controllers/agents/__tests__/v1.spec.js @@ -0,0 +1,195 @@ +const { duplicateAgent } = require('../v1'); +const { getAgent, createAgent } = require('~/models/Agent'); +const { getActions } = require('~/models/Action'); +const { nanoid } = require('nanoid'); + +jest.mock('~/models/Agent'); +jest.mock('~/models/Action'); +jest.mock('nanoid'); + +describe('duplicateAgent', () => { + let req, res; + + beforeEach(() => { + req = { + params: { id: 'agent_123' }, + user: { id: 'user_456' }, + }; + res = { + status: jest.fn().mockReturnThis(), + json: jest.fn(), + }; + jest.clearAllMocks(); + }); + + it('should duplicate an agent successfully', async () => { + const mockAgent = { + id: 'agent_123', + name: 'Test Agent', + description: 'Test Description', + instructions: 'Test Instructions', + provider: 'openai', + model: 'gpt-4', + tools: ['file_search'], + actions: [], + author: 'user_789', + versions: [{ name: 'Test Agent', version: 1 }], + __v: 0, + }; + + const mockNewAgent = { + id: 'agent_new_123', + name: 'Test Agent (1/2/23, 12:34)', + description: 'Test Description', + instructions: 'Test Instructions', + provider: 'openai', + model: 'gpt-4', + tools: ['file_search'], + actions: [], + author: 'user_456', + versions: [ + { + name: 'Test Agent (1/2/23, 12:34)', + description: 'Test Description', + instructions: 'Test Instructions', + provider: 'openai', + model: 'gpt-4', + tools: ['file_search'], + actions: [], + createdAt: new Date(), + updatedAt: new Date(), + }, + ], + }; + + getAgent.mockResolvedValue(mockAgent); + getActions.mockResolvedValue([]); + nanoid.mockReturnValue('new_123'); + createAgent.mockResolvedValue(mockNewAgent); + + await duplicateAgent(req, res); + + expect(getAgent).toHaveBeenCalledWith({ id: 'agent_123' }); + expect(getActions).toHaveBeenCalledWith({ agent_id: 'agent_123' }, true); + expect(createAgent).toHaveBeenCalledWith( + expect.objectContaining({ + id: 'agent_new_123', + author: 'user_456', + name: expect.stringContaining('Test Agent ('), + description: 'Test Description', + instructions: 'Test Instructions', + provider: 'openai', + model: 'gpt-4', + tools: ['file_search'], + actions: [], + }), + ); + + expect(createAgent).toHaveBeenCalledWith( + expect.not.objectContaining({ + versions: expect.anything(), + __v: expect.anything(), + }), + ); + + expect(res.status).toHaveBeenCalledWith(201); + expect(res.json).toHaveBeenCalledWith({ + agent: mockNewAgent, + actions: [], + }); + }); + + it('should ensure duplicated agent has clean versions array without nested fields', async () => { + const mockAgent = { + id: 'agent_123', + name: 'Test Agent', + description: 'Test Description', + versions: [ + { + name: 'Test Agent', + versions: [{ name: 'Nested' }], + __v: 1, + }, + ], + __v: 2, + }; + + const mockNewAgent = { + id: 'agent_new_123', + name: 'Test Agent (1/2/23, 12:34)', + description: 'Test Description', + versions: [ + { + name: 'Test Agent (1/2/23, 12:34)', + description: 'Test Description', + createdAt: new Date(), + updatedAt: new Date(), + }, + ], + }; + + getAgent.mockResolvedValue(mockAgent); + getActions.mockResolvedValue([]); + nanoid.mockReturnValue('new_123'); + createAgent.mockResolvedValue(mockNewAgent); + + await duplicateAgent(req, res); + + expect(mockNewAgent.versions).toHaveLength(1); + + const firstVersion = mockNewAgent.versions[0]; + expect(firstVersion).not.toHaveProperty('versions'); + expect(firstVersion).not.toHaveProperty('__v'); + + expect(mockNewAgent).not.toHaveProperty('__v'); + + expect(res.status).toHaveBeenCalledWith(201); + }); + + it('should return 404 if agent not found', async () => { + getAgent.mockResolvedValue(null); + + await duplicateAgent(req, res); + + expect(res.status).toHaveBeenCalledWith(404); + expect(res.json).toHaveBeenCalledWith({ + error: 'Agent not found', + status: 'error', + }); + }); + + it('should handle tool_resources.ocr correctly', async () => { + const mockAgent = { + id: 'agent_123', + name: 'Test Agent', + tool_resources: { + ocr: { enabled: true, config: 'test' }, + other: { should: 'not be copied' }, + }, + }; + + getAgent.mockResolvedValue(mockAgent); + getActions.mockResolvedValue([]); + nanoid.mockReturnValue('new_123'); + createAgent.mockResolvedValue({ id: 'agent_new_123' }); + + await duplicateAgent(req, res); + + expect(createAgent).toHaveBeenCalledWith( + expect.objectContaining({ + tool_resources: { + ocr: { enabled: true, config: 'test' }, + }, + }), + ); + }); + + it('should handle errors gracefully', async () => { + getAgent.mockRejectedValue(new Error('Database error')); + + await duplicateAgent(req, res); + + expect(res.status).toHaveBeenCalledWith(500); + expect(res.json).toHaveBeenCalledWith({ error: 'Database error' }); + }); +}); diff --git a/api/server/controllers/agents/callbacks.js b/api/server/controllers/agents/callbacks.js index cedfc6bd62..60e68b5f2d 100644 --- a/api/server/controllers/agents/callbacks.js +++ b/api/server/controllers/agents/callbacks.js @@ -1,4 +1,6 @@ const { nanoid } = require('nanoid'); +const { sendEvent } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { Tools, StepTypes, FileContext } = require('librechat-data-provider'); const { EnvVar, @@ -12,7 +14,6 @@ const { const { processCodeOutput } = require('~/server/services/Files/Code/process'); const { loadAuthValues } = require('~/server/services/Tools/credentials'); const { saveBase64Image } = require('~/server/services/Files/process'); -const { logger, sendEvent } = require('~/config'); class ModelEndHandler { /** @@ -240,9 +241,7 @@ function createToolEndCallback({ req, res, artifactPromises }) { if (output.artifact[Tools.web_search]) { artifactPromises.push( (async () => { - const name = `${output.name}_${output.tool_call_id}_${nanoid()}`; const attachment = { - name, type: Tools.web_search, messageId: metadata.run_id, toolCallId: output.tool_call_id, diff --git a/api/server/controllers/agents/client.js b/api/server/controllers/agents/client.js index 31fd56930e..1bdf809d91 100644 --- a/api/server/controllers/agents/client.js +++ b/api/server/controllers/agents/client.js @@ -1,15 +1,16 @@ -// const { HttpsProxyAgent } = require('https-proxy-agent'); -// const { -// Constants, -// ImageDetail, -// EModelEndpoint, -// resolveHeaders, -// validateVisionModel, -// mapModelToAzureConfig, -// } = require('librechat-data-provider'); require('events').EventEmitter.defaultMaxListeners = 100; +const { logger } = require('@librechat/data-schemas'); +const { + sendEvent, + createRun, + Tokenizer, + checkAccess, + memoryInstructions, + createMemoryProcessor, +} = require('@librechat/api'); const { Callback, + Providers, GraphEvents, formatMessage, formatAgentMessages, @@ -19,25 +20,41 @@ const { } = require('@librechat/agents'); const { Constants, + Permissions, VisionModes, ContentTypes, EModelEndpoint, KnownEndpoints, + PermissionTypes, isAgentsEndpoint, AgentCapabilities, bedrockInputSchema, removeNullishValues, } = require('librechat-data-provider'); -const { getCustomEndpointConfig, checkCapability } = require('~/server/services/Config'); -const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts'); -const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); +const { DynamicStructuredTool } = require('@langchain/core/tools'); const { getBufferString, HumanMessage } = require('@langchain/core/messages'); +const { createGetMCPAuthMap, checkCapability } = require('~/server/services/Config'); +const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts'); +const { initializeAgent } = require('~/server/services/Endpoints/agents/agent'); +const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); +const { getFormattedMemories, deleteMemory, setMemory } = require('~/models'); const { encodeAndFormat } = require('~/server/services/Files/images/encode'); -const initOpenAI = require('~/server/services/Endpoints/openAI/initialize'); -const Tokenizer = require('~/server/services/Tokenizer'); +const { getProviderConfig } = require('~/server/services/Endpoints'); const BaseClient = require('~/app/clients/BaseClient'); -const { logger, sendEvent } = require('~/config'); -const { createRun } = require('./run'); +const { getRoleByName } = require('~/models/Role'); +const { loadAgent } = require('~/models/Agent'); +const { getMCPManager } = require('~/config'); + +const omitTitleOptions = new Set([ + 'stream', + 'thinking', + 'streaming', + 'clientOptions', + 'thinkingConfig', + 'thinkingBudget', + 'includeThoughts', + 'maxOutputTokens', +]); /** * @param {ServerRequest} req @@ -57,12 +74,8 @@ const legacyContentEndpoints = new Set([KnownEndpoints.groq, KnownEndpoints.deep const noSystemModelRegex = [/\b(o1-preview|o1-mini|amazon\.titan-text)\b/gi]; -// const { processMemory, memoryInstructions } = require('~/server/services/Endpoints/agents/memory'); -// const { getFormattedMemories } = require('~/models/Memory'); -// const { getCurrentDateTime } = require('~/utils'); - function createTokenCounter(encoding) { - return (message) => { + return function (message) { const countTokens = (text) => Tokenizer.getTokenCount(text, encoding); return getTokenCountForMessage(message, countTokens); }; @@ -123,6 +136,8 @@ class AgentClient extends BaseClient { this.usage; /** @type {Record} */ this.indexTokenCountMap = {}; + /** @type {(messages: BaseMessage[]) => Promise} */ + this.processMemory; } /** @@ -137,55 +152,10 @@ class AgentClient extends BaseClient { } /** - * - * Checks if the model is a vision model based on request attachments and sets the appropriate options: - * - Sets `this.modelOptions.model` to `gpt-4-vision-preview` if the request is a vision request. - * - Sets `this.isVisionModel` to `true` if vision request. - * - Deletes `this.modelOptions.stop` if vision request. + * `AgentClient` is not opinionated about vision requests, so we don't do anything here * @param {MongoFile[]} attachments */ - checkVisionRequest(attachments) { - // if (!attachments) { - // return; - // } - // const availableModels = this.options.modelsConfig?.[this.options.endpoint]; - // if (!availableModels) { - // return; - // } - // let visionRequestDetected = false; - // for (const file of attachments) { - // if (file?.type?.includes('image')) { - // visionRequestDetected = true; - // break; - // } - // } - // if (!visionRequestDetected) { - // return; - // } - // this.isVisionModel = validateVisionModel({ model: this.modelOptions.model, availableModels }); - // if (this.isVisionModel) { - // delete this.modelOptions.stop; - // return; - // } - // for (const model of availableModels) { - // if (!validateVisionModel({ model, availableModels })) { - // continue; - // } - // this.modelOptions.model = model; - // this.isVisionModel = true; - // delete this.modelOptions.stop; - // return; - // } - // if (!availableModels.includes(this.defaultVisionModel)) { - // return; - // } - // if (!validateVisionModel({ model: this.defaultVisionModel, availableModels })) { - // return; - // } - // this.modelOptions.model = this.defaultVisionModel; - // this.isVisionModel = true; - // delete this.modelOptions.stop; - } + checkVisionRequest() {} getSaveOptions() { // TODO: @@ -269,24 +239,6 @@ class AgentClient extends BaseClient { .filter(Boolean) .join('\n') .trim(); - // this.systemMessage = getCurrentDateTime(); - // const { withKeys, withoutKeys } = await getFormattedMemories({ - // userId: this.options.req.user.id, - // }); - // processMemory({ - // userId: this.options.req.user.id, - // message: this.options.req.body.text, - // parentMessageId, - // memory: withKeys, - // thread_id: this.conversationId, - // }).catch((error) => { - // logger.error('Memory Agent failed to process memory', error); - // }); - - // this.systemMessage += '\n\n' + memoryInstructions; - // if (withoutKeys) { - // this.systemMessage += `\n\n# Existing memory about the user:\n${withoutKeys}`; - // } if (this.options.attachments) { const attachments = await this.options.attachments; @@ -370,6 +322,37 @@ class AgentClient extends BaseClient { systemContent = this.augmentedPrompt + systemContent; } + // Inject MCP server instructions if available + const ephemeralAgent = this.options.req.body.ephemeralAgent; + let mcpServers = []; + + // Check for ephemeral agent MCP servers + if (ephemeralAgent && ephemeralAgent.mcp && ephemeralAgent.mcp.length > 0) { + mcpServers = ephemeralAgent.mcp; + } + // Check for regular agent MCP tools + else if (this.options.agent && this.options.agent.tools) { + mcpServers = this.options.agent.tools + .filter( + (tool) => + tool instanceof DynamicStructuredTool && tool.name.includes(Constants.mcp_delimiter), + ) + .map((tool) => tool.name.split(Constants.mcp_delimiter).pop()) + .filter(Boolean); + } + + if (mcpServers.length > 0) { + try { + const mcpInstructions = getMCPManager().formatInstructionsForContext(mcpServers); + if (mcpInstructions) { + systemContent = [systemContent, mcpInstructions].filter(Boolean).join('\n\n'); + logger.debug('[AgentClient] Injected MCP instructions for servers:', mcpServers); + } + } catch (error) { + logger.error('[AgentClient] Failed to inject MCP instructions:', error); + } + } + if (systemContent) { this.options.agent.instructions = systemContent; } @@ -399,9 +382,158 @@ class AgentClient extends BaseClient { opts.getReqData({ promptTokens }); } + const withoutKeys = await this.useMemory(); + if (withoutKeys) { + systemContent += `${memoryInstructions}\n\n# Existing memory about the user:\n${withoutKeys}`; + } + + if (systemContent) { + this.options.agent.instructions = systemContent; + } + return result; } + /** + * @returns {Promise} + */ + async useMemory() { + const user = this.options.req.user; + if (user.personalization?.memories === false) { + return; + } + const hasAccess = await checkAccess({ + user, + permissionType: PermissionTypes.MEMORIES, + permissions: [Permissions.USE], + getRoleByName, + }); + + if (!hasAccess) { + logger.debug( + `[api/server/controllers/agents/client.js #useMemory] User ${user.id} does not have USE permission for memories`, + ); + return; + } + /** @type {TCustomConfig['memory']} */ + const memoryConfig = this.options.req?.app?.locals?.memory; + if (!memoryConfig || memoryConfig.disabled === true) { + return; + } + + /** @type {Agent} */ + let prelimAgent; + const allowedProviders = new Set( + this.options.req?.app?.locals?.[EModelEndpoint.agents]?.allowedProviders, + ); + try { + if (memoryConfig.agent?.id != null && memoryConfig.agent.id !== this.options.agent.id) { + prelimAgent = await loadAgent({ + req: this.options.req, + agent_id: memoryConfig.agent.id, + endpoint: EModelEndpoint.agents, + }); + } else if ( + memoryConfig.agent?.id == null && + memoryConfig.agent?.model != null && + memoryConfig.agent?.provider != null + ) { + prelimAgent = { id: Constants.EPHEMERAL_AGENT_ID, ...memoryConfig.agent }; + } + } catch (error) { + logger.error( + '[api/server/controllers/agents/client.js #useMemory] Error loading agent for memory', + error, + ); + } + + const agent = await initializeAgent({ + req: this.options.req, + res: this.options.res, + agent: prelimAgent, + allowedProviders, + }); + + if (!agent) { + logger.warn( + '[api/server/controllers/agents/client.js #useMemory] No agent found for memory', + memoryConfig, + ); + return; + } + + const llmConfig = Object.assign( + { + provider: agent.provider, + model: agent.model, + }, + agent.model_parameters, + ); + + /** @type {import('@librechat/api').MemoryConfig} */ + const config = { + validKeys: memoryConfig.validKeys, + instructions: agent.instructions, + llmConfig, + tokenLimit: memoryConfig.tokenLimit, + }; + + const userId = this.options.req.user.id + ''; + const messageId = this.responseMessageId + ''; + const conversationId = this.conversationId + ''; + const [withoutKeys, processMemory] = await createMemoryProcessor({ + userId, + config, + messageId, + conversationId, + memoryMethods: { + setMemory, + deleteMemory, + getFormattedMemories, + }, + res: this.options.res, + }); + + this.processMemory = processMemory; + return withoutKeys; + } + + /** + * @param {BaseMessage[]} messages + * @returns {Promise} + */ + async runMemory(messages) { + try { + if (this.processMemory == null) { + return; + } + /** @type {TCustomConfig['memory']} */ + const memoryConfig = this.options.req?.app?.locals?.memory; + const messageWindowSize = memoryConfig?.messageWindowSize ?? 5; + + let messagesToProcess = [...messages]; + if (messages.length > messageWindowSize) { + for (let i = messages.length - messageWindowSize; i >= 0; i--) { + const potentialWindow = messages.slice(i, i + messageWindowSize); + if (potentialWindow[0]?.role === 'user') { + messagesToProcess = [...potentialWindow]; + break; + } + } + + if (messagesToProcess.length === messages.length) { + messagesToProcess = [...messages.slice(-messageWindowSize)]; + } + } + + const bufferString = getBufferString(messagesToProcess); + const bufferMessage = new HumanMessage(`# Current Chat:\n\n${bufferString}`); + return await this.processMemory([bufferMessage]); + } catch (error) { + logger.error('Memory Agent failed to process memory', error); + } + } + /** @type {sendCompletion} */ async sendCompletion(payload, opts = {}) { await this.chatCompletion({ @@ -544,100 +676,13 @@ class AgentClient extends BaseClient { let config; /** @type {ReturnType} */ let run; + /** @type {Promise<(TAttachment | null)[] | undefined>} */ + let memoryPromise; try { if (!abortController) { abortController = new AbortController(); } - // if (this.options.headers) { - // opts.defaultHeaders = { ...opts.defaultHeaders, ...this.options.headers }; - // } - - // if (this.options.proxy) { - // opts.httpAgent = new HttpsProxyAgent(this.options.proxy); - // } - - // if (this.isVisionModel) { - // modelOptions.max_tokens = 4000; - // } - - // /** @type {TAzureConfig | undefined} */ - // const azureConfig = this.options?.req?.app?.locals?.[EModelEndpoint.azureOpenAI]; - - // if ( - // (this.azure && this.isVisionModel && azureConfig) || - // (azureConfig && this.isVisionModel && this.options.endpoint === EModelEndpoint.azureOpenAI) - // ) { - // const { modelGroupMap, groupMap } = azureConfig; - // const { - // azureOptions, - // baseURL, - // headers = {}, - // serverless, - // } = mapModelToAzureConfig({ - // modelName: modelOptions.model, - // modelGroupMap, - // groupMap, - // }); - // opts.defaultHeaders = resolveHeaders(headers); - // this.langchainProxy = extractBaseURL(baseURL); - // this.apiKey = azureOptions.azureOpenAIApiKey; - - // const groupName = modelGroupMap[modelOptions.model].group; - // this.options.addParams = azureConfig.groupMap[groupName].addParams; - // this.options.dropParams = azureConfig.groupMap[groupName].dropParams; - // // Note: `forcePrompt` not re-assigned as only chat models are vision models - - // this.azure = !serverless && azureOptions; - // this.azureEndpoint = - // !serverless && genAzureChatCompletion(this.azure, modelOptions.model, this); - // } - - // if (this.azure || this.options.azure) { - // /* Azure Bug, extremely short default `max_tokens` response */ - // if (!modelOptions.max_tokens && modelOptions.model === 'gpt-4-vision-preview') { - // modelOptions.max_tokens = 4000; - // } - - // /* Azure does not accept `model` in the body, so we need to remove it. */ - // delete modelOptions.model; - - // opts.baseURL = this.langchainProxy - // ? constructAzureURL({ - // baseURL: this.langchainProxy, - // azureOptions: this.azure, - // }) - // : this.azureEndpoint.split(/(? { - // delete modelOptions[param]; - // }); - // logger.debug('[api/server/controllers/agents/client.js #chatCompletion] dropped params', { - // dropParams: this.options.dropParams, - // modelOptions, - // }); - // } - /** @type {TCustomConfig['endpoints']['agents']} */ const agentsEConfig = this.options.req.app.locals[EModelEndpoint.agents]; @@ -647,13 +692,16 @@ class AgentClient extends BaseClient { last_agent_index: this.agentConfigs?.size ?? 0, user_id: this.user ?? this.options.req.user?.id, hide_sequential_outputs: this.options.agent.hide_sequential_outputs, + user: this.options.req.user, }, - recursionLimit: agentsEConfig?.recursionLimit, + recursionLimit: agentsEConfig?.recursionLimit ?? 25, signal: abortController.signal, streamMode: 'values', version: 'v2', }; + const getUserMCPAuthMap = await createGetMCPAuthMap(); + const toolSet = new Set((this.options.agent.tools ?? []).map((tool) => tool && tool.name)); let { messages: initialMessages, indexTokenCountMap } = formatAgentMessages( payload, @@ -734,6 +782,10 @@ class AgentClient extends BaseClient { messages = addCacheControl(messages); } + if (i === 0) { + memoryPromise = this.runMemory(messages); + } + run = await createRun({ agent, req: this.options.req, @@ -769,10 +821,23 @@ class AgentClient extends BaseClient { run.Graph.contentData = contentData; } - const encoding = this.getEncoding(); + try { + if (getUserMCPAuthMap) { + config.configurable.userMCPAuthMap = await getUserMCPAuthMap({ + tools: agent.tools, + userId: this.options.req.user.id, + }); + } + } catch (err) { + logger.error( + `[api/server/controllers/agents/client.js #chatCompletion] Error getting custom user vars for agent ${agent.id}`, + err, + ); + } + await run.processStream({ messages }, config, { keepContent: i !== 0, - tokenCounter: createTokenCounter(encoding), + tokenCounter: createTokenCounter(this.getEncoding()), indexTokenCountMap: currentIndexCountMap, maxContextTokens: agent.maxContextTokens, callbacks: { @@ -887,6 +952,12 @@ class AgentClient extends BaseClient { }); try { + if (memoryPromise) { + const attachments = await memoryPromise; + if (attachments && attachments.length > 0) { + this.artifactPromises.push(...attachments); + } + } await this.recordCollectedUsage({ context: 'message' }); } catch (err) { logger.error( @@ -895,6 +966,12 @@ class AgentClient extends BaseClient { ); } } catch (err) { + if (memoryPromise) { + const attachments = await memoryPromise; + if (attachments && attachments.length > 0) { + this.artifactPromises.push(...attachments); + } + } logger.error( '[api/server/controllers/agents/client.js #sendCompletion] Operation aborted', err, @@ -923,23 +1000,26 @@ class AgentClient extends BaseClient { throw new Error('Run not initialized'); } const { handleLLMEnd, collected: collectedMetadata } = createMetadataAggregator(); - const endpoint = this.options.agent.endpoint; - const { req, res } = this.options; + const { req, res, agent } = this.options; + const endpoint = agent.endpoint; + /** @type {import('@librechat/agents').ClientOptions} */ let clientOptions = { maxTokens: 75, + model: agent.model_parameters.model, }; - let endpointConfig = req.app.locals[endpoint]; + + const { getOptions, overrideProvider, customEndpointConfig } = + await getProviderConfig(endpoint); + + /** @type {TEndpoint | undefined} */ + const endpointConfig = req.app.locals[endpoint] ?? customEndpointConfig; if (!endpointConfig) { - try { - endpointConfig = await getCustomEndpointConfig(endpoint); - } catch (err) { - logger.error( - '[api/server/controllers/agents/client.js #titleConvo] Error getting custom endpoint config', - err, - ); - } + logger.warn( + '[api/server/controllers/agents/client.js #titleConvo] Error getting endpoint config', + ); } + if ( endpointConfig && endpointConfig.titleModel && @@ -947,30 +1027,50 @@ class AgentClient extends BaseClient { ) { clientOptions.model = endpointConfig.titleModel; } + + const options = await getOptions({ + req, + res, + optionsOnly: true, + overrideEndpoint: endpoint, + overrideModel: clientOptions.model, + endpointOption: { model_parameters: clientOptions }, + }); + + let provider = options.provider ?? overrideProvider ?? agent.provider; if ( endpoint === EModelEndpoint.azureOpenAI && - clientOptions.model && - this.options.agent.model_parameters.model !== clientOptions.model + options.llmConfig?.azureOpenAIApiInstanceName == null ) { - clientOptions = - ( - await initOpenAI({ - req, - res, - optionsOnly: true, - overrideModel: clientOptions.model, - overrideEndpoint: endpoint, - endpointOption: { - model_parameters: clientOptions, - }, - }) - )?.llmConfig ?? clientOptions; + provider = Providers.OPENAI; } - if (/\b(o\d)\b/i.test(clientOptions.model) && clientOptions.maxTokens != null) { + + /** @type {import('@librechat/agents').ClientOptions} */ + clientOptions = { ...options.llmConfig }; + if (options.configOptions) { + clientOptions.configuration = options.configOptions; + } + + // Ensure maxTokens is set for non-o1 models + if (!/\b(o\d)\b/i.test(clientOptions.model) && !clientOptions.maxTokens) { + clientOptions.maxTokens = 75; + } else if (/\b(o\d)\b/i.test(clientOptions.model) && clientOptions.maxTokens != null) { delete clientOptions.maxTokens; } + + clientOptions = Object.assign( + Object.fromEntries( + Object.entries(clientOptions).filter(([key]) => !omitTitleOptions.has(key)), + ), + ); + + if (provider === Providers.GOOGLE) { + clientOptions.json = true; + } + try { const titleResult = await this.run.generateTitle({ + provider, inputText: text, contentParts: this.contentParts, clientOptions, @@ -988,8 +1088,10 @@ class AgentClient extends BaseClient { let input_tokens, output_tokens; if (item.usage) { - input_tokens = item.usage.input_tokens || item.usage.inputTokens; - output_tokens = item.usage.output_tokens || item.usage.outputTokens; + input_tokens = + item.usage.prompt_tokens || item.usage.input_tokens || item.usage.inputTokens; + output_tokens = + item.usage.completion_tokens || item.usage.output_tokens || item.usage.outputTokens; } else if (item.tokenUsage) { input_tokens = item.tokenUsage.promptTokens; output_tokens = item.tokenUsage.completionTokens; diff --git a/api/server/controllers/agents/errors.js b/api/server/controllers/agents/errors.js index fb4de45085..b3bb1cea65 100644 --- a/api/server/controllers/agents/errors.js +++ b/api/server/controllers/agents/errors.js @@ -1,10 +1,10 @@ // errorHandler.js -const { logger } = require('~/config'); -const getLogStores = require('~/cache/getLogStores'); +const { logger } = require('@librechat/data-schemas'); const { CacheKeys, ViolationTypes } = require('librechat-data-provider'); +const { sendResponse } = require('~/server/middleware/error'); const { recordUsage } = require('~/server/services/Threads'); const { getConvo } = require('~/models/Conversation'); -const { sendResponse } = require('~/server/utils'); +const getLogStores = require('~/cache/getLogStores'); /** * @typedef {Object} ErrorHandlerContext @@ -75,7 +75,7 @@ const createErrorHandler = ({ req, res, getContext, originPath = '/assistants/ch } else if (/Files.*are invalid/.test(error.message)) { const errorMessage = `Files are invalid, or may not have uploaded yet.${ endpoint === 'azureAssistants' - ? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.' + ? " If using Azure OpenAI, files are only available in the region of the assistant's model at the time of upload." : '' }`; return sendResponse(req, res, messageData, errorMessage); diff --git a/api/server/controllers/agents/llm.js b/api/server/controllers/agents/llm.js deleted file mode 100644 index 438a38b6cb..0000000000 --- a/api/server/controllers/agents/llm.js +++ /dev/null @@ -1,106 +0,0 @@ -const { HttpsProxyAgent } = require('https-proxy-agent'); -const { resolveHeaders } = require('librechat-data-provider'); -const { createLLM } = require('~/app/clients/llm'); - -/** - * Initializes and returns a Language Learning Model (LLM) instance. - * - * @param {Object} options - Configuration options for the LLM. - * @param {string} options.model - The model identifier. - * @param {string} options.modelName - The specific name of the model. - * @param {number} options.temperature - The temperature setting for the model. - * @param {number} options.presence_penalty - The presence penalty for the model. - * @param {number} options.frequency_penalty - The frequency penalty for the model. - * @param {number} options.max_tokens - The maximum number of tokens for the model output. - * @param {boolean} options.streaming - Whether to use streaming for the model output. - * @param {Object} options.context - The context for the conversation. - * @param {number} options.tokenBuffer - The token buffer size. - * @param {number} options.initialMessageCount - The initial message count. - * @param {string} options.conversationId - The ID of the conversation. - * @param {string} options.user - The user identifier. - * @param {string} options.langchainProxy - The langchain proxy URL. - * @param {boolean} options.useOpenRouter - Whether to use OpenRouter. - * @param {Object} options.options - Additional options. - * @param {Object} options.options.headers - Custom headers for the request. - * @param {string} options.options.proxy - Proxy URL. - * @param {Object} options.options.req - The request object. - * @param {Object} options.options.res - The response object. - * @param {boolean} options.options.debug - Whether to enable debug mode. - * @param {string} options.apiKey - The API key for authentication. - * @param {Object} options.azure - Azure-specific configuration. - * @param {Object} options.abortController - The AbortController instance. - * @returns {Object} The initialized LLM instance. - */ -function initializeLLM(options) { - const { - model, - modelName, - temperature, - presence_penalty, - frequency_penalty, - max_tokens, - streaming, - user, - langchainProxy, - useOpenRouter, - options: { headers, proxy }, - apiKey, - azure, - } = options; - - const modelOptions = { - modelName: modelName || model, - temperature, - presence_penalty, - frequency_penalty, - user, - }; - - if (max_tokens) { - modelOptions.max_tokens = max_tokens; - } - - const configOptions = {}; - - if (langchainProxy) { - configOptions.basePath = langchainProxy; - } - - if (useOpenRouter) { - configOptions.basePath = 'https://openrouter.ai/api/v1'; - configOptions.baseOptions = { - headers: { - 'HTTP-Referer': 'https://librechat.ai', - 'X-Title': 'LibreChat', - }, - }; - } - - if (headers && typeof headers === 'object' && !Array.isArray(headers)) { - configOptions.baseOptions = { - headers: resolveHeaders({ - ...headers, - ...configOptions?.baseOptions?.headers, - }), - }; - } - - if (proxy) { - configOptions.httpAgent = new HttpsProxyAgent(proxy); - configOptions.httpsAgent = new HttpsProxyAgent(proxy); - } - - const llm = createLLM({ - modelOptions, - configOptions, - openAIApiKey: apiKey, - azure, - streaming, - }); - - return llm; -} - -module.exports = { - initializeLLM, -}; diff --git a/api/server/controllers/agents/request.js b/api/server/controllers/agents/request.js index fcee62edc7..2c8e424b5d 100644 --- a/api/server/controllers/agents/request.js +++ b/api/server/controllers/agents/request.js @@ -1,3 +1,5 @@ +const { sendEvent } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { Constants } = require('librechat-data-provider'); const { handleAbortError, @@ -5,17 +7,18 @@ const { cleanupAbortController, } = require('~/server/middleware'); const { disposeClient, clientRegistry, requestDataMap } = require('~/server/cleanup'); -const { sendMessage } = require('~/server/utils'); const { saveMessage } = require('~/models'); -const { logger } = require('~/config'); const AgentController = async (req, res, next, initializeClient, addTitle) => { let { text, endpointOption, conversationId, + isContinued = false, + editedContent = null, parentMessageId = null, overrideParentMessageId = null, + responseMessageId: editedResponseMessageId = null, } = req.body; let sender; @@ -67,7 +70,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { handler(); } } catch (e) { - // Ignore cleanup errors + logger.error('[AgentController] Error in cleanup handler', e); } } } @@ -155,7 +158,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { try { res.removeListener('close', closeHandler); } catch (e) { - // Ignore + logger.error('[AgentController] Error removing close listener', e); } }); @@ -163,10 +166,14 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { user: userId, onStart, getReqData, + isContinued, + editedContent, conversationId, parentMessageId, abortController, overrideParentMessageId, + isEdited: !!editedContent, + responseMessageId: editedResponseMessageId, progressOptions: { res, }, @@ -206,7 +213,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { // Create a new response object with minimal copies const finalResponse = { ...response }; - sendMessage(res, { + sendEvent(res, { final: true, conversation, title: conversation.title, @@ -228,7 +235,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { // Save user message if needed if (!client.skipSaveUserMessage) { await saveMessage(req, userMessage, { - context: 'api/server/controllers/agents/request.js - don\'t skip saving user message', + context: "api/server/controllers/agents/request.js - don't skip saving user message", }); } diff --git a/api/server/controllers/agents/run.js b/api/server/controllers/agents/run.js deleted file mode 100644 index 2452e66233..0000000000 --- a/api/server/controllers/agents/run.js +++ /dev/null @@ -1,94 +0,0 @@ -const { Run, Providers } = require('@librechat/agents'); -const { providerEndpointMap, KnownEndpoints } = require('librechat-data-provider'); - -/** - * @typedef {import('@librechat/agents').t} t - * @typedef {import('@librechat/agents').StandardGraphConfig} StandardGraphConfig - * @typedef {import('@librechat/agents').StreamEventData} StreamEventData - * @typedef {import('@librechat/agents').EventHandler} EventHandler - * @typedef {import('@librechat/agents').GraphEvents} GraphEvents - * @typedef {import('@librechat/agents').LLMConfig} LLMConfig - * @typedef {import('@librechat/agents').IState} IState - */ - -const customProviders = new Set([ - Providers.XAI, - Providers.OLLAMA, - Providers.DEEPSEEK, - Providers.OPENROUTER, -]); - -/** - * Creates a new Run instance with custom handlers and configuration. - * - * @param {Object} options - The options for creating the Run instance. - * @param {ServerRequest} [options.req] - The server request. - * @param {string | undefined} [options.runId] - Optional run ID; otherwise, a new run ID will be generated. - * @param {Agent} options.agent - The agent for this run. - * @param {AbortSignal} options.signal - The signal for this run. - * @param {Record | undefined} [options.customHandlers] - Custom event handlers. - * @param {boolean} [options.streaming=true] - Whether to use streaming. - * @param {boolean} [options.streamUsage=true] - Whether to stream usage information. - * @returns {Promise>} A promise that resolves to a new Run instance. - */ -async function createRun({ - runId, - agent, - signal, - customHandlers, - streaming = true, - streamUsage = true, -}) { - const provider = providerEndpointMap[agent.provider] ?? agent.provider; - /** @type {LLMConfig} */ - const llmConfig = Object.assign( - { - provider, - streaming, - streamUsage, - }, - agent.model_parameters, - ); - - /** Resolves issues with new OpenAI usage field */ - if ( - customProviders.has(agent.provider) || - (agent.provider === Providers.OPENAI && agent.endpoint !== agent.provider) - ) { - llmConfig.streamUsage = false; - llmConfig.usage = true; - } - - /** @type {'reasoning_content' | 'reasoning'} */ - let reasoningKey; - if ( - llmConfig.configuration?.baseURL?.includes(KnownEndpoints.openrouter) || - (agent.endpoint && agent.endpoint.toLowerCase().includes(KnownEndpoints.openrouter)) - ) { - reasoningKey = 'reasoning'; - } - - /** @type {StandardGraphConfig} */ - const graphConfig = { - signal, - llmConfig, - reasoningKey, - tools: agent.tools, - instructions: agent.instructions, - additional_instructions: agent.additional_instructions, - // toolEnd: agent.end_after_tools, - }; - - // TEMPORARY FOR TESTING - if (agent.provider === Providers.ANTHROPIC || agent.provider === Providers.BEDROCK) { - graphConfig.streamBuffer = 2000; - } - - return Run.create({ - runId, - graphConfig, - customHandlers, - }); -} - -module.exports = { createRun }; diff --git a/api/server/controllers/agents/v1.js b/api/server/controllers/agents/v1.js index 244b6e8e23..764a2e05d4 100644 --- a/api/server/controllers/agents/v1.js +++ b/api/server/controllers/agents/v1.js @@ -1,9 +1,9 @@ const fs = require('fs').promises; const { nanoid } = require('nanoid'); +const { logger } = require('@librechat/data-schemas'); const { Tools, Constants, - FileContext, FileSources, SystemRoles, EToolResources, @@ -16,15 +16,16 @@ const { deleteAgent, getListAgents, } = require('~/models/Agent'); -const { uploadImageBuffer, filterFile } = require('~/server/services/Files/process'); const { getStrategyFunctions } = require('~/server/services/Files/strategies'); +const { resizeAvatar } = require('~/server/services/Files/images/avatar'); const { refreshS3Url } = require('~/server/services/Files/S3/crud'); +const { filterFile } = require('~/server/services/Files/process'); const { updateAction, getActions } = require('~/models/Action'); +const { getCachedTools } = require('~/server/services/Config'); const { updateAgentProjects } = require('~/models/Agent'); const { getProjectByName } = require('~/models/Project'); -const { deleteFileByFilter } = require('~/models/File'); const { revertAgentVersion } = require('~/models/Agent'); -const { logger } = require('~/config'); +const { deleteFileByFilter } = require('~/models/File'); const systemTools = { [Tools.execute_code]: true, @@ -46,8 +47,9 @@ const createAgentHandler = async (req, res) => { agentData.tools = []; + const availableTools = await getCachedTools({ includeGlobal: true }); for (const tool of tools) { - if (req.app.locals.availableTools[tool]) { + if (availableTools[tool]) { agentData.tools.push(tool); } @@ -168,12 +170,18 @@ const updateAgentHandler = async (req, res) => { }); } + /** @type {boolean} */ + const isProjectUpdate = (projectIds?.length ?? 0) > 0 || (removeProjectIds?.length ?? 0) > 0; + let updatedAgent = Object.keys(updateData).length > 0 - ? await updateAgent({ id }, updateData, { updatingUserId: req.user.id }) + ? await updateAgent({ id }, updateData, { + updatingUserId: req.user.id, + skipVersioning: isProjectUpdate, + }) : existingAgent; - if (projectIds || removeProjectIds) { + if (isProjectUpdate) { updatedAgent = await updateAgentProjects({ user: req.user, agentId: id, @@ -234,6 +242,8 @@ const duplicateAgentHandler = async (req, res) => { createdAt: _createdAt, updatedAt: _updatedAt, tool_resources: _tool_resources = {}, + versions: _versions, + __v: _v, ...cloneData } = agent; cloneData.name = `${agent.name} (${new Date().toLocaleString('en-US', { @@ -373,12 +383,27 @@ const uploadAgentAvatarHandler = async (req, res) => { } const buffer = await fs.readFile(req.file.path); - const image = await uploadImageBuffer({ - req, - context: FileContext.avatar, - metadata: { buffer }, + + const fileStrategy = req.app.locals.fileStrategy; + + const resizedBuffer = await resizeAvatar({ + userId: req.user.id, + input: buffer, }); + const { processAvatar } = getStrategyFunctions(fileStrategy); + const avatarUrl = await processAvatar({ + buffer: resizedBuffer, + userId: req.user.id, + manual: 'false', + agentId: agent_id, + }); + + const image = { + filepath: avatarUrl, + source: fileStrategy, + }; + let _avatar; try { const agent = await getAgent({ id: agent_id }); @@ -403,7 +428,7 @@ const uploadAgentAvatarHandler = async (req, res) => { const data = { avatar: { filepath: image.filepath, - source: req.app.locals.fileStrategy, + source: image.source, }, }; @@ -423,7 +448,7 @@ const uploadAgentAvatarHandler = async (req, res) => { try { await fs.unlink(req.file.path); logger.debug('[/:agent_id/avatar] Temp. image upload file deleted'); - } catch (error) { + } catch { logger.debug('[/:agent_id/avatar] Temp. image upload file already deleted'); } } diff --git a/api/server/controllers/assistants/chatV1.js b/api/server/controllers/assistants/chatV1.js index 9129a6a1c1..b4fe0d9013 100644 --- a/api/server/controllers/assistants/chatV1.js +++ b/api/server/controllers/assistants/chatV1.js @@ -1,4 +1,7 @@ const { v4 } = require('uuid'); +const { sleep } = require('@librechat/agents'); +const { sendEvent } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { Time, Constants, @@ -19,20 +22,20 @@ const { addThreadMetadata, saveAssistantMessage, } = require('~/server/services/Threads'); -const { sendResponse, sendMessage, sleep, countTokens } = require('~/server/utils'); const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService'); const validateAuthor = require('~/server/middleware/assistants/validateAuthor'); const { formatMessage, createVisionPrompt } = require('~/app/clients/prompts'); const { createRun, StreamRunManager } = require('~/server/services/Runs'); const { addTitle } = require('~/server/services/Endpoints/assistants'); const { createRunBody } = require('~/server/services/createRunBody'); +const { sendResponse } = require('~/server/middleware/error'); const { getTransactions } = require('~/models/Transaction'); const { checkBalance } = require('~/models/balanceMethods'); const { getConvo } = require('~/models/Conversation'); const getLogStores = require('~/cache/getLogStores'); +const { countTokens } = require('~/server/utils'); const { getModelMaxTokens } = require('~/utils'); const { getOpenAIClient } = require('./helpers'); -const { logger } = require('~/config'); /** * @route POST / @@ -471,7 +474,7 @@ const chatV1 = async (req, res) => { await Promise.all(promises); const sendInitialResponse = () => { - sendMessage(res, { + sendEvent(res, { sync: true, conversationId, // messages: previousMessages, @@ -587,7 +590,7 @@ const chatV1 = async (req, res) => { iconURL: endpointOption.iconURL, }; - sendMessage(res, { + sendEvent(res, { final: true, conversation, requestMessage: { diff --git a/api/server/controllers/assistants/chatV2.js b/api/server/controllers/assistants/chatV2.js index 309e5a86c4..e1ba93bc21 100644 --- a/api/server/controllers/assistants/chatV2.js +++ b/api/server/controllers/assistants/chatV2.js @@ -1,4 +1,7 @@ const { v4 } = require('uuid'); +const { sleep } = require('@librechat/agents'); +const { sendEvent } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { Time, Constants, @@ -22,15 +25,14 @@ const { createErrorHandler } = require('~/server/controllers/assistants/errors') const validateAuthor = require('~/server/middleware/assistants/validateAuthor'); const { createRun, StreamRunManager } = require('~/server/services/Runs'); const { addTitle } = require('~/server/services/Endpoints/assistants'); -const { sendMessage, sleep, countTokens } = require('~/server/utils'); const { createRunBody } = require('~/server/services/createRunBody'); const { getTransactions } = require('~/models/Transaction'); const { checkBalance } = require('~/models/balanceMethods'); const { getConvo } = require('~/models/Conversation'); const getLogStores = require('~/cache/getLogStores'); +const { countTokens } = require('~/server/utils'); const { getModelMaxTokens } = require('~/utils'); const { getOpenAIClient } = require('./helpers'); -const { logger } = require('~/config'); /** * @route POST / @@ -309,7 +311,7 @@ const chatV2 = async (req, res) => { await Promise.all(promises); const sendInitialResponse = () => { - sendMessage(res, { + sendEvent(res, { sync: true, conversationId, // messages: previousMessages, @@ -432,7 +434,7 @@ const chatV2 = async (req, res) => { iconURL: endpointOption.iconURL, }; - sendMessage(res, { + sendEvent(res, { final: true, conversation, requestMessage: { diff --git a/api/server/controllers/assistants/errors.js b/api/server/controllers/assistants/errors.js index a4b880bf04..182b230fba 100644 --- a/api/server/controllers/assistants/errors.js +++ b/api/server/controllers/assistants/errors.js @@ -1,10 +1,10 @@ // errorHandler.js -const { sendResponse } = require('~/server/utils'); -const { logger } = require('~/config'); -const getLogStores = require('~/cache/getLogStores'); +const { logger } = require('@librechat/data-schemas'); const { CacheKeys, ViolationTypes, ContentTypes } = require('librechat-data-provider'); -const { getConvo } = require('~/models/Conversation'); const { recordUsage, checkMessageGaps } = require('~/server/services/Threads'); +const { sendResponse } = require('~/server/middleware/error'); +const { getConvo } = require('~/models/Conversation'); +const getLogStores = require('~/cache/getLogStores'); /** * @typedef {Object} ErrorHandlerContext @@ -78,7 +78,7 @@ const createErrorHandler = ({ req, res, getContext, originPath = '/assistants/ch } else if (/Files.*are invalid/.test(error.message)) { const errorMessage = `Files are invalid, or may not have uploaded yet.${ endpoint === 'azureAssistants' - ? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.' + ? " If using Azure OpenAI, files are only available in the region of the assistant's model at the time of upload." : '' }`; return sendResponse(req, res, messageData, errorMessage); diff --git a/api/server/controllers/assistants/v1.js b/api/server/controllers/assistants/v1.js index 8fb73167c1..e723cda4fc 100644 --- a/api/server/controllers/assistants/v1.js +++ b/api/server/controllers/assistants/v1.js @@ -1,4 +1,5 @@ const fs = require('fs').promises; +const { logger } = require('@librechat/data-schemas'); const { FileContext } = require('librechat-data-provider'); const { uploadImageBuffer, filterFile } = require('~/server/services/Files/process'); const validateAuthor = require('~/server/middleware/assistants/validateAuthor'); @@ -6,9 +7,9 @@ const { getStrategyFunctions } = require('~/server/services/Files/strategies'); const { deleteAssistantActions } = require('~/server/services/ActionService'); const { updateAssistantDoc, getAssistants } = require('~/models/Assistant'); const { getOpenAIClient, fetchAssistants } = require('./helpers'); +const { getCachedTools } = require('~/server/services/Config'); const { manifestToolMap } = require('~/app/clients/tools'); const { deleteFileByFilter } = require('~/models/File'); -const { logger } = require('~/config'); /** * Create an assistant. @@ -30,21 +31,20 @@ const createAssistant = async (req, res) => { delete assistantData.conversation_starters; delete assistantData.append_current_datetime; + const toolDefinitions = await getCachedTools({ includeGlobal: true }); + assistantData.tools = tools .map((tool) => { if (typeof tool !== 'string') { return tool; } - const toolDefinitions = req.app.locals.availableTools; const toolDef = toolDefinitions[tool]; if (!toolDef && manifestToolMap[tool] && manifestToolMap[tool].toolkit === true) { - return ( - Object.entries(toolDefinitions) - .filter(([key]) => key.startsWith(`${tool}_`)) - // eslint-disable-next-line no-unused-vars - .map(([_, val]) => val) - ); + return Object.entries(toolDefinitions) + .filter(([key]) => key.startsWith(`${tool}_`)) + + .map(([_, val]) => val); } return toolDef; @@ -135,21 +135,21 @@ const patchAssistant = async (req, res) => { append_current_datetime, ...updateData } = req.body; + + const toolDefinitions = await getCachedTools({ includeGlobal: true }); + updateData.tools = (updateData.tools ?? []) .map((tool) => { if (typeof tool !== 'string') { return tool; } - const toolDefinitions = req.app.locals.availableTools; const toolDef = toolDefinitions[tool]; if (!toolDef && manifestToolMap[tool] && manifestToolMap[tool].toolkit === true) { - return ( - Object.entries(toolDefinitions) - .filter(([key]) => key.startsWith(`${tool}_`)) - // eslint-disable-next-line no-unused-vars - .map(([_, val]) => val) - ); + return Object.entries(toolDefinitions) + .filter(([key]) => key.startsWith(`${tool}_`)) + + .map(([_, val]) => val); } return toolDef; diff --git a/api/server/controllers/assistants/v2.js b/api/server/controllers/assistants/v2.js index 3bf83a626f..98441ba70a 100644 --- a/api/server/controllers/assistants/v2.js +++ b/api/server/controllers/assistants/v2.js @@ -1,10 +1,11 @@ +const { logger } = require('@librechat/data-schemas'); const { ToolCallTypes } = require('librechat-data-provider'); const validateAuthor = require('~/server/middleware/assistants/validateAuthor'); const { validateAndUpdateTool } = require('~/server/services/ActionService'); +const { getCachedTools } = require('~/server/services/Config'); const { updateAssistantDoc } = require('~/models/Assistant'); const { manifestToolMap } = require('~/app/clients/tools'); const { getOpenAIClient } = require('./helpers'); -const { logger } = require('~/config'); /** * Create an assistant. @@ -27,21 +28,20 @@ const createAssistant = async (req, res) => { delete assistantData.conversation_starters; delete assistantData.append_current_datetime; + const toolDefinitions = await getCachedTools({ includeGlobal: true }); + assistantData.tools = tools .map((tool) => { if (typeof tool !== 'string') { return tool; } - const toolDefinitions = req.app.locals.availableTools; const toolDef = toolDefinitions[tool]; if (!toolDef && manifestToolMap[tool] && manifestToolMap[tool].toolkit === true) { - return ( - Object.entries(toolDefinitions) - .filter(([key]) => key.startsWith(`${tool}_`)) - // eslint-disable-next-line no-unused-vars - .map(([_, val]) => val) - ); + return Object.entries(toolDefinitions) + .filter(([key]) => key.startsWith(`${tool}_`)) + + .map(([_, val]) => val); } return toolDef; @@ -125,13 +125,13 @@ const updateAssistant = async ({ req, openai, assistant_id, updateData }) => { let hasFileSearch = false; for (const tool of updateData.tools ?? []) { - const toolDefinitions = req.app.locals.availableTools; + const toolDefinitions = await getCachedTools({ includeGlobal: true }); let actualTool = typeof tool === 'string' ? toolDefinitions[tool] : tool; if (!actualTool && manifestToolMap[tool] && manifestToolMap[tool].toolkit === true) { actualTool = Object.entries(toolDefinitions) .filter(([key]) => key.startsWith(`${tool}_`)) - // eslint-disable-next-line no-unused-vars + .map(([_, val]) => val); } else if (!actualTool) { continue; diff --git a/api/server/controllers/auth/TwoFactorAuthController.js b/api/server/controllers/auth/TwoFactorAuthController.js index 15cde8122a..b37c89a998 100644 --- a/api/server/controllers/auth/TwoFactorAuthController.js +++ b/api/server/controllers/auth/TwoFactorAuthController.js @@ -1,12 +1,12 @@ const jwt = require('jsonwebtoken'); +const { logger } = require('@librechat/data-schemas'); const { verifyTOTP, - verifyBackupCode, getTOTPSecret, + verifyBackupCode, } = require('~/server/services/twoFactorService'); const { setAuthTokens } = require('~/server/services/AuthService'); -const { getUserById } = require('~/models/userMethods'); -const { logger } = require('~/config'); +const { getUserById } = require('~/models'); /** * Verifies the 2FA code during login using a temporary token. diff --git a/api/server/controllers/tools.js b/api/server/controllers/tools.js index 254ecb4f94..8d5d2e9ce6 100644 --- a/api/server/controllers/tools.js +++ b/api/server/controllers/tools.js @@ -1,5 +1,7 @@ const { nanoid } = require('nanoid'); const { EnvVar } = require('@librechat/agents'); +const { checkAccess } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { Tools, AuthType, @@ -13,9 +15,8 @@ const { processCodeOutput } = require('~/server/services/Files/Code/process'); const { createToolCall, getToolCallsByConvo } = require('~/models/ToolCall'); const { loadAuthValues } = require('~/server/services/Tools/credentials'); const { loadTools } = require('~/app/clients/tools/util'); -const { checkAccess } = require('~/server/middleware'); +const { getRoleByName } = require('~/models/Role'); const { getMessage } = require('~/models/Message'); -const { logger } = require('~/config'); const fieldsMap = { [Tools.execute_code]: [EnvVar.CODE_API_KEY], @@ -79,6 +80,7 @@ const verifyToolAuth = async (req, res) => { throwError: false, }); } catch (error) { + logger.error('Error loading auth values', error); res.status(200).json({ authenticated: false, message: AuthType.USER_PROVIDED }); return; } @@ -132,7 +134,12 @@ const callTool = async (req, res) => { logger.debug(`[${toolId}/call] User: ${req.user.id}`); let hasAccess = true; if (toolAccessPermType[toolId]) { - hasAccess = await checkAccess(req.user, toolAccessPermType[toolId], [Permissions.USE]); + hasAccess = await checkAccess({ + user: req.user, + permissionType: toolAccessPermType[toolId], + permissions: [Permissions.USE], + getRoleByName, + }); } if (!hasAccess) { logger.warn( diff --git a/api/server/index.js b/api/server/index.js index c7525f9b91..ac79a627e9 100644 --- a/api/server/index.js +++ b/api/server/index.js @@ -1,21 +1,22 @@ require('dotenv').config(); +const fs = require('fs'); const path = require('path'); require('module-alias')({ base: path.resolve(__dirname, '..') }); const cors = require('cors'); const axios = require('axios'); const express = require('express'); -const compression = require('compression'); const passport = require('passport'); -const mongoSanitize = require('express-mongo-sanitize'); -const fs = require('fs'); +const compression = require('compression'); const cookieParser = require('cookie-parser'); -const { jwtLogin, passportLogin } = require('~/strategies'); -const { connectDb, indexSync } = require('~/lib/db'); -const { isEnabled } = require('~/server/utils'); -const { ldapLogin } = require('~/strategies'); -const { logger } = require('~/config'); +const { isEnabled } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); +const mongoSanitize = require('express-mongo-sanitize'); +const { connectDb, indexSync } = require('~/db'); + const validateImageRequest = require('./middleware/validateImageRequest'); +const { jwtLogin, ldapLogin, passportLogin } = require('~/strategies'); const errorController = require('./controllers/ErrorController'); +const initializeMCP = require('./services/initializeMCP'); const configureSocialLogins = require('./socialLogins'); const AppService = require('./services/AppService'); const staticCache = require('./utils/staticCache'); @@ -36,8 +37,11 @@ const startServer = async () => { axios.defaults.headers.common['Accept-Encoding'] = 'gzip'; } await connectDb(); + logger.info('Connected to MongoDB'); - await indexSync(); + indexSync().catch((err) => { + logger.error('[indexSync] Background sync failed:', err); + }); app.disable('x-powered-by'); app.set('trust proxy', trusted_proxy); @@ -93,7 +97,6 @@ const startServer = async () => { app.use('/api/actions', routes.actions); app.use('/api/keys', routes.keys); app.use('/api/user', routes.user); - app.use('/api/ask', routes.ask); app.use('/api/search', routes.search); app.use('/api/edit', routes.edit); app.use('/api/messages', routes.messages); @@ -114,9 +117,9 @@ const startServer = async () => { app.use('/api/roles', routes.roles); app.use('/api/agents', routes.agents); app.use('/api/banner', routes.banner); - app.use('/api/bedrock', routes.bedrock); - + app.use('/api/memories', routes.memories); app.use('/api/tags', routes.tags); + app.use('/api/mcp', routes.mcp); app.use((req, res) => { res.set({ @@ -140,6 +143,8 @@ const startServer = async () => { } else { logger.info(`Server listening at http://${host == '0.0.0.0' ? 'localhost' : host}:${port}`); } + + initializeMCP(app); }); }; @@ -182,5 +187,5 @@ process.on('uncaughtException', (err) => { process.exit(1); }); -// export app for easier testing purposes +/** Export app for easier testing purposes */ module.exports = app; diff --git a/api/server/middleware/abortMiddleware.js b/api/server/middleware/abortMiddleware.js index bfc28f513d..c5fc48780c 100644 --- a/api/server/middleware/abortMiddleware.js +++ b/api/server/middleware/abortMiddleware.js @@ -1,13 +1,13 @@ -// abortMiddleware.js +const { logger } = require('@librechat/data-schemas'); +const { countTokens, isEnabled, sendEvent } = require('@librechat/api'); const { isAssistantsEndpoint, ErrorTypes } = require('librechat-data-provider'); -const { sendMessage, sendError, countTokens, isEnabled } = require('~/server/utils'); const { truncateText, smartTruncateText } = require('~/app/clients/prompts'); const clearPendingReq = require('~/cache/clearPendingReq'); +const { sendError } = require('~/server/middleware/error'); const { spendTokens } = require('~/models/spendTokens'); const abortControllers = require('./abortControllers'); const { saveMessage, getConvo } = require('~/models'); const { abortRun } = require('./abortRun'); -const { logger } = require('~/config'); const abortDataMap = new WeakMap(); @@ -101,7 +101,7 @@ async function abortMessage(req, res) { cleanupAbortController(abortKey); if (res.headersSent && finalEvent) { - return sendMessage(res, finalEvent); + return sendEvent(res, finalEvent); } res.setHeader('Content-Type', 'application/json'); @@ -174,7 +174,7 @@ const createAbortController = (req, res, getAbortData, getReqData) => { * @param {string} responseMessageId */ const onStart = (userMessage, responseMessageId) => { - sendMessage(res, { message: userMessage, created: true }); + sendEvent(res, { message: userMessage, created: true }); const abortKey = userMessage?.conversationId ?? req.user.id; getReqData({ abortKey }); @@ -327,7 +327,7 @@ const handleAbortError = async (res, req, error, data) => { errorText = `{"type":"${ErrorTypes.INVALID_REQUEST}"}`; } - if (error?.message?.includes('does not support \'system\'')) { + if (error?.message?.includes("does not support 'system'")) { errorText = `{"type":"${ErrorTypes.NO_SYSTEM_MESSAGES}"}`; } diff --git a/api/server/middleware/abortRun.js b/api/server/middleware/abortRun.js index 01b34aacc2..2846c6eefc 100644 --- a/api/server/middleware/abortRun.js +++ b/api/server/middleware/abortRun.js @@ -1,11 +1,11 @@ +const { sendEvent } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { CacheKeys, RunStatus, isUUID } = require('librechat-data-provider'); const { initializeClient } = require('~/server/services/Endpoints/assistants'); const { checkMessageGaps, recordUsage } = require('~/server/services/Threads'); const { deleteMessages } = require('~/models/Message'); const { getConvo } = require('~/models/Conversation'); const getLogStores = require('~/cache/getLogStores'); -const { sendMessage } = require('~/server/utils'); -const { logger } = require('~/config'); const three_minutes = 1000 * 60 * 3; @@ -34,7 +34,7 @@ async function abortRun(req, res) { const [thread_id, run_id] = runValues.split(':'); if (!run_id) { - logger.warn('[abortRun] Couldn\'t find run for cancel request', { thread_id }); + logger.warn("[abortRun] Couldn't find run for cancel request", { thread_id }); return res.status(204).send({ message: 'Run not found' }); } else if (run_id === 'cancelled') { logger.warn('[abortRun] Run already cancelled', { thread_id }); @@ -93,7 +93,7 @@ async function abortRun(req, res) { }; if (res.headersSent && finalEvent) { - return sendMessage(res, finalEvent); + return sendEvent(res, finalEvent); } res.json(finalEvent); diff --git a/api/server/middleware/buildEndpointOption.js b/api/server/middleware/buildEndpointOption.js index 8394223b5e..d302bf8743 100644 --- a/api/server/middleware/buildEndpointOption.js +++ b/api/server/middleware/buildEndpointOption.js @@ -1,13 +1,12 @@ +const { logger } = require('@librechat/data-schemas'); const { - parseCompactConvo, + EndpointURLs, EModelEndpoint, isAgentsEndpoint, - EndpointURLs, + parseCompactConvo, } = require('librechat-data-provider'); const azureAssistants = require('~/server/services/Endpoints/azureAssistants'); -const { getModelsConfig } = require('~/server/controllers/ModelController'); const assistants = require('~/server/services/Endpoints/assistants'); -const gptPlugins = require('~/server/services/Endpoints/gptPlugins'); const { processFiles } = require('~/server/services/Files/process'); const anthropic = require('~/server/services/Endpoints/anthropic'); const bedrock = require('~/server/services/Endpoints/bedrock'); @@ -25,7 +24,6 @@ const buildFunction = { [EModelEndpoint.bedrock]: bedrock.buildOptions, [EModelEndpoint.azureOpenAI]: openAI.buildOptions, [EModelEndpoint.anthropic]: anthropic.buildOptions, - [EModelEndpoint.gptPlugins]: gptPlugins.buildOptions, [EModelEndpoint.assistants]: assistants.buildOptions, [EModelEndpoint.azureAssistants]: azureAssistants.buildOptions, }; @@ -36,6 +34,9 @@ async function buildEndpointOption(req, res, next) { try { parsedBody = parseCompactConvo({ endpoint, endpointType, conversation: req.body }); } catch (error) { + logger.warn( + `Error parsing conversation for endpoint ${endpoint}${error?.message ? `: ${error.message}` : ''}`, + ); return handleError(res, { text: 'Error parsing conversation' }); } @@ -57,15 +58,6 @@ async function buildEndpointOption(req, res, next) { return handleError(res, { text: 'Model spec mismatch' }); } - if ( - currentModelSpec.preset.endpoint !== EModelEndpoint.gptPlugins && - currentModelSpec.preset.tools - ) { - return handleError(res, { - text: `Only the "${EModelEndpoint.gptPlugins}" endpoint can have tools defined in the preset`, - }); - } - try { currentModelSpec.preset.spec = spec; if (currentModelSpec.iconURL != null && currentModelSpec.iconURL !== '') { @@ -77,6 +69,7 @@ async function buildEndpointOption(req, res, next) { conversation: currentModelSpec.preset, }); } catch (error) { + logger.error(`Error parsing model spec for endpoint ${endpoint}`, error); return handleError(res, { text: 'Error parsing model spec' }); } } @@ -84,20 +77,23 @@ async function buildEndpointOption(req, res, next) { try { const isAgents = isAgentsEndpoint(endpoint) || req.baseUrl.startsWith(EndpointURLs[EModelEndpoint.agents]); - const endpointFn = buildFunction[isAgents ? EModelEndpoint.agents : (endpointType ?? endpoint)]; - const builder = isAgents ? (...args) => endpointFn(req, ...args) : endpointFn; + const builder = isAgents + ? (...args) => buildFunction[EModelEndpoint.agents](req, ...args) + : buildFunction[endpointType ?? endpoint]; // TODO: use object params req.body.endpointOption = await builder(endpoint, parsedBody, endpointType); - // TODO: use `getModelsConfig` only when necessary - const modelsConfig = await getModelsConfig(req); - req.body.endpointOption.modelsConfig = modelsConfig; if (req.body.files && !isAgents) { req.body.endpointOption.attachments = processFiles(req.body.files); } + next(); } catch (error) { + logger.error( + `Error building endpoint option for endpoint ${endpoint} with type ${endpointType}`, + error, + ); return handleError(res, { text: 'Error building endpoint option' }); } } diff --git a/api/server/middleware/checkBan.js b/api/server/middleware/checkBan.js index 4e0593192a..91c31ab66a 100644 --- a/api/server/middleware/checkBan.js +++ b/api/server/middleware/checkBan.js @@ -1,12 +1,12 @@ const { Keyv } = require('keyv'); const uap = require('ua-parser-js'); +const { logger } = require('@librechat/data-schemas'); const { ViolationTypes } = require('librechat-data-provider'); const { isEnabled, removePorts } = require('~/server/utils'); const keyvMongo = require('~/cache/keyvMongo'); const denyRequest = require('./denyRequest'); const { getLogStores } = require('~/cache'); const { findUser } = require('~/models'); -const { logger } = require('~/config'); const banCache = new Keyv({ store: keyvMongo, namespace: ViolationTypes.BAN, ttl: 0 }); const message = 'Your account has been temporarily banned due to violations of our service.'; diff --git a/api/server/middleware/checkInviteUser.js b/api/server/middleware/checkInviteUser.js index e1ad271b55..42e1faba5b 100644 --- a/api/server/middleware/checkInviteUser.js +++ b/api/server/middleware/checkInviteUser.js @@ -1,5 +1,5 @@ const { getInvite } = require('~/models/inviteUser'); -const { deleteTokens } = require('~/models/Token'); +const { deleteTokens } = require('~/models'); async function checkInviteUser(req, res, next) { const token = req.body.token; diff --git a/api/server/middleware/denyRequest.js b/api/server/middleware/denyRequest.js index 62efb1aeaf..20360519cf 100644 --- a/api/server/middleware/denyRequest.js +++ b/api/server/middleware/denyRequest.js @@ -1,6 +1,7 @@ const crypto = require('crypto'); +const { sendEvent } = require('@librechat/api'); const { getResponseSender, Constants } = require('librechat-data-provider'); -const { sendMessage, sendError } = require('~/server/utils'); +const { sendError } = require('~/server/middleware/error'); const { saveMessage } = require('~/models'); /** @@ -36,7 +37,7 @@ const denyRequest = async (req, res, errorMessage) => { isCreatedByUser: true, text, }; - sendMessage(res, { message: userMessage, created: true }); + sendEvent(res, { message: userMessage, created: true }); const shouldSaveMessage = _convoId && parentMessageId && parentMessageId !== Constants.NO_PARENT; diff --git a/api/server/utils/streamResponse.js b/api/server/middleware/error.js similarity index 76% rename from api/server/utils/streamResponse.js rename to api/server/middleware/error.js index bb8d63b229..db445c1d43 100644 --- a/api/server/utils/streamResponse.js +++ b/api/server/middleware/error.js @@ -1,31 +1,9 @@ const crypto = require('crypto'); +const { logger } = require('@librechat/data-schemas'); const { parseConvo } = require('librechat-data-provider'); +const { sendEvent, handleError } = require('@librechat/api'); const { saveMessage, getMessages } = require('~/models/Message'); const { getConvo } = require('~/models/Conversation'); -const { logger } = require('~/config'); - -/** - * Sends error data in Server Sent Events format and ends the response. - * @param {object} res - The server response. - * @param {string} message - The error message. - */ -const handleError = (res, message) => { - res.write(`event: error\ndata: ${JSON.stringify(message)}\n\n`); - res.end(); -}; - -/** - * Sends message data in Server Sent Events format. - * @param {Express.Response} res - - The server response. - * @param {string | Object} message - The message to be sent. - * @param {'message' | 'error' | 'cancel'} event - [Optional] The type of event. Default is 'message'. - */ -const sendMessage = (res, message, event = 'message') => { - if (typeof message === 'string' && message.length === 0) { - return; - } - res.write(`event: ${event}\ndata: ${JSON.stringify(message)}\n\n`); -}; /** * Processes an error with provided options, saves the error message and sends a corresponding SSE response @@ -91,7 +69,7 @@ const sendError = async (req, res, options, callback) => { convo = parseConvo(errorMessage); } - return sendMessage(res, { + return sendEvent(res, { final: true, requestMessage: query?.[0] ? query[0] : requestMessage, responseMessage: errorMessage, @@ -120,12 +98,10 @@ const sendResponse = (req, res, data, errorMessage) => { if (errorMessage) { return sendError(req, res, { ...data, text: errorMessage }); } - return sendMessage(res, data); + return sendEvent(res, data); }; module.exports = { - sendResponse, - handleError, - sendMessage, sendError, + sendResponse, }; diff --git a/api/server/middleware/roles/checkAdmin.js b/api/server/middleware/roles/admin.js similarity index 100% rename from api/server/middleware/roles/checkAdmin.js rename to api/server/middleware/roles/admin.js diff --git a/api/server/middleware/roles/generateCheckAccess.js b/api/server/middleware/roles/generateCheckAccess.js deleted file mode 100644 index cabbd405b0..0000000000 --- a/api/server/middleware/roles/generateCheckAccess.js +++ /dev/null @@ -1,78 +0,0 @@ -const { getRoleByName } = require('~/models/Role'); -const { logger } = require('~/config'); - -/** - * Core function to check if a user has one or more required permissions - * - * @param {object} user - The user object - * @param {PermissionTypes} permissionType - The type of permission to check - * @param {Permissions[]} permissions - The list of specific permissions to check - * @param {Record} [bodyProps] - An optional object where keys are permissions and values are arrays of properties to check - * @param {object} [checkObject] - The object to check properties against - * @returns {Promise} Whether the user has the required permissions - */ -const checkAccess = async (user, permissionType, permissions, bodyProps = {}, checkObject = {}) => { - if (!user) { - return false; - } - - const role = await getRoleByName(user.role); - if (role && role.permissions && role.permissions[permissionType]) { - const hasAnyPermission = permissions.some((permission) => { - if (role.permissions[permissionType][permission]) { - return true; - } - - if (bodyProps[permission] && checkObject) { - return bodyProps[permission].some((prop) => - Object.prototype.hasOwnProperty.call(checkObject, prop), - ); - } - - return false; - }); - - return hasAnyPermission; - } - - return false; -}; - -/** - * Middleware to check if a user has one or more required permissions, optionally based on `req.body` properties. - * - * @param {PermissionTypes} permissionType - The type of permission to check. - * @param {Permissions[]} permissions - The list of specific permissions to check. - * @param {Record} [bodyProps] - An optional object where keys are permissions and values are arrays of `req.body` properties to check. - * @returns {(req: ServerRequest, res: ServerResponse, next: NextFunction) => Promise} Express middleware function. - */ -const generateCheckAccess = (permissionType, permissions, bodyProps = {}) => { - return async (req, res, next) => { - try { - const hasAccess = await checkAccess( - req.user, - permissionType, - permissions, - bodyProps, - req.body, - ); - - if (hasAccess) { - return next(); - } - - logger.warn( - `[${permissionType}] Forbidden: Insufficient permissions for User ${req.user.id}: ${permissions.join(', ')}`, - ); - return res.status(403).json({ message: 'Forbidden: Insufficient permissions' }); - } catch (error) { - logger.error(error); - return res.status(500).json({ message: `Server error: ${error.message}` }); - } - }; -}; - -module.exports = { - checkAccess, - generateCheckAccess, -}; diff --git a/api/server/middleware/roles/index.js b/api/server/middleware/roles/index.js index a9fc5b2a08..f01b884e5a 100644 --- a/api/server/middleware/roles/index.js +++ b/api/server/middleware/roles/index.js @@ -1,8 +1,5 @@ -const checkAdmin = require('./checkAdmin'); -const { checkAccess, generateCheckAccess } = require('./generateCheckAccess'); +const checkAdmin = require('./admin'); module.exports = { checkAdmin, - checkAccess, - generateCheckAccess, }; diff --git a/api/server/middleware/setBalanceConfig.js b/api/server/middleware/setBalanceConfig.js index 98d3cf1145..5dd9757965 100644 --- a/api/server/middleware/setBalanceConfig.js +++ b/api/server/middleware/setBalanceConfig.js @@ -1,6 +1,6 @@ +const { logger } = require('@librechat/data-schemas'); const { getBalanceConfig } = require('~/server/services/Config'); -const Balance = require('~/models/Balance'); -const { logger } = require('~/config'); +const { Balance } = require('~/db/models'); /** * Middleware to synchronize user balance settings with current balance configuration. diff --git a/api/server/middleware/validate/convoAccess.js b/api/server/middleware/validate/convoAccess.js index 43cca0097d..afd2aeacef 100644 --- a/api/server/middleware/validate/convoAccess.js +++ b/api/server/middleware/validate/convoAccess.js @@ -1,8 +1,8 @@ +const { isEnabled } = require('@librechat/api'); const { Constants, ViolationTypes, Time } = require('librechat-data-provider'); const { searchConversation } = require('~/models/Conversation'); const denyRequest = require('~/server/middleware/denyRequest'); const { logViolation, getLogStores } = require('~/cache'); -const { isEnabled } = require('~/server/utils'); const { USE_REDIS, CONVO_ACCESS_VIOLATION_SCORE: score = 0 } = process.env ?? {}; diff --git a/api/server/routes/actions.js b/api/server/routes/actions.js index dc474d1a67..9f94f617ce 100644 --- a/api/server/routes/actions.js +++ b/api/server/routes/actions.js @@ -1,8 +1,10 @@ const express = require('express'); const jwt = require('jsonwebtoken'); +const { getAccessToken } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { CacheKeys } = require('librechat-data-provider'); -const { getAccessToken } = require('~/server/services/TokenService'); -const { logger, getFlowStateManager } = require('~/config'); +const { findToken, updateToken, createToken } = require('~/models'); +const { getFlowStateManager } = require('~/config'); const { getLogStores } = require('~/cache'); const router = express.Router(); @@ -28,18 +30,19 @@ router.get('/:action_id/oauth/callback', async (req, res) => { try { decodedState = jwt.verify(state, JWT_SECRET); } catch (err) { + logger.error('Error verifying state parameter:', err); await flowManager.failFlow(identifier, 'oauth', 'Invalid or expired state parameter'); - return res.status(400).send('Invalid or expired state parameter'); + return res.redirect('/oauth/error?error=invalid_state'); } if (decodedState.action_id !== action_id) { await flowManager.failFlow(identifier, 'oauth', 'Mismatched action ID in state parameter'); - return res.status(400).send('Mismatched action ID in state parameter'); + return res.redirect('/oauth/error?error=invalid_state'); } if (!decodedState.user) { await flowManager.failFlow(identifier, 'oauth', 'Invalid user ID in state parameter'); - return res.status(400).send('Invalid user ID in state parameter'); + return res.redirect('/oauth/error?error=invalid_state'); } identifier = `${decodedState.user}:${action_id}`; const flowState = await flowManager.getFlowState(identifier, 'oauth'); @@ -47,90 +50,34 @@ router.get('/:action_id/oauth/callback', async (req, res) => { throw new Error('OAuth flow not found'); } - const tokenData = await getAccessToken({ - code, - userId: decodedState.user, - identifier, - client_url: flowState.metadata.client_url, - redirect_uri: flowState.metadata.redirect_uri, - /** Encrypted values */ - encrypted_oauth_client_id: flowState.metadata.encrypted_oauth_client_id, - encrypted_oauth_client_secret: flowState.metadata.encrypted_oauth_client_secret, - }); + const tokenData = await getAccessToken( + { + code, + userId: decodedState.user, + identifier, + client_url: flowState.metadata.client_url, + redirect_uri: flowState.metadata.redirect_uri, + token_exchange_method: flowState.metadata.token_exchange_method, + /** Encrypted values */ + encrypted_oauth_client_id: flowState.metadata.encrypted_oauth_client_id, + encrypted_oauth_client_secret: flowState.metadata.encrypted_oauth_client_secret, + }, + { + findToken, + updateToken, + createToken, + }, + ); await flowManager.completeFlow(identifier, 'oauth', tokenData); - res.send(` - - - - Authentication Successful - - - - - -
-

Authentication Successful

-

- Your authentication was successful. This window will close in - 3 seconds. -

-
- - - - `); + + /** Redirect to React success page */ + const serverName = flowState.metadata?.action_name || `Action ${action_id}`; + const redirectUrl = `/oauth/success?serverName=${encodeURIComponent(serverName)}`; + res.redirect(redirectUrl); } catch (error) { logger.error('Error in OAuth callback:', error); await flowManager.failFlow(identifier, 'oauth', error); - res.status(500).send('Authentication failed. Please try again.'); + res.redirect('/oauth/error?error=callback_failed'); } }); diff --git a/api/server/routes/agents/actions.js b/api/server/routes/agents/actions.js index 89d6a9dc42..2f11486a0e 100644 --- a/api/server/routes/agents/actions.js +++ b/api/server/routes/agents/actions.js @@ -1,14 +1,28 @@ const express = require('express'); const { nanoid } = require('nanoid'); -const { actionDelimiter, SystemRoles, removeNullishValues } = require('librechat-data-provider'); +const { logger } = require('@librechat/data-schemas'); +const { generateCheckAccess } = require('@librechat/api'); +const { + SystemRoles, + Permissions, + PermissionTypes, + actionDelimiter, + removeNullishValues, +} = require('librechat-data-provider'); const { encryptMetadata, domainParser } = require('~/server/services/ActionService'); const { updateAction, getActions, deleteAction } = require('~/models/Action'); const { isActionDomainAllowed } = require('~/server/services/domains'); const { getAgent, updateAgent } = require('~/models/Agent'); -const { logger } = require('~/config'); +const { getRoleByName } = require('~/models/Role'); const router = express.Router(); +const checkAgentCreate = generateCheckAccess({ + permissionType: PermissionTypes.AGENTS, + permissions: [Permissions.USE, Permissions.CREATE], + getRoleByName, +}); + // If the user has ADMIN role // then action edition is possible even if not owner of the assistant const isAdmin = (req) => { @@ -41,7 +55,7 @@ router.get('/', async (req, res) => { * @param {ActionMetadata} req.body.metadata - Metadata for the action. * @returns {Object} 200 - success response - application/json */ -router.post('/:agent_id', async (req, res) => { +router.post('/:agent_id', checkAgentCreate, async (req, res) => { try { const { agent_id } = req.params; @@ -149,7 +163,7 @@ router.post('/:agent_id', async (req, res) => { * @param {string} req.params.action_id - The ID of the action to delete. * @returns {Object} 200 - success response - application/json */ -router.delete('/:agent_id/:action_id', async (req, res) => { +router.delete('/:agent_id/:action_id', checkAgentCreate, async (req, res) => { try { const { agent_id, action_id } = req.params; const admin = isAdmin(req); diff --git a/api/server/routes/agents/chat.js b/api/server/routes/agents/chat.js index ef66ef7896..0e07c83bd1 100644 --- a/api/server/routes/agents/chat.js +++ b/api/server/routes/agents/chat.js @@ -1,22 +1,28 @@ const express = require('express'); +const { generateCheckAccess, skipAgentCheck } = require('@librechat/api'); const { PermissionTypes, Permissions } = require('librechat-data-provider'); const { setHeaders, moderateText, // validateModel, - generateCheckAccess, validateConvoAccess, buildEndpointOption, } = require('~/server/middleware'); const { initializeClient } = require('~/server/services/Endpoints/agents'); const AgentController = require('~/server/controllers/agents/request'); const addTitle = require('~/server/services/Endpoints/agents/title'); +const { getRoleByName } = require('~/models/Role'); const router = express.Router(); router.use(moderateText); -const checkAgentAccess = generateCheckAccess(PermissionTypes.AGENTS, [Permissions.USE]); +const checkAgentAccess = generateCheckAccess({ + permissionType: PermissionTypes.AGENTS, + permissions: [Permissions.USE], + skipCheck: skipAgentCheck, + getRoleByName, +}); router.use(checkAgentAccess); router.use(validateConvoAccess); diff --git a/api/server/routes/agents/v1.js b/api/server/routes/agents/v1.js index 657aa79414..0455b23948 100644 --- a/api/server/routes/agents/v1.js +++ b/api/server/routes/agents/v1.js @@ -1,29 +1,36 @@ const express = require('express'); +const { generateCheckAccess } = require('@librechat/api'); const { PermissionTypes, Permissions } = require('librechat-data-provider'); -const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware'); +const { requireJwtAuth } = require('~/server/middleware'); const v1 = require('~/server/controllers/agents/v1'); +const { getRoleByName } = require('~/models/Role'); const actions = require('./actions'); const tools = require('./tools'); const router = express.Router(); const avatar = express.Router(); -const checkAgentAccess = generateCheckAccess(PermissionTypes.AGENTS, [Permissions.USE]); -const checkAgentCreate = generateCheckAccess(PermissionTypes.AGENTS, [ - Permissions.USE, - Permissions.CREATE, -]); +const checkAgentAccess = generateCheckAccess({ + permissionType: PermissionTypes.AGENTS, + permissions: [Permissions.USE], + getRoleByName, +}); +const checkAgentCreate = generateCheckAccess({ + permissionType: PermissionTypes.AGENTS, + permissions: [Permissions.USE, Permissions.CREATE], + getRoleByName, +}); -const checkGlobalAgentShare = generateCheckAccess( - PermissionTypes.AGENTS, - [Permissions.USE, Permissions.CREATE], - { +const checkGlobalAgentShare = generateCheckAccess({ + permissionType: PermissionTypes.AGENTS, + permissions: [Permissions.USE, Permissions.CREATE], + bodyProps: { [Permissions.SHARED_GLOBAL]: ['projectIds', 'removeProjectIds'], }, -); + getRoleByName, +}); router.use(requireJwtAuth); -router.use(checkAgentAccess); /** * Agent actions route. diff --git a/api/server/routes/ask/addToCache.js b/api/server/routes/ask/addToCache.js deleted file mode 100644 index a2f427098f..0000000000 --- a/api/server/routes/ask/addToCache.js +++ /dev/null @@ -1,63 +0,0 @@ -const { Keyv } = require('keyv'); -const { KeyvFile } = require('keyv-file'); -const { logger } = require('~/config'); - -const addToCache = async ({ endpoint, endpointOption, userMessage, responseMessage }) => { - try { - const conversationsCache = new Keyv({ - store: new KeyvFile({ filename: './data/cache.json' }), - namespace: 'chatgpt', // should be 'bing' for bing/sydney - }); - - const { - conversationId, - messageId: userMessageId, - parentMessageId: userParentMessageId, - text: userText, - } = userMessage; - const { - messageId: responseMessageId, - parentMessageId: responseParentMessageId, - text: responseText, - } = responseMessage; - - let conversation = await conversationsCache.get(conversationId); - // used to generate a title for the conversation if none exists - // let isNewConversation = false; - if (!conversation) { - conversation = { - messages: [], - createdAt: Date.now(), - }; - // isNewConversation = true; - } - - const roles = (options) => { - if (endpoint === 'openAI') { - return options?.chatGptLabel || 'ChatGPT'; - } - }; - - let _userMessage = { - id: userMessageId, - parentMessageId: userParentMessageId, - role: 'User', - message: userText, - }; - - let _responseMessage = { - id: responseMessageId, - parentMessageId: responseParentMessageId, - role: roles(endpointOption), - message: responseText, - }; - - conversation.messages.push(_userMessage, _responseMessage); - - await conversationsCache.set(conversationId, conversation); - } catch (error) { - logger.error('[addToCache] Error adding conversation to cache', error); - } -}; - -module.exports = addToCache; diff --git a/api/server/routes/ask/anthropic.js b/api/server/routes/ask/anthropic.js deleted file mode 100644 index afe1720d84..0000000000 --- a/api/server/routes/ask/anthropic.js +++ /dev/null @@ -1,25 +0,0 @@ -const express = require('express'); -const AskController = require('~/server/controllers/AskController'); -const { addTitle, initializeClient } = require('~/server/services/Endpoints/anthropic'); -const { - setHeaders, - handleAbort, - validateModel, - validateEndpoint, - buildEndpointOption, -} = require('~/server/middleware'); - -const router = express.Router(); - -router.post( - '/', - validateEndpoint, - validateModel, - buildEndpointOption, - setHeaders, - async (req, res, next) => { - await AskController(req, res, next, initializeClient, addTitle); - }, -); - -module.exports = router; diff --git a/api/server/routes/ask/custom.js b/api/server/routes/ask/custom.js deleted file mode 100644 index 8fc343cf17..0000000000 --- a/api/server/routes/ask/custom.js +++ /dev/null @@ -1,25 +0,0 @@ -const express = require('express'); -const AskController = require('~/server/controllers/AskController'); -const { initializeClient } = require('~/server/services/Endpoints/custom'); -const { addTitle } = require('~/server/services/Endpoints/openAI'); -const { - setHeaders, - validateModel, - validateEndpoint, - buildEndpointOption, -} = require('~/server/middleware'); - -const router = express.Router(); - -router.post( - '/', - validateEndpoint, - validateModel, - buildEndpointOption, - setHeaders, - async (req, res, next) => { - await AskController(req, res, next, initializeClient, addTitle); - }, -); - -module.exports = router; diff --git a/api/server/routes/ask/google.js b/api/server/routes/ask/google.js deleted file mode 100644 index 16c7e265f4..0000000000 --- a/api/server/routes/ask/google.js +++ /dev/null @@ -1,24 +0,0 @@ -const express = require('express'); -const AskController = require('~/server/controllers/AskController'); -const { initializeClient, addTitle } = require('~/server/services/Endpoints/google'); -const { - setHeaders, - validateModel, - validateEndpoint, - buildEndpointOption, -} = require('~/server/middleware'); - -const router = express.Router(); - -router.post( - '/', - validateEndpoint, - validateModel, - buildEndpointOption, - setHeaders, - async (req, res, next) => { - await AskController(req, res, next, initializeClient, addTitle); - }, -); - -module.exports = router; diff --git a/api/server/routes/ask/gptPlugins.js b/api/server/routes/ask/gptPlugins.js deleted file mode 100644 index a40022848a..0000000000 --- a/api/server/routes/ask/gptPlugins.js +++ /dev/null @@ -1,241 +0,0 @@ -const express = require('express'); -const { getResponseSender, Constants } = require('librechat-data-provider'); -const { initializeClient } = require('~/server/services/Endpoints/gptPlugins'); -const { sendMessage, createOnProgress } = require('~/server/utils'); -const { addTitle } = require('~/server/services/Endpoints/openAI'); -const { saveMessage, updateMessage } = require('~/models'); -const { - handleAbort, - createAbortController, - handleAbortError, - setHeaders, - validateModel, - validateEndpoint, - buildEndpointOption, - moderateText, -} = require('~/server/middleware'); -const { validateTools } = require('~/app'); -const { logger } = require('~/config'); - -const router = express.Router(); - -router.use(moderateText); - -router.post( - '/', - validateEndpoint, - validateModel, - buildEndpointOption, - setHeaders, - async (req, res) => { - let { - text, - endpointOption, - conversationId, - parentMessageId = null, - overrideParentMessageId = null, - } = req.body; - - logger.debug('[/ask/gptPlugins]', { text, conversationId, ...endpointOption }); - - let userMessage; - let userMessagePromise; - let promptTokens; - let userMessageId; - let responseMessageId; - const sender = getResponseSender({ - ...endpointOption, - model: endpointOption.modelOptions.model, - }); - const newConvo = !conversationId; - const user = req.user.id; - - const plugins = []; - - const getReqData = (data = {}) => { - for (let key in data) { - if (key === 'userMessage') { - userMessage = data[key]; - userMessageId = data[key].messageId; - } else if (key === 'userMessagePromise') { - userMessagePromise = data[key]; - } else if (key === 'responseMessageId') { - responseMessageId = data[key]; - } else if (key === 'promptTokens') { - promptTokens = data[key]; - } else if (!conversationId && key === 'conversationId') { - conversationId = data[key]; - } - } - }; - - let streaming = null; - let timer = null; - - const { - onProgress: progressCallback, - sendIntermediateMessage, - getPartialText, - } = createOnProgress({ - onProgress: () => { - if (timer) { - clearTimeout(timer); - } - - streaming = new Promise((resolve) => { - timer = setTimeout(() => { - resolve(); - }, 250); - }); - }, - }); - - const pluginMap = new Map(); - const onAgentAction = async (action, runId) => { - pluginMap.set(runId, action.tool); - sendIntermediateMessage(res, { - plugins, - parentMessageId: userMessage.messageId, - messageId: responseMessageId, - }); - }; - - const onToolStart = async (tool, input, runId, parentRunId) => { - const pluginName = pluginMap.get(parentRunId); - const latestPlugin = { - runId, - loading: true, - inputs: [input], - latest: pluginName, - outputs: null, - }; - - if (streaming) { - await streaming; - } - const extraTokens = ':::plugin:::\n'; - plugins.push(latestPlugin); - sendIntermediateMessage( - res, - { plugins, parentMessageId: userMessage.messageId, messageId: responseMessageId }, - extraTokens, - ); - }; - - const onToolEnd = async (output, runId) => { - if (streaming) { - await streaming; - } - - const pluginIndex = plugins.findIndex((plugin) => plugin.runId === runId); - - if (pluginIndex !== -1) { - plugins[pluginIndex].loading = false; - plugins[pluginIndex].outputs = output; - } - }; - - const getAbortData = () => ({ - sender, - conversationId, - userMessagePromise, - messageId: responseMessageId, - parentMessageId: overrideParentMessageId ?? userMessageId, - text: getPartialText(), - plugins: plugins.map((p) => ({ ...p, loading: false })), - userMessage, - promptTokens, - }); - const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData); - - try { - endpointOption.tools = await validateTools(user, endpointOption.tools); - const { client } = await initializeClient({ req, res, endpointOption }); - - const onChainEnd = () => { - if (!client.skipSaveUserMessage) { - saveMessage( - req, - { ...userMessage, user }, - { context: 'api/server/routes/ask/gptPlugins.js - onChainEnd' }, - ); - } - sendIntermediateMessage(res, { - plugins, - parentMessageId: userMessage.messageId, - messageId: responseMessageId, - }); - }; - - let response = await client.sendMessage(text, { - user, - conversationId, - parentMessageId, - overrideParentMessageId, - getReqData, - onAgentAction, - onChainEnd, - onToolStart, - onToolEnd, - onStart, - getPartialText, - ...endpointOption, - progressCallback, - progressOptions: { - res, - // parentMessageId: overrideParentMessageId || userMessageId, - plugins, - }, - abortController, - }); - - if (overrideParentMessageId) { - response.parentMessageId = overrideParentMessageId; - } - - logger.debug('[/ask/gptPlugins]', response); - - const { conversation = {} } = await response.databasePromise; - delete response.databasePromise; - conversation.title = - conversation && !conversation.title ? null : conversation?.title || 'New Chat'; - - sendMessage(res, { - title: conversation.title, - final: true, - conversation, - requestMessage: userMessage, - responseMessage: response, - }); - res.end(); - - if (parentMessageId === Constants.NO_PARENT && newConvo) { - addTitle(req, { - text, - response, - client, - }); - } - - response.plugins = plugins.map((p) => ({ ...p, loading: false })); - if (response.plugins?.length > 0) { - await updateMessage( - req, - { ...response, user }, - { context: 'api/server/routes/ask/gptPlugins.js - save plugins used' }, - ); - } - } catch (error) { - const partialText = getPartialText(); - handleAbortError(res, req, error, { - partialText, - conversationId, - sender, - messageId: responseMessageId, - parentMessageId: userMessageId ?? parentMessageId, - }); - } - }, -); - -module.exports = router; diff --git a/api/server/routes/ask/index.js b/api/server/routes/ask/index.js deleted file mode 100644 index 525bd8e29d..0000000000 --- a/api/server/routes/ask/index.js +++ /dev/null @@ -1,47 +0,0 @@ -const express = require('express'); -const { EModelEndpoint } = require('librechat-data-provider'); -const { - uaParser, - checkBan, - requireJwtAuth, - messageIpLimiter, - concurrentLimiter, - messageUserLimiter, - validateConvoAccess, -} = require('~/server/middleware'); -const { isEnabled } = require('~/server/utils'); -const gptPlugins = require('./gptPlugins'); -const anthropic = require('./anthropic'); -const custom = require('./custom'); -const google = require('./google'); -const openAI = require('./openAI'); - -const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {}; - -const router = express.Router(); - -router.use(requireJwtAuth); -router.use(checkBan); -router.use(uaParser); - -if (isEnabled(LIMIT_CONCURRENT_MESSAGES)) { - router.use(concurrentLimiter); -} - -if (isEnabled(LIMIT_MESSAGE_IP)) { - router.use(messageIpLimiter); -} - -if (isEnabled(LIMIT_MESSAGE_USER)) { - router.use(messageUserLimiter); -} - -router.use(validateConvoAccess); - -router.use([`/${EModelEndpoint.azureOpenAI}`, `/${EModelEndpoint.openAI}`], openAI); -router.use(`/${EModelEndpoint.gptPlugins}`, gptPlugins); -router.use(`/${EModelEndpoint.anthropic}`, anthropic); -router.use(`/${EModelEndpoint.google}`, google); -router.use(`/${EModelEndpoint.custom}`, custom); - -module.exports = router; diff --git a/api/server/routes/ask/openAI.js b/api/server/routes/ask/openAI.js deleted file mode 100644 index dadf00def4..0000000000 --- a/api/server/routes/ask/openAI.js +++ /dev/null @@ -1,27 +0,0 @@ -const express = require('express'); -const AskController = require('~/server/controllers/AskController'); -const { addTitle, initializeClient } = require('~/server/services/Endpoints/openAI'); -const { - handleAbort, - setHeaders, - validateModel, - validateEndpoint, - buildEndpointOption, - moderateText, -} = require('~/server/middleware'); - -const router = express.Router(); -router.use(moderateText); - -router.post( - '/', - validateEndpoint, - validateModel, - buildEndpointOption, - setHeaders, - async (req, res, next) => { - await AskController(req, res, next, initializeClient, addTitle); - }, -); - -module.exports = router; diff --git a/api/server/routes/bedrock/chat.js b/api/server/routes/bedrock/chat.js deleted file mode 100644 index 263ca96002..0000000000 --- a/api/server/routes/bedrock/chat.js +++ /dev/null @@ -1,37 +0,0 @@ -const express = require('express'); - -const router = express.Router(); -const { - setHeaders, - handleAbort, - moderateText, - // validateModel, - // validateEndpoint, - buildEndpointOption, -} = require('~/server/middleware'); -const { initializeClient } = require('~/server/services/Endpoints/bedrock'); -const AgentController = require('~/server/controllers/agents/request'); -const addTitle = require('~/server/services/Endpoints/agents/title'); - -router.use(moderateText); - -/** - * @route POST / - * @desc Chat with an assistant - * @access Public - * @param {express.Request} req - The request object, containing the request data. - * @param {express.Response} res - The response object, used to send back a response. - * @returns {void} - */ -router.post( - '/', - // validateModel, - // validateEndpoint, - buildEndpointOption, - setHeaders, - async (req, res, next) => { - await AgentController(req, res, next, initializeClient, addTitle); - }, -); - -module.exports = router; diff --git a/api/server/routes/bedrock/index.js b/api/server/routes/bedrock/index.js deleted file mode 100644 index ce440a7c0e..0000000000 --- a/api/server/routes/bedrock/index.js +++ /dev/null @@ -1,35 +0,0 @@ -const express = require('express'); -const { - uaParser, - checkBan, - requireJwtAuth, - messageIpLimiter, - concurrentLimiter, - messageUserLimiter, -} = require('~/server/middleware'); -const { isEnabled } = require('~/server/utils'); -const chat = require('./chat'); - -const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {}; - -const router = express.Router(); - -router.use(requireJwtAuth); -router.use(checkBan); -router.use(uaParser); - -if (isEnabled(LIMIT_CONCURRENT_MESSAGES)) { - router.use(concurrentLimiter); -} - -if (isEnabled(LIMIT_MESSAGE_IP)) { - router.use(messageIpLimiter); -} - -if (isEnabled(LIMIT_MESSAGE_USER)) { - router.use(messageUserLimiter); -} - -router.use('/chat', chat); - -module.exports = router; diff --git a/api/server/routes/config.js b/api/server/routes/config.js index a53a636d05..dd93037dd9 100644 --- a/api/server/routes/config.js +++ b/api/server/routes/config.js @@ -1,10 +1,11 @@ const express = require('express'); +const { logger } = require('@librechat/data-schemas'); const { CacheKeys, defaultSocialLogins, Constants } = require('librechat-data-provider'); +const { getCustomConfig } = require('~/server/services/Config/getCustomConfig'); const { getLdapConfig } = require('~/server/services/Config/ldap'); const { getProjectByName } = require('~/models/Project'); const { isEnabled } = require('~/server/utils'); const { getLogStores } = require('~/cache'); -const { logger } = require('~/config'); const router = express.Router(); const emailLoginEnabled = @@ -21,6 +22,7 @@ const publicSharedLinksEnabled = router.get('/', async function (req, res) { const cache = getLogStores(CacheKeys.CONFIG_STORE); + const cachedStartupConfig = await cache.get(CacheKeys.STARTUP_CONFIG); if (cachedStartupConfig) { res.send(cachedStartupConfig); @@ -96,6 +98,18 @@ router.get('/', async function (req, res) { bundlerURL: process.env.SANDPACK_BUNDLER_URL, staticBundlerURL: process.env.SANDPACK_STATIC_BUNDLER_URL, }; + + payload.mcpServers = {}; + const config = await getCustomConfig(); + if (config?.mcpServers != null) { + for (const serverName in config.mcpServers) { + const serverConfig = config.mcpServers[serverName]; + payload.mcpServers[serverName] = { + customUserVars: serverConfig?.customUserVars || {}, + }; + } + } + /** @type {TCustomConfig['webSearch']} */ const webSearchConfig = req.app.locals.webSearch; if ( diff --git a/api/server/routes/convos.js b/api/server/routes/convos.js index 87bac6ed29..eb7e2c5c27 100644 --- a/api/server/routes/convos.js +++ b/api/server/routes/convos.js @@ -65,8 +65,14 @@ router.post('/gen_title', async (req, res) => { let title = await titleCache.get(key); if (!title) { - await sleep(2500); - title = await titleCache.get(key); + // Retry every 1s for up to 20s + for (let i = 0; i < 20; i++) { + await sleep(1000); + title = await titleCache.get(key); + if (title) { + break; + } + } } if (title) { diff --git a/api/server/routes/edit/gptPlugins.js b/api/server/routes/edit/gptPlugins.js deleted file mode 100644 index 94d9b91d0b..0000000000 --- a/api/server/routes/edit/gptPlugins.js +++ /dev/null @@ -1,207 +0,0 @@ -const express = require('express'); -const { getResponseSender } = require('librechat-data-provider'); -const { - setHeaders, - moderateText, - validateModel, - handleAbortError, - validateEndpoint, - buildEndpointOption, - createAbortController, -} = require('~/server/middleware'); -const { sendMessage, createOnProgress, formatSteps, formatAction } = require('~/server/utils'); -const { initializeClient } = require('~/server/services/Endpoints/gptPlugins'); -const { saveMessage, updateMessage } = require('~/models'); -const { validateTools } = require('~/app'); -const { logger } = require('~/config'); - -const router = express.Router(); - -router.use(moderateText); - -router.post( - '/', - validateEndpoint, - validateModel, - buildEndpointOption, - setHeaders, - async (req, res) => { - let { - text, - generation, - endpointOption, - conversationId, - responseMessageId, - isContinued = false, - parentMessageId = null, - overrideParentMessageId = null, - } = req.body; - - logger.debug('[/edit/gptPlugins]', { - text, - generation, - isContinued, - conversationId, - ...endpointOption, - }); - - let userMessage; - let userMessagePromise; - let promptTokens; - const sender = getResponseSender({ - ...endpointOption, - model: endpointOption.modelOptions.model, - }); - const userMessageId = parentMessageId; - const user = req.user.id; - - const plugin = { - loading: true, - inputs: [], - latest: null, - outputs: null, - }; - - const getReqData = (data = {}) => { - for (let key in data) { - if (key === 'userMessage') { - userMessage = data[key]; - } else if (key === 'userMessagePromise') { - userMessagePromise = data[key]; - } else if (key === 'responseMessageId') { - responseMessageId = data[key]; - } else if (key === 'promptTokens') { - promptTokens = data[key]; - } - } - }; - - const { - onProgress: progressCallback, - sendIntermediateMessage, - getPartialText, - } = createOnProgress({ - generation, - onProgress: () => { - if (plugin.loading === true) { - plugin.loading = false; - } - }, - }); - - const onChainEnd = (data) => { - let { intermediateSteps: steps } = data; - plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.'; - plugin.loading = false; - saveMessage( - req, - { ...userMessage, user }, - { context: 'api/server/routes/ask/gptPlugins.js - onChainEnd' }, - ); - sendIntermediateMessage(res, { - plugin, - parentMessageId: userMessage.messageId, - messageId: responseMessageId, - }); - // logger.debug('CHAIN END', plugin.outputs); - }; - - const getAbortData = () => ({ - sender, - conversationId, - userMessagePromise, - messageId: responseMessageId, - parentMessageId: overrideParentMessageId ?? userMessageId, - text: getPartialText(), - plugin: { ...plugin, loading: false }, - userMessage, - promptTokens, - }); - const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData); - - try { - endpointOption.tools = await validateTools(user, endpointOption.tools); - const { client } = await initializeClient({ req, res, endpointOption }); - - const onAgentAction = (action, start = false) => { - const formattedAction = formatAction(action); - plugin.inputs.push(formattedAction); - plugin.latest = formattedAction.plugin; - if (!start && !client.skipSaveUserMessage) { - saveMessage( - req, - { ...userMessage, user }, - { context: 'api/server/routes/ask/gptPlugins.js - onAgentAction' }, - ); - } - sendIntermediateMessage(res, { - plugin, - parentMessageId: userMessage.messageId, - messageId: responseMessageId, - }); - // logger.debug('PLUGIN ACTION', formattedAction); - }; - - let response = await client.sendMessage(text, { - user, - generation, - isContinued, - isEdited: true, - conversationId, - parentMessageId, - responseMessageId, - overrideParentMessageId, - getReqData, - onAgentAction, - onChainEnd, - onStart, - ...endpointOption, - progressCallback, - progressOptions: { - res, - plugin, - // parentMessageId: overrideParentMessageId || userMessageId, - }, - abortController, - }); - - if (overrideParentMessageId) { - response.parentMessageId = overrideParentMessageId; - } - - logger.debug('[/edit/gptPlugins] CLIENT RESPONSE', response); - - const { conversation = {} } = await response.databasePromise; - delete response.databasePromise; - conversation.title = - conversation && !conversation.title ? null : conversation?.title || 'New Chat'; - - sendMessage(res, { - title: conversation.title, - final: true, - conversation, - requestMessage: userMessage, - responseMessage: response, - }); - res.end(); - - response.plugin = { ...plugin, loading: false }; - await updateMessage( - req, - { ...response, user }, - { context: 'api/server/routes/edit/gptPlugins.js' }, - ); - } catch (error) { - const partialText = getPartialText(); - handleAbortError(res, req, error, { - partialText, - conversationId, - sender, - messageId: responseMessageId, - parentMessageId: userMessageId ?? parentMessageId, - }); - } - }, -); - -module.exports = router; diff --git a/api/server/routes/edit/index.js b/api/server/routes/edit/index.js index f1d47af3f9..92a1e63f63 100644 --- a/api/server/routes/edit/index.js +++ b/api/server/routes/edit/index.js @@ -3,7 +3,6 @@ const openAI = require('./openAI'); const custom = require('./custom'); const google = require('./google'); const anthropic = require('./anthropic'); -const gptPlugins = require('./gptPlugins'); const { isEnabled } = require('~/server/utils'); const { EModelEndpoint } = require('librechat-data-provider'); const { @@ -39,7 +38,6 @@ if (isEnabled(LIMIT_MESSAGE_USER)) { router.use(validateConvoAccess); router.use([`/${EModelEndpoint.azureOpenAI}`, `/${EModelEndpoint.openAI}`], openAI); -router.use(`/${EModelEndpoint.gptPlugins}`, gptPlugins); router.use(`/${EModelEndpoint.anthropic}`, anthropic); router.use(`/${EModelEndpoint.google}`, google); router.use(`/${EModelEndpoint.custom}`, custom); diff --git a/api/server/routes/files/files.js b/api/server/routes/files/files.js index bb2ae0bbe5..bdfdca65cf 100644 --- a/api/server/routes/files/files.js +++ b/api/server/routes/files/files.js @@ -283,7 +283,10 @@ router.post('/', async (req, res) => { message += ': ' + error.message; } - if (error.message?.includes('Invalid file format')) { + if ( + error.message?.includes('Invalid file format') || + error.message?.includes('No OCR result') + ) { message = error.message; } diff --git a/api/server/routes/files/multer.js b/api/server/routes/files/multer.js index f23ecd2823..257c309fa2 100644 --- a/api/server/routes/files/multer.js +++ b/api/server/routes/files/multer.js @@ -2,8 +2,8 @@ const fs = require('fs'); const path = require('path'); const crypto = require('crypto'); const multer = require('multer'); +const { sanitizeFilename } = require('@librechat/api'); const { fileConfig: defaultFileConfig, mergeFileConfig } = require('librechat-data-provider'); -const { sanitizeFilename } = require('~/server/utils/handleText'); const { getCustomConfig } = require('~/server/services/Config'); const storage = multer.diskStorage({ diff --git a/api/server/routes/files/multer.spec.js b/api/server/routes/files/multer.spec.js new file mode 100644 index 0000000000..2fb9147aef --- /dev/null +++ b/api/server/routes/files/multer.spec.js @@ -0,0 +1,573 @@ +/* eslint-disable no-unused-vars */ +/* eslint-disable jest/no-done-callback */ +const fs = require('fs'); +const os = require('os'); +const path = require('path'); +const crypto = require('crypto'); +const { createMulterInstance, storage, importFileFilter } = require('./multer'); + +// Mock only the config service that requires external dependencies +jest.mock('~/server/services/Config', () => ({ + getCustomConfig: jest.fn(() => + Promise.resolve({ + fileConfig: { + endpoints: { + openAI: { + supportedMimeTypes: ['image/jpeg', 'image/png', 'application/pdf'], + }, + default: { + supportedMimeTypes: ['image/jpeg', 'image/png', 'text/plain'], + }, + }, + serverFileSizeLimit: 10000000, // 10MB + }, + }), + ), +})); + +describe('Multer Configuration', () => { + let tempDir; + let mockReq; + let mockFile; + + beforeEach(() => { + // Create a temporary directory for each test + tempDir = fs.mkdtempSync(path.join(os.tmpdir(), 'multer-test-')); + + mockReq = { + user: { id: 'test-user-123' }, + app: { + locals: { + paths: { + uploads: tempDir, + }, + }, + }, + body: {}, + originalUrl: '/api/files/upload', + }; + + mockFile = { + originalname: 'test-file.jpg', + mimetype: 'image/jpeg', + size: 1024, + }; + + // Clear mocks + jest.clearAllMocks(); + }); + + afterEach(() => { + // Clean up temporary directory + if (fs.existsSync(tempDir)) { + fs.rmSync(tempDir, { recursive: true, force: true }); + } + }); + + describe('Storage Configuration', () => { + describe('destination function', () => { + it('should create the correct destination path', (done) => { + const cb = jest.fn((err, destination) => { + expect(err).toBeNull(); + expect(destination).toBe(path.join(tempDir, 'temp', 'test-user-123')); + expect(fs.existsSync(destination)).toBe(true); + done(); + }); + + storage.getDestination(mockReq, mockFile, cb); + }); + + it("should create directory recursively if it doesn't exist", (done) => { + const deepPath = path.join(tempDir, 'deep', 'nested', 'path'); + mockReq.app.locals.paths.uploads = deepPath; + + const cb = jest.fn((err, destination) => { + expect(err).toBeNull(); + expect(destination).toBe(path.join(deepPath, 'temp', 'test-user-123')); + expect(fs.existsSync(destination)).toBe(true); + done(); + }); + + storage.getDestination(mockReq, mockFile, cb); + }); + }); + + describe('filename function', () => { + it('should generate a UUID for req.file_id', (done) => { + const cb = jest.fn((err, filename) => { + expect(err).toBeNull(); + expect(mockReq.file_id).toBeDefined(); + expect(mockReq.file_id).toMatch( + /^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i, + ); + done(); + }); + + storage.getFilename(mockReq, mockFile, cb); + }); + + it('should decode URI components in filename', (done) => { + const encodedFile = { + ...mockFile, + originalname: encodeURIComponent('test file with spaces.jpg'), + }; + + const cb = jest.fn((err, filename) => { + expect(err).toBeNull(); + expect(encodedFile.originalname).toBe('test file with spaces.jpg'); + done(); + }); + + storage.getFilename(mockReq, encodedFile, cb); + }); + + it('should call real sanitizeFilename with properly encoded filename', (done) => { + // Test with a properly URI-encoded filename that needs sanitization + const unsafeFile = { + ...mockFile, + originalname: encodeURIComponent('test@#$%^&*()file with spaces!.jpg'), + }; + + const cb = jest.fn((err, filename) => { + expect(err).toBeNull(); + // The actual sanitizeFilename should have cleaned this up after decoding + expect(filename).not.toContain('@'); + expect(filename).not.toContain('#'); + expect(filename).not.toContain('*'); + expect(filename).not.toContain('!'); + // Should still preserve dots and hyphens + expect(filename).toContain('.jpg'); + done(); + }); + + storage.getFilename(mockReq, unsafeFile, cb); + }); + + it('should handle very long filenames with actual crypto', (done) => { + const longFile = { + ...mockFile, + originalname: 'a'.repeat(300) + '.jpg', + }; + + const cb = jest.fn((err, filename) => { + expect(err).toBeNull(); + expect(filename.length).toBeLessThanOrEqual(255); + expect(filename).toMatch(/\.jpg$/); // Should still end with .jpg + // Should contain a hex suffix if truncated + if (filename.length === 255) { + expect(filename).toMatch(/-[a-f0-9]{6}\.jpg$/); + } + done(); + }); + + storage.getFilename(mockReq, longFile, cb); + }); + + it('should generate unique file_id for each call', (done) => { + let firstFileId; + + const firstCb = jest.fn((err, filename) => { + expect(err).toBeNull(); + firstFileId = mockReq.file_id; + + // Reset req for second call + delete mockReq.file_id; + + const secondCb = jest.fn((err, filename) => { + expect(err).toBeNull(); + expect(mockReq.file_id).toBeDefined(); + expect(mockReq.file_id).not.toBe(firstFileId); + done(); + }); + + storage.getFilename(mockReq, mockFile, secondCb); + }); + + storage.getFilename(mockReq, mockFile, firstCb); + }); + }); + }); + + describe('Import File Filter', () => { + it('should accept JSON files by mimetype', (done) => { + const jsonFile = { + ...mockFile, + mimetype: 'application/json', + originalname: 'data.json', + }; + + const cb = jest.fn((err, result) => { + expect(err).toBeNull(); + expect(result).toBe(true); + done(); + }); + + importFileFilter(mockReq, jsonFile, cb); + }); + + it('should accept files with .json extension', (done) => { + const jsonFile = { + ...mockFile, + mimetype: 'text/plain', + originalname: 'data.json', + }; + + const cb = jest.fn((err, result) => { + expect(err).toBeNull(); + expect(result).toBe(true); + done(); + }); + + importFileFilter(mockReq, jsonFile, cb); + }); + + it('should reject non-JSON files', (done) => { + const textFile = { + ...mockFile, + mimetype: 'text/plain', + originalname: 'document.txt', + }; + + const cb = jest.fn((err, result) => { + expect(err).toBeInstanceOf(Error); + expect(err.message).toBe('Only JSON files are allowed'); + expect(result).toBe(false); + done(); + }); + + importFileFilter(mockReq, textFile, cb); + }); + + it('should handle files with uppercase .JSON extension', (done) => { + const jsonFile = { + ...mockFile, + mimetype: 'text/plain', + originalname: 'DATA.JSON', + }; + + const cb = jest.fn((err, result) => { + expect(err).toBeNull(); + expect(result).toBe(true); + done(); + }); + + importFileFilter(mockReq, jsonFile, cb); + }); + }); + + describe('File Filter with Real defaultFileConfig', () => { + it('should use real fileConfig.checkType for validation', async () => { + // Test with actual librechat-data-provider functions + const { + fileConfig, + imageMimeTypes, + applicationMimeTypes, + } = require('librechat-data-provider'); + + // Test that the real checkType function works with regex patterns + expect(fileConfig.checkType('image/jpeg', [imageMimeTypes])).toBe(true); + expect(fileConfig.checkType('video/mp4', [imageMimeTypes])).toBe(false); + expect(fileConfig.checkType('application/pdf', [applicationMimeTypes])).toBe(true); + expect(fileConfig.checkType('application/pdf', [])).toBe(false); + }); + + it('should handle audio files for speech-to-text endpoint with real config', async () => { + mockReq.originalUrl = '/api/speech/stt'; + + const multerInstance = await createMulterInstance(); + expect(multerInstance).toBeDefined(); + expect(typeof multerInstance.single).toBe('function'); + }); + + it('should reject unsupported file types using real config', async () => { + // Mock defaultFileConfig for this specific test + const originalCheckType = require('librechat-data-provider').fileConfig.checkType; + const mockCheckType = jest.fn().mockReturnValue(false); + require('librechat-data-provider').fileConfig.checkType = mockCheckType; + + try { + const multerInstance = await createMulterInstance(); + expect(multerInstance).toBeDefined(); + + // Test the actual file filter behavior would reject unsupported files + expect(mockCheckType).toBeDefined(); + } finally { + // Restore original function + require('librechat-data-provider').fileConfig.checkType = originalCheckType; + } + }); + + it('should use real mergeFileConfig function', async () => { + const { mergeFileConfig, mbToBytes } = require('librechat-data-provider'); + + // Test with actual merge function - note that it converts MB to bytes + const testConfig = { + serverFileSizeLimit: 5, // 5 MB + endpoints: { + custom: { + supportedMimeTypes: ['text/plain'], + }, + }, + }; + + const result = mergeFileConfig(testConfig); + + // The function converts MB to bytes, so 5 MB becomes 5 * 1024 * 1024 bytes + expect(result.serverFileSizeLimit).toBe(mbToBytes(5)); + expect(result.endpoints.custom.supportedMimeTypes).toBeDefined(); + // Should still have the default endpoints + expect(result.endpoints.default).toBeDefined(); + }); + }); + + describe('createMulterInstance with Real Functions', () => { + it('should create a multer instance with correct configuration', async () => { + const multerInstance = await createMulterInstance(); + + expect(multerInstance).toBeDefined(); + expect(typeof multerInstance.single).toBe('function'); + expect(typeof multerInstance.array).toBe('function'); + expect(typeof multerInstance.fields).toBe('function'); + }); + + it('should use real config merging', async () => { + const { getCustomConfig } = require('~/server/services/Config'); + + const multerInstance = await createMulterInstance(); + + expect(getCustomConfig).toHaveBeenCalled(); + expect(multerInstance).toBeDefined(); + }); + + it('should create multer instance with expected interface', async () => { + const multerInstance = await createMulterInstance(); + + expect(multerInstance).toBeDefined(); + expect(typeof multerInstance.single).toBe('function'); + expect(typeof multerInstance.array).toBe('function'); + expect(typeof multerInstance.fields).toBe('function'); + }); + }); + + describe('Real Crypto Integration', () => { + it('should use actual crypto.randomUUID()', (done) => { + // Spy on crypto.randomUUID to ensure it's called + const uuidSpy = jest.spyOn(crypto, 'randomUUID'); + + const cb = jest.fn((err, filename) => { + expect(err).toBeNull(); + expect(uuidSpy).toHaveBeenCalled(); + expect(mockReq.file_id).toMatch( + /^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i, + ); + + uuidSpy.mockRestore(); + done(); + }); + + storage.getFilename(mockReq, mockFile, cb); + }); + + it('should generate different UUIDs on subsequent calls', (done) => { + const uuids = []; + let callCount = 0; + const totalCalls = 5; + + const cb = jest.fn((err, filename) => { + expect(err).toBeNull(); + uuids.push(mockReq.file_id); + callCount++; + + if (callCount === totalCalls) { + // Check that all UUIDs are unique + const uniqueUuids = new Set(uuids); + expect(uniqueUuids.size).toBe(totalCalls); + done(); + } else { + // Reset for next call + delete mockReq.file_id; + storage.getFilename(mockReq, mockFile, cb); + } + }); + + // Start the chain + storage.getFilename(mockReq, mockFile, cb); + }); + + it('should generate cryptographically secure UUIDs', (done) => { + const generatedUuids = new Set(); + let callCount = 0; + const totalCalls = 10; + + const cb = jest.fn((err, filename) => { + expect(err).toBeNull(); + + // Verify UUID format and uniqueness + expect(mockReq.file_id).toMatch( + /^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$/i, + ); + + generatedUuids.add(mockReq.file_id); + callCount++; + + if (callCount === totalCalls) { + // All UUIDs should be unique + expect(generatedUuids.size).toBe(totalCalls); + done(); + } else { + // Reset for next call + delete mockReq.file_id; + storage.getFilename(mockReq, mockFile, cb); + } + }); + + // Start the chain + storage.getFilename(mockReq, mockFile, cb); + }); + }); + + describe('Error Handling', () => { + it('should handle CVE-2024-28870: empty field name DoS vulnerability', async () => { + // Test for the CVE where empty field name could cause unhandled exception + const multerInstance = await createMulterInstance(); + + // Create a mock request with empty field name (the vulnerability scenario) + const mockReqWithEmptyField = { + ...mockReq, + headers: { + 'content-type': 'multipart/form-data', + }, + }; + + const mockRes = { + status: jest.fn().mockReturnThis(), + json: jest.fn(), + end: jest.fn(), + }; + + // This should not crash or throw unhandled exceptions + const uploadMiddleware = multerInstance.single(''); // Empty field name + + const mockNext = jest.fn((err) => { + // If there's an error, it should be handled gracefully, not crash + if (err) { + expect(err).toBeInstanceOf(Error); + // The error should be handled, not crash the process + } + }); + + // This should complete without crashing the process + expect(() => { + uploadMiddleware(mockReqWithEmptyField, mockRes, mockNext); + }).not.toThrow(); + }); + + it('should handle file system errors when directory creation fails', (done) => { + // Test with a non-existent parent directory to simulate fs issues + const invalidPath = '/nonexistent/path/that/should/not/exist'; + mockReq.app.locals.paths.uploads = invalidPath; + + try { + // Call getDestination which should fail due to permission/path issues + storage.getDestination(mockReq, mockFile, (err, destination) => { + // If callback is reached, we didn't get the expected error + done(new Error('Expected mkdirSync to throw an error but callback was called')); + }); + // If we get here without throwing, something unexpected happened + done(new Error('Expected mkdirSync to throw an error but no error was thrown')); + } catch (error) { + // This is the expected behavior - mkdirSync throws synchronously for invalid paths + // On Linux, this typically returns EACCES (permission denied) + // On macOS/Darwin, this returns ENOENT (no such file or directory) + expect(['EACCES', 'ENOENT']).toContain(error.code); + done(); + } + }); + + it('should handle malformed filenames with real sanitization', (done) => { + const malformedFile = { + ...mockFile, + originalname: null, // This should be handled gracefully + }; + + const cb = jest.fn((err, filename) => { + // The function should handle this gracefully + expect(typeof err === 'object' || err === null).toBe(true); + done(); + }); + + try { + storage.getFilename(mockReq, malformedFile, cb); + } catch (error) { + // If it throws, that's also acceptable behavior + done(); + } + }); + + it('should handle edge cases in filename sanitization', (done) => { + const edgeCaseFiles = [ + { originalname: '', expected: /_/ }, + { originalname: '.hidden', expected: /^_\.hidden/ }, + { originalname: '../../../etc/passwd', expected: /passwd/ }, + { originalname: 'file\x00name.txt', expected: /file_name\.txt/ }, + ]; + + let testCount = 0; + + const testNextFile = (fileData) => { + const fileToTest = { ...mockFile, originalname: fileData.originalname }; + + const cb = jest.fn((err, filename) => { + expect(err).toBeNull(); + expect(filename).toMatch(fileData.expected); + + testCount++; + if (testCount === edgeCaseFiles.length) { + done(); + } else { + testNextFile(edgeCaseFiles[testCount]); + } + }); + + storage.getFilename(mockReq, fileToTest, cb); + }; + + testNextFile(edgeCaseFiles[0]); + }); + }); + + describe('Real Configuration Testing', () => { + it('should handle missing custom config gracefully with real mergeFileConfig', async () => { + const { getCustomConfig } = require('~/server/services/Config'); + + // Mock getCustomConfig to return undefined + getCustomConfig.mockResolvedValueOnce(undefined); + + const multerInstance = await createMulterInstance(); + expect(multerInstance).toBeDefined(); + expect(typeof multerInstance.single).toBe('function'); + }); + + it('should properly integrate real fileConfig with custom endpoints', async () => { + const { getCustomConfig } = require('~/server/services/Config'); + + // Mock a custom config with additional endpoints + getCustomConfig.mockResolvedValueOnce({ + fileConfig: { + endpoints: { + anthropic: { + supportedMimeTypes: ['text/plain', 'image/png'], + }, + }, + serverFileSizeLimit: 20, // 20 MB + }, + }); + + const multerInstance = await createMulterInstance(); + expect(multerInstance).toBeDefined(); + + // Verify that getCustomConfig was called (we can't spy on the actual merge function easily) + expect(getCustomConfig).toHaveBeenCalled(); + }); + }); +}); diff --git a/api/server/routes/index.js b/api/server/routes/index.js index 449759383d..ec97ba3986 100644 --- a/api/server/routes/index.js +++ b/api/server/routes/index.js @@ -4,11 +4,11 @@ const tokenizer = require('./tokenizer'); const endpoints = require('./endpoints'); const staticRoute = require('./static'); const messages = require('./messages'); +const memories = require('./memories'); const presets = require('./presets'); const prompts = require('./prompts'); const balance = require('./balance'); const plugins = require('./plugins'); -const bedrock = require('./bedrock'); const actions = require('./actions'); const banner = require('./banner'); const search = require('./search'); @@ -25,10 +25,9 @@ const auth = require('./auth'); const edit = require('./edit'); const keys = require('./keys'); const user = require('./user'); -const ask = require('./ask'); +const mcp = require('./mcp'); module.exports = { - ask, edit, auth, keys, @@ -44,16 +43,17 @@ module.exports = { search, config, models, - bedrock, prompts, plugins, actions, presets, balance, messages, + memories, endpoints, tokenizer, assistants, categories, staticRoute, + mcp, }; diff --git a/api/server/routes/mcp.js b/api/server/routes/mcp.js new file mode 100644 index 0000000000..3dfed4d240 --- /dev/null +++ b/api/server/routes/mcp.js @@ -0,0 +1,205 @@ +const { Router } = require('express'); +const { MCPOAuthHandler } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); +const { CacheKeys } = require('librechat-data-provider'); +const { requireJwtAuth } = require('~/server/middleware'); +const { getFlowStateManager } = require('~/config'); +const { getLogStores } = require('~/cache'); + +const router = Router(); + +/** + * Initiate OAuth flow + * This endpoint is called when the user clicks the auth link in the UI + */ +router.get('/:serverName/oauth/initiate', requireJwtAuth, async (req, res) => { + try { + const { serverName } = req.params; + const { userId, flowId } = req.query; + const user = req.user; + + // Verify the userId matches the authenticated user + if (userId !== user.id) { + return res.status(403).json({ error: 'User mismatch' }); + } + + logger.debug('[MCP OAuth] Initiate request', { serverName, userId, flowId }); + + const flowsCache = getLogStores(CacheKeys.FLOWS); + const flowManager = getFlowStateManager(flowsCache); + + /** Flow state to retrieve OAuth config */ + const flowState = await flowManager.getFlowState(flowId, 'mcp_oauth'); + if (!flowState) { + logger.error('[MCP OAuth] Flow state not found', { flowId }); + return res.status(404).json({ error: 'Flow not found' }); + } + + const { serverUrl, oauth: oauthConfig } = flowState.metadata || {}; + if (!serverUrl || !oauthConfig) { + logger.error('[MCP OAuth] Missing server URL or OAuth config in flow state'); + return res.status(400).json({ error: 'Invalid flow state' }); + } + + const { authorizationUrl, flowId: oauthFlowId } = await MCPOAuthHandler.initiateOAuthFlow( + serverName, + serverUrl, + userId, + oauthConfig, + ); + + logger.debug('[MCP OAuth] OAuth flow initiated', { oauthFlowId, authorizationUrl }); + + // Redirect user to the authorization URL + res.redirect(authorizationUrl); + } catch (error) { + logger.error('[MCP OAuth] Failed to initiate OAuth', error); + res.status(500).json({ error: 'Failed to initiate OAuth' }); + } +}); + +/** + * OAuth callback handler + * This handles the OAuth callback after the user has authorized the application + */ +router.get('/:serverName/oauth/callback', async (req, res) => { + try { + const { serverName } = req.params; + const { code, state, error: oauthError } = req.query; + + logger.debug('[MCP OAuth] Callback received', { + serverName, + code: code ? 'present' : 'missing', + state, + error: oauthError, + }); + + if (oauthError) { + logger.error('[MCP OAuth] OAuth error received', { error: oauthError }); + return res.redirect(`/oauth/error?error=${encodeURIComponent(String(oauthError))}`); + } + + if (!code || typeof code !== 'string') { + logger.error('[MCP OAuth] Missing or invalid code'); + return res.redirect('/oauth/error?error=missing_code'); + } + + if (!state || typeof state !== 'string') { + logger.error('[MCP OAuth] Missing or invalid state'); + return res.redirect('/oauth/error?error=missing_state'); + } + + // Extract flow ID from state + const flowId = state; + logger.debug('[MCP OAuth] Using flow ID from state', { flowId }); + + const flowsCache = getLogStores(CacheKeys.FLOWS); + const flowManager = getFlowStateManager(flowsCache); + + logger.debug('[MCP OAuth] Getting flow state for flowId: ' + flowId); + const flowState = await MCPOAuthHandler.getFlowState(flowId, flowManager); + + if (!flowState) { + logger.error('[MCP OAuth] Flow state not found for flowId:', flowId); + return res.redirect('/oauth/error?error=invalid_state'); + } + + logger.debug('[MCP OAuth] Flow state details', { + serverName: flowState.serverName, + userId: flowState.userId, + hasMetadata: !!flowState.metadata, + hasClientInfo: !!flowState.clientInfo, + hasCodeVerifier: !!flowState.codeVerifier, + }); + + // Complete the OAuth flow + logger.debug('[MCP OAuth] Completing OAuth flow'); + const tokens = await MCPOAuthHandler.completeOAuthFlow(flowId, code, flowManager); + logger.info('[MCP OAuth] OAuth flow completed, tokens received in callback route'); + + // For system-level OAuth, we need to store the tokens and retry the connection + if (flowState.userId === 'system') { + logger.debug(`[MCP OAuth] System-level OAuth completed for ${serverName}`); + } + + /** ID of the flow that the tool/connection is waiting for */ + const toolFlowId = flowState.metadata?.toolFlowId; + if (toolFlowId) { + logger.debug('[MCP OAuth] Completing tool flow', { toolFlowId }); + await flowManager.completeFlow(toolFlowId, 'mcp_oauth', tokens); + } + + /** Redirect to success page with flowId and serverName */ + const redirectUrl = `/oauth/success?serverName=${encodeURIComponent(serverName)}`; + res.redirect(redirectUrl); + } catch (error) { + logger.error('[MCP OAuth] OAuth callback error', error); + res.redirect('/oauth/error?error=callback_failed'); + } +}); + +/** + * Get OAuth tokens for a completed flow + * This is primarily for user-level OAuth flows + */ +router.get('/oauth/tokens/:flowId', requireJwtAuth, async (req, res) => { + try { + const { flowId } = req.params; + const user = req.user; + + if (!user?.id) { + return res.status(401).json({ error: 'User not authenticated' }); + } + + // Allow system flows or user-owned flows + if (!flowId.startsWith(`${user.id}:`) && !flowId.startsWith('system:')) { + return res.status(403).json({ error: 'Access denied' }); + } + + const flowsCache = getLogStores(CacheKeys.FLOWS); + const flowManager = getFlowStateManager(flowsCache); + + const flowState = await flowManager.getFlowState(flowId, 'mcp_oauth'); + if (!flowState) { + return res.status(404).json({ error: 'Flow not found' }); + } + + if (flowState.status !== 'COMPLETED') { + return res.status(400).json({ error: 'Flow not completed' }); + } + + res.json({ tokens: flowState.result }); + } catch (error) { + logger.error('[MCP OAuth] Failed to get tokens', error); + res.status(500).json({ error: 'Failed to get tokens' }); + } +}); + +/** + * Check OAuth flow status + * This endpoint can be used to poll the status of an OAuth flow + */ +router.get('/oauth/status/:flowId', async (req, res) => { + try { + const { flowId } = req.params; + const flowsCache = getLogStores(CacheKeys.FLOWS); + const flowManager = getFlowStateManager(flowsCache); + + const flowState = await flowManager.getFlowState(flowId, 'mcp_oauth'); + if (!flowState) { + return res.status(404).json({ error: 'Flow not found' }); + } + + res.json({ + status: flowState.status, + completed: flowState.status === 'COMPLETED', + failed: flowState.status === 'FAILED', + error: flowState.error, + }); + } catch (error) { + logger.error('[MCP OAuth] Failed to get flow status', error); + res.status(500).json({ error: 'Failed to get flow status' }); + } +}); + +module.exports = router; diff --git a/api/server/routes/memories.js b/api/server/routes/memories.js new file mode 100644 index 0000000000..fe520de000 --- /dev/null +++ b/api/server/routes/memories.js @@ -0,0 +1,237 @@ +const express = require('express'); +const { Tokenizer, generateCheckAccess } = require('@librechat/api'); +const { PermissionTypes, Permissions } = require('librechat-data-provider'); +const { + getAllUserMemories, + toggleUserMemories, + createMemory, + deleteMemory, + setMemory, +} = require('~/models'); +const { requireJwtAuth } = require('~/server/middleware'); +const { getRoleByName } = require('~/models/Role'); + +const router = express.Router(); + +const checkMemoryRead = generateCheckAccess({ + permissionType: PermissionTypes.MEMORIES, + permissions: [Permissions.USE, Permissions.READ], + getRoleByName, +}); +const checkMemoryCreate = generateCheckAccess({ + permissionType: PermissionTypes.MEMORIES, + permissions: [Permissions.USE, Permissions.CREATE], + getRoleByName, +}); +const checkMemoryUpdate = generateCheckAccess({ + permissionType: PermissionTypes.MEMORIES, + permissions: [Permissions.USE, Permissions.UPDATE], + getRoleByName, +}); +const checkMemoryDelete = generateCheckAccess({ + permissionType: PermissionTypes.MEMORIES, + permissions: [Permissions.USE, Permissions.UPDATE], + getRoleByName, +}); +const checkMemoryOptOut = generateCheckAccess({ + permissionType: PermissionTypes.MEMORIES, + permissions: [Permissions.USE, Permissions.OPT_OUT], + getRoleByName, +}); + +router.use(requireJwtAuth); + +/** + * GET /memories + * Returns all memories for the authenticated user, sorted by updated_at (newest first). + * Also includes memory usage percentage based on token limit. + */ +router.get('/', checkMemoryRead, async (req, res) => { + try { + const memories = await getAllUserMemories(req.user.id); + + const sortedMemories = memories.sort( + (a, b) => new Date(b.updated_at).getTime() - new Date(a.updated_at).getTime(), + ); + + const totalTokens = memories.reduce((sum, memory) => { + return sum + (memory.tokenCount || 0); + }, 0); + + const memoryConfig = req.app.locals?.memory; + const tokenLimit = memoryConfig?.tokenLimit; + + let usagePercentage = null; + if (tokenLimit && tokenLimit > 0) { + usagePercentage = Math.min(100, Math.round((totalTokens / tokenLimit) * 100)); + } + + res.json({ + memories: sortedMemories, + totalTokens, + tokenLimit: tokenLimit || null, + usagePercentage, + }); + } catch (error) { + res.status(500).json({ error: error.message }); + } +}); + +/** + * POST /memories + * Creates a new memory entry for the authenticated user. + * Body: { key: string, value: string } + * Returns 201 and { created: true, memory: } when successful. + */ +router.post('/', checkMemoryCreate, async (req, res) => { + const { key, value } = req.body; + + if (typeof key !== 'string' || key.trim() === '') { + return res.status(400).json({ error: 'Key is required and must be a non-empty string.' }); + } + + if (typeof value !== 'string' || value.trim() === '') { + return res.status(400).json({ error: 'Value is required and must be a non-empty string.' }); + } + + try { + const tokenCount = Tokenizer.getTokenCount(value, 'o200k_base'); + + const memories = await getAllUserMemories(req.user.id); + + // Check token limit + const memoryConfig = req.app.locals?.memory; + const tokenLimit = memoryConfig?.tokenLimit; + + if (tokenLimit) { + const currentTotalTokens = memories.reduce( + (sum, memory) => sum + (memory.tokenCount || 0), + 0, + ); + if (currentTotalTokens + tokenCount > tokenLimit) { + return res.status(400).json({ + error: `Adding this memory would exceed the token limit of ${tokenLimit}. Current usage: ${currentTotalTokens} tokens.`, + }); + } + } + + const result = await createMemory({ + userId: req.user.id, + key: key.trim(), + value: value.trim(), + tokenCount, + }); + + if (!result.ok) { + return res.status(500).json({ error: 'Failed to create memory.' }); + } + + const updatedMemories = await getAllUserMemories(req.user.id); + const newMemory = updatedMemories.find((m) => m.key === key.trim()); + + res.status(201).json({ created: true, memory: newMemory }); + } catch (error) { + if (error.message && error.message.includes('already exists')) { + return res.status(409).json({ error: 'Memory with this key already exists.' }); + } + res.status(500).json({ error: error.message }); + } +}); + +/** + * PATCH /memories/preferences + * Updates the user's memory preferences (e.g., enabling/disabling memories). + * Body: { memories: boolean } + * Returns 200 and { updated: true, preferences: { memories: boolean } } when successful. + */ +router.patch('/preferences', checkMemoryOptOut, async (req, res) => { + const { memories } = req.body; + + if (typeof memories !== 'boolean') { + return res.status(400).json({ error: 'memories must be a boolean value.' }); + } + + try { + const updatedUser = await toggleUserMemories(req.user.id, memories); + + if (!updatedUser) { + return res.status(404).json({ error: 'User not found.' }); + } + + res.json({ + updated: true, + preferences: { + memories: updatedUser.personalization?.memories ?? true, + }, + }); + } catch (error) { + res.status(500).json({ error: error.message }); + } +}); + +/** + * PATCH /memories/:key + * Updates the value of an existing memory entry for the authenticated user. + * Body: { value: string } + * Returns 200 and { updated: true, memory: } when successful. + */ +router.patch('/:key', checkMemoryUpdate, async (req, res) => { + const { key } = req.params; + const { value } = req.body || {}; + + if (typeof value !== 'string' || value.trim() === '') { + return res.status(400).json({ error: 'Value is required and must be a non-empty string.' }); + } + + try { + const tokenCount = Tokenizer.getTokenCount(value, 'o200k_base'); + + const memories = await getAllUserMemories(req.user.id); + const existingMemory = memories.find((m) => m.key === key); + + if (!existingMemory) { + return res.status(404).json({ error: 'Memory not found.' }); + } + + const result = await setMemory({ + userId: req.user.id, + key, + value, + tokenCount, + }); + + if (!result.ok) { + return res.status(500).json({ error: 'Failed to update memory.' }); + } + + const updatedMemories = await getAllUserMemories(req.user.id); + const updatedMemory = updatedMemories.find((m) => m.key === key); + + res.json({ updated: true, memory: updatedMemory }); + } catch (error) { + res.status(500).json({ error: error.message }); + } +}); + +/** + * DELETE /memories/:key + * Deletes a memory entry for the authenticated user. + * Returns 200 and { deleted: true } when successful. + */ +router.delete('/:key', checkMemoryDelete, async (req, res) => { + const { key } = req.params; + + try { + const result = await deleteMemory({ userId: req.user.id, key }); + + if (!result.ok) { + return res.status(404).json({ error: 'Memory not found.' }); + } + + res.json({ deleted: true }); + } catch (error) { + res.status(500).json({ error: error.message }); + } +}); + +module.exports = router; diff --git a/api/server/routes/messages.js b/api/server/routes/messages.js index d5980ae55b..0a277a1bd6 100644 --- a/api/server/routes/messages.js +++ b/api/server/routes/messages.js @@ -1,4 +1,5 @@ const express = require('express'); +const { logger } = require('@librechat/data-schemas'); const { ContentTypes } = require('librechat-data-provider'); const { saveConvo, @@ -13,8 +14,7 @@ const { requireJwtAuth, validateMessageReq } = require('~/server/middleware'); const { cleanUpPrimaryKeyValue } = require('~/lib/utils/misc'); const { getConvosQueried } = require('~/models/Conversation'); const { countTokens } = require('~/server/utils'); -const { Message } = require('~/models/Message'); -const { logger } = require('~/config'); +const { Message } = require('~/db/models'); const router = express.Router(); router.use(requireJwtAuth); @@ -40,7 +40,11 @@ router.get('/', async (req, res) => { const sortOrder = sortDirection === 'asc' ? 1 : -1; if (conversationId && messageId) { - const message = await Message.findOne({ conversationId, messageId, user: user }).lean(); + const message = await Message.findOne({ + conversationId, + messageId, + user: user, + }).lean(); response = { messages: message ? [message] : [], nextCursor: null }; } else if (conversationId) { const filter = { conversationId, user: user }; @@ -231,12 +235,13 @@ router.put('/:conversationId/:messageId', validateMessageReq, async (req, res) = return res.status(400).json({ error: 'Content part not found' }); } - if (updatedContent[index].type !== ContentTypes.TEXT) { + const currentPartType = updatedContent[index].type; + if (currentPartType !== ContentTypes.TEXT && currentPartType !== ContentTypes.THINK) { return res.status(400).json({ error: 'Cannot update non-text content' }); } - const oldText = updatedContent[index].text; - updatedContent[index] = { type: ContentTypes.TEXT, text }; + const oldText = updatedContent[index][currentPartType]; + updatedContent[index] = { type: currentPartType, [currentPartType]: text }; let tokenCount = message.tokenCount; if (tokenCount !== undefined) { @@ -253,6 +258,31 @@ router.put('/:conversationId/:messageId', validateMessageReq, async (req, res) = } }); +router.put('/:conversationId/:messageId/feedback', validateMessageReq, async (req, res) => { + try { + const { conversationId, messageId } = req.params; + const { feedback } = req.body; + + const updatedMessage = await updateMessage( + req, + { + messageId, + feedback: feedback || null, + }, + { context: 'updateFeedback' }, + ); + + res.json({ + messageId, + conversationId, + feedback: updatedMessage.feedback, + }); + } catch (error) { + logger.error('Error updating message feedback:', error); + res.status(500).json({ error: 'Failed to update feedback' }); + } +}); + router.delete('/:conversationId/:messageId', validateMessageReq, async (req, res) => { try { const { messageId } = req.params; diff --git a/api/server/routes/oauth.js b/api/server/routes/oauth.js index bc8d120ef5..afc4a05b75 100644 --- a/api/server/routes/oauth.js +++ b/api/server/routes/oauth.js @@ -47,7 +47,9 @@ const oauthHandler = async (req, res) => { router.get('/error', (req, res) => { // A single error message is pushed by passport when authentication fails. - logger.error('Error in OAuth authentication:', { message: req.session.messages.pop() }); + logger.error('Error in OAuth authentication:', { + message: req.session?.messages?.pop() || 'Unknown error', + }); // Redirect to login page with auth_failed parameter to prevent infinite redirect loops res.redirect(`${domains.client}/login?redirect=false`); diff --git a/api/server/routes/prompts.js b/api/server/routes/prompts.js index e3ab5bf5d3..c18418cba5 100644 --- a/api/server/routes/prompts.js +++ b/api/server/routes/prompts.js @@ -1,5 +1,7 @@ const express = require('express'); -const { PermissionTypes, Permissions, SystemRoles } = require('librechat-data-provider'); +const { logger } = require('@librechat/data-schemas'); +const { generateCheckAccess } = require('@librechat/api'); +const { Permissions, SystemRoles, PermissionTypes } = require('librechat-data-provider'); const { getPrompt, getPrompts, @@ -14,24 +16,30 @@ const { // updatePromptLabels, makePromptProduction, } = require('~/models/Prompt'); -const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware'); -const { logger } = require('~/config'); +const { requireJwtAuth } = require('~/server/middleware'); +const { getRoleByName } = require('~/models/Role'); const router = express.Router(); -const checkPromptAccess = generateCheckAccess(PermissionTypes.PROMPTS, [Permissions.USE]); -const checkPromptCreate = generateCheckAccess(PermissionTypes.PROMPTS, [ - Permissions.USE, - Permissions.CREATE, -]); +const checkPromptAccess = generateCheckAccess({ + permissionType: PermissionTypes.PROMPTS, + permissions: [Permissions.USE], + getRoleByName, +}); +const checkPromptCreate = generateCheckAccess({ + permissionType: PermissionTypes.PROMPTS, + permissions: [Permissions.USE, Permissions.CREATE], + getRoleByName, +}); -const checkGlobalPromptShare = generateCheckAccess( - PermissionTypes.PROMPTS, - [Permissions.USE, Permissions.CREATE], - { +const checkGlobalPromptShare = generateCheckAccess({ + permissionType: PermissionTypes.PROMPTS, + permissions: [Permissions.USE, Permissions.CREATE], + bodyProps: { [Permissions.SHARED_GLOBAL]: ['projectIds', 'removeProjectIds'], }, -); + getRoleByName, +}); router.use(requireJwtAuth); router.use(checkPromptAccess); diff --git a/api/server/routes/roles.js b/api/server/routes/roles.js index 17768c7de6..aefbfcec0c 100644 --- a/api/server/routes/roles.js +++ b/api/server/routes/roles.js @@ -1,6 +1,7 @@ const express = require('express'); const { promptPermissionsSchema, + memoryPermissionsSchema, agentPermissionsSchema, PermissionTypes, roleDefaults, @@ -118,4 +119,43 @@ router.put('/:roleName/agents', checkAdmin, async (req, res) => { } }); +/** + * PUT /api/roles/:roleName/memories + * Update memory permissions for a specific role + */ +router.put('/:roleName/memories', checkAdmin, async (req, res) => { + const { roleName: _r } = req.params; + // TODO: TEMP, use a better parsing for roleName + const roleName = _r.toUpperCase(); + /** @type {TRole['permissions']['MEMORIES']} */ + const updates = req.body; + + try { + const parsedUpdates = memoryPermissionsSchema.partial().parse(updates); + + const role = await getRoleByName(roleName); + if (!role) { + return res.status(404).send({ message: 'Role not found' }); + } + + const currentPermissions = + role.permissions?.[PermissionTypes.MEMORIES] || role[PermissionTypes.MEMORIES] || {}; + + const mergedUpdates = { + permissions: { + ...role.permissions, + [PermissionTypes.MEMORIES]: { + ...currentPermissions, + ...parsedUpdates, + }, + }, + }; + + const updatedRole = await updateRoleByName(roleName, mergedUpdates); + res.status(200).send(updatedRole); + } catch (error) { + return res.status(400).send({ message: 'Invalid memory permissions.', error: error.errors }); + } +}); + module.exports = router; diff --git a/api/server/routes/share.js b/api/server/routes/share.js index e551f4a354..14c25271fc 100644 --- a/api/server/routes/share.js +++ b/api/server/routes/share.js @@ -1,15 +1,15 @@ const express = require('express'); - +const { isEnabled } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { - getSharedLink, getSharedMessages, createSharedLink, updateSharedLink, - getSharedLinks, deleteSharedLink, -} = require('~/models/Share'); + getSharedLinks, + getSharedLink, +} = require('~/models'); const requireJwtAuth = require('~/server/middleware/requireJwtAuth'); -const { isEnabled } = require('~/server/utils'); const router = express.Router(); /** @@ -35,6 +35,7 @@ if (allowSharedLinks) { res.status(404).end(); } } catch (error) { + logger.error('Error getting shared messages:', error); res.status(500).json({ message: 'Error getting shared messages' }); } }, @@ -54,9 +55,7 @@ router.get('/', requireJwtAuth, async (req, res) => { sortDirection: ['asc', 'desc'].includes(req.query.sortDirection) ? req.query.sortDirection : 'desc', - search: req.query.search - ? decodeURIComponent(req.query.search.trim()) - : undefined, + search: req.query.search ? decodeURIComponent(req.query.search.trim()) : undefined, }; const result = await getSharedLinks( @@ -75,7 +74,7 @@ router.get('/', requireJwtAuth, async (req, res) => { hasNextPage: result.hasNextPage, }); } catch (error) { - console.error('Error getting shared links:', error); + logger.error('Error getting shared links:', error); res.status(500).json({ message: 'Error getting shared links', error: error.message, @@ -93,6 +92,7 @@ router.get('/link/:conversationId', requireJwtAuth, async (req, res) => { conversationId: req.params.conversationId, }); } catch (error) { + logger.error('Error getting shared link:', error); res.status(500).json({ message: 'Error getting shared link' }); } }); @@ -106,6 +106,7 @@ router.post('/:conversationId', requireJwtAuth, async (req, res) => { res.status(404).end(); } } catch (error) { + logger.error('Error creating shared link:', error); res.status(500).json({ message: 'Error creating shared link' }); } }); @@ -119,6 +120,7 @@ router.patch('/:shareId', requireJwtAuth, async (req, res) => { res.status(404).end(); } } catch (error) { + logger.error('Error updating shared link:', error); res.status(500).json({ message: 'Error updating shared link' }); } }); @@ -133,7 +135,8 @@ router.delete('/:shareId', requireJwtAuth, async (req, res) => { return res.status(200).json(result); } catch (error) { - return res.status(400).json({ message: error.message }); + logger.error('Error deleting shared link:', error); + return res.status(400).json({ message: 'Error deleting shared link' }); } }); diff --git a/api/server/routes/tags.js b/api/server/routes/tags.js index d3e27d3711..0a4ee5084c 100644 --- a/api/server/routes/tags.js +++ b/api/server/routes/tags.js @@ -1,18 +1,24 @@ const express = require('express'); +const { logger } = require('@librechat/data-schemas'); +const { generateCheckAccess } = require('@librechat/api'); const { PermissionTypes, Permissions } = require('librechat-data-provider'); const { - getConversationTags, + updateTagsForConversation, updateConversationTag, createConversationTag, deleteConversationTag, - updateTagsForConversation, + getConversationTags, } = require('~/models/ConversationTag'); -const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware'); -const { logger } = require('~/config'); +const { requireJwtAuth } = require('~/server/middleware'); +const { getRoleByName } = require('~/models/Role'); const router = express.Router(); -const checkBookmarkAccess = generateCheckAccess(PermissionTypes.BOOKMARKS, [Permissions.USE]); +const checkBookmarkAccess = generateCheckAccess({ + permissionType: PermissionTypes.BOOKMARKS, + permissions: [Permissions.USE], + getRoleByName, +}); router.use(requireJwtAuth); router.use(checkBookmarkAccess); diff --git a/api/server/services/ActionService.js b/api/server/services/ActionService.js index 740a77092a..b9555a752c 100644 --- a/api/server/services/ActionService.js +++ b/api/server/services/ActionService.js @@ -1,7 +1,15 @@ const jwt = require('jsonwebtoken'); const { nanoid } = require('nanoid'); const { tool } = require('@langchain/core/tools'); +const { logger } = require('@librechat/data-schemas'); const { GraphEvents, sleep } = require('@librechat/agents'); +const { + sendEvent, + encryptV2, + decryptV2, + logAxiosError, + refreshAccessToken, +} = require('@librechat/api'); const { Time, CacheKeys, @@ -12,13 +20,10 @@ const { isImageVisionTool, actionDomainSeparator, } = require('librechat-data-provider'); -const { refreshAccessToken } = require('~/server/services/TokenService'); -const { logger, getFlowStateManager, sendEvent } = require('~/config'); -const { encryptV2, decryptV2 } = require('~/server/utils/crypto'); +const { findToken, updateToken, createToken } = require('~/models'); const { getActions, deleteActions } = require('~/models/Action'); const { deleteAssistant } = require('~/models/Assistant'); -const { findToken } = require('~/models/Token'); -const { logAxiosError } = require('~/utils'); +const { getFlowStateManager } = require('~/config'); const { getLogStores } = require('~/cache'); const JWT_SECRET = process.env.JWT_SECRET; @@ -208,6 +213,7 @@ async function createActionTool({ userId: userId, client_url: metadata.auth.client_url, redirect_uri: `${process.env.DOMAIN_SERVER}/api/actions/${action_id}/oauth/callback`, + token_exchange_method: metadata.auth.token_exchange_method, /** Encrypted values */ encrypted_oauth_client_id: encrypted.oauth_client_id, encrypted_oauth_client_secret: encrypted.oauth_client_secret, @@ -256,14 +262,22 @@ async function createActionTool({ try { const refresh_token = await decryptV2(refreshTokenData.token); const refreshTokens = async () => - await refreshAccessToken({ - userId, - identifier, - refresh_token, - client_url: metadata.auth.client_url, - encrypted_oauth_client_id: encrypted.oauth_client_id, - encrypted_oauth_client_secret: encrypted.oauth_client_secret, - }); + await refreshAccessToken( + { + userId, + identifier, + refresh_token, + client_url: metadata.auth.client_url, + encrypted_oauth_client_id: encrypted.oauth_client_id, + token_exchange_method: metadata.auth.token_exchange_method, + encrypted_oauth_client_secret: encrypted.oauth_client_secret, + }, + { + findToken, + updateToken, + createToken, + }, + ); const flowsCache = getLogStores(CacheKeys.FLOWS); const flowManager = getFlowStateManager(flowsCache); const refreshData = await flowManager.createFlowWithHandler( diff --git a/api/server/services/AppService.interface.spec.js b/api/server/services/AppService.interface.spec.js index 0bf9d07dcc..90168d4778 100644 --- a/api/server/services/AppService.interface.spec.js +++ b/api/server/services/AppService.interface.spec.js @@ -1,5 +1,7 @@ -jest.mock('~/models/Role', () => ({ +jest.mock('~/models', () => ({ initializeRoles: jest.fn(), +})); +jest.mock('~/models/Role', () => ({ updateAccessPermissions: jest.fn(), getRoleByName: jest.fn(), updateRoleByName: jest.fn(), diff --git a/api/server/services/AppService.js b/api/server/services/AppService.js index 6a1cdfc695..6b7ff7417f 100644 --- a/api/server/services/AppService.js +++ b/api/server/services/AppService.js @@ -1,11 +1,12 @@ const { FileSources, loadOCRConfig, - processMCPEnv, EModelEndpoint, + loadMemoryConfig, getConfigDefaults, loadWebSearchConfig, } = require('librechat-data-provider'); +const { agentsConfigSetup } = require('@librechat/api'); const { checkHealth, checkConfig, @@ -24,10 +25,9 @@ const { azureConfigSetup } = require('./start/azureOpenAI'); const { processModelSpecs } = require('./start/modelSpecs'); const { initializeS3 } = require('./Files/S3/initialize'); const { loadAndFormatTools } = require('./ToolService'); -const { agentsConfigSetup } = require('./start/agents'); -const { initializeRoles } = require('~/models/Role'); const { isEnabled } = require('~/server/utils'); -const { getMCPManager } = require('~/config'); +const { initializeRoles } = require('~/models'); +const { setCachedTools } = require('./Config'); const paths = require('~/config/paths'); /** @@ -44,6 +44,7 @@ const AppService = async (app) => { const ocr = loadOCRConfig(config.ocr); const webSearch = loadWebSearchConfig(config.webSearch); checkWebSearchConfig(webSearch); + const memory = loadMemoryConfig(config.memory); const filteredTools = config.filteredTools; const includedTools = config.includedTools; const fileStrategy = config.fileStrategy ?? configDefaults.fileStrategy; @@ -74,11 +75,10 @@ const AppService = async (app) => { directory: paths.structuredTools, }); - if (config.mcpServers != null) { - const mcpManager = getMCPManager(); - await mcpManager.initializeMCP(config.mcpServers, processMCPEnv); - await mcpManager.mapAvailableTools(availableTools); - } + await setCachedTools(availableTools, { isGlobal: true }); + + // Store MCP config for later initialization + const mcpConfig = config.mcpServers || null; const socialLogins = config?.registration?.socialLogins ?? configDefaults?.registration?.socialLogins; @@ -88,20 +88,26 @@ const AppService = async (app) => { const defaultLocals = { ocr, paths, + memory, webSearch, fileStrategy, socialLogins, filteredTools, includedTools, - availableTools, imageOutputType, interfaceConfig, turnstileConfig, balance, + mcpConfig, }; + const agentsDefaults = agentsConfigSetup(config); + if (!Object.keys(config).length) { - app.locals = defaultLocals; + app.locals = { + ...defaultLocals, + [EModelEndpoint.agents]: agentsDefaults, + }; return; } @@ -136,9 +142,7 @@ const AppService = async (app) => { ); } - if (endpoints?.[EModelEndpoint.agents]) { - endpointLocals[EModelEndpoint.agents] = agentsConfigSetup(config); - } + endpointLocals[EModelEndpoint.agents] = agentsConfigSetup(config, agentsDefaults); const endpointKeys = [ EModelEndpoint.openAI, diff --git a/api/server/services/AppService.spec.js b/api/server/services/AppService.spec.js index 0c7fac6ed3..7edccc2c0d 100644 --- a/api/server/services/AppService.spec.js +++ b/api/server/services/AppService.spec.js @@ -2,8 +2,10 @@ const { FileSources, EModelEndpoint, EImageOutputType, + AgentCapabilities, defaultSocialLogins, validateAzureGroups, + defaultAgentCapabilities, deprecatedAzureVariables, conflictingAzureVariables, } = require('librechat-data-provider'); @@ -24,10 +26,31 @@ jest.mock('./Config/loadCustomConfig', () => { jest.mock('./Files/Firebase/initialize', () => ({ initializeFirebase: jest.fn(), })); -jest.mock('~/models/Role', () => ({ +jest.mock('~/models', () => ({ initializeRoles: jest.fn(), +})); +jest.mock('~/models/Role', () => ({ updateAccessPermissions: jest.fn(), })); +jest.mock('./Config', () => ({ + setCachedTools: jest.fn(), + getCachedTools: jest.fn().mockResolvedValue({ + ExampleTool: { + type: 'function', + function: { + description: 'Example tool function', + name: 'exampleFunction', + parameters: { + type: 'object', + properties: { + param1: { type: 'string', description: 'An example parameter' }, + }, + required: ['param1'], + }, + }, + }, + }), +})); jest.mock('./ToolService', () => ({ loadAndFormatTools: jest.fn().mockReturnValue({ ExampleTool: { @@ -117,22 +140,9 @@ describe('AppService', () => { sidePanel: true, presets: true, }), + mcpConfig: null, turnstileConfig: mockedTurnstileConfig, modelSpecs: undefined, - availableTools: { - ExampleTool: { - type: 'function', - function: expect.objectContaining({ - description: 'Example tool function', - name: 'exampleFunction', - parameters: expect.objectContaining({ - type: 'object', - properties: expect.any(Object), - required: expect.arrayContaining(['param1']), - }), - }), - }, - }, paths: expect.anything(), ocr: expect.anything(), imageOutputType: expect.any(String), @@ -149,6 +159,11 @@ describe('AppService', () => { safeSearch: 1, serperApiKey: '${SERPER_API_KEY}', }, + memory: undefined, + agents: { + disableBuilder: false, + capabilities: expect.arrayContaining([...defaultAgentCapabilities]), + }, }); }); @@ -214,14 +229,41 @@ describe('AppService', () => { it('should load and format tools accurately with defined structure', async () => { const { loadAndFormatTools } = require('./ToolService'); + const { setCachedTools, getCachedTools } = require('./Config'); + await AppService(app); expect(loadAndFormatTools).toHaveBeenCalledWith({ + adminFilter: undefined, + adminIncluded: undefined, directory: expect.anything(), }); - expect(app.locals.availableTools.ExampleTool).toBeDefined(); - expect(app.locals.availableTools.ExampleTool).toEqual({ + // Verify setCachedTools was called with the tools + expect(setCachedTools).toHaveBeenCalledWith( + { + ExampleTool: { + type: 'function', + function: { + description: 'Example tool function', + name: 'exampleFunction', + parameters: { + type: 'object', + properties: { + param1: { type: 'string', description: 'An example parameter' }, + }, + required: ['param1'], + }, + }, + }, + }, + { isGlobal: true }, + ); + + // Verify we can retrieve the tools from cache + const cachedTools = await getCachedTools({ includeGlobal: true }); + expect(cachedTools.ExampleTool).toBeDefined(); + expect(cachedTools.ExampleTool).toEqual({ type: 'function', function: { description: 'Example tool function', @@ -266,6 +308,71 @@ describe('AppService', () => { ); }); + it('should correctly configure Agents endpoint based on custom config', async () => { + require('./Config/loadCustomConfig').mockImplementationOnce(() => + Promise.resolve({ + endpoints: { + [EModelEndpoint.agents]: { + disableBuilder: true, + recursionLimit: 10, + maxRecursionLimit: 20, + allowedProviders: ['openai', 'anthropic'], + capabilities: [AgentCapabilities.tools, AgentCapabilities.actions], + }, + }, + }), + ); + + await AppService(app); + + expect(app.locals).toHaveProperty(EModelEndpoint.agents); + expect(app.locals[EModelEndpoint.agents]).toEqual( + expect.objectContaining({ + disableBuilder: true, + recursionLimit: 10, + maxRecursionLimit: 20, + allowedProviders: expect.arrayContaining(['openai', 'anthropic']), + capabilities: expect.arrayContaining([AgentCapabilities.tools, AgentCapabilities.actions]), + }), + ); + }); + + it('should configure Agents endpoint with defaults when no config is provided', async () => { + require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve({})); + + await AppService(app); + + expect(app.locals).toHaveProperty(EModelEndpoint.agents); + expect(app.locals[EModelEndpoint.agents]).toEqual( + expect.objectContaining({ + disableBuilder: false, + capabilities: expect.arrayContaining([...defaultAgentCapabilities]), + }), + ); + }); + + it('should configure Agents endpoint with defaults when endpoints exist but agents is not defined', async () => { + require('./Config/loadCustomConfig').mockImplementationOnce(() => + Promise.resolve({ + endpoints: { + [EModelEndpoint.openAI]: { + titleConvo: true, + }, + }, + }), + ); + + await AppService(app); + + expect(app.locals).toHaveProperty(EModelEndpoint.agents); + expect(app.locals[EModelEndpoint.agents]).toEqual( + expect.objectContaining({ + disableBuilder: false, + capabilities: expect.arrayContaining([...defaultAgentCapabilities]), + }), + ); + }); + it('should correctly configure minimum Azure OpenAI Assistant values', async () => { const assistantGroups = [azureGroups[0], { ...azureGroups[1], assistants: true }]; require('./Config/loadCustomConfig').mockImplementationOnce(() => @@ -461,7 +568,6 @@ describe('AppService updating app.locals and issuing warnings', () => { expect(app.locals).toBeDefined(); expect(app.locals.paths).toBeDefined(); - expect(app.locals.availableTools).toBeDefined(); expect(app.locals.fileStrategy).toEqual(FileSources.local); expect(app.locals.socialLogins).toEqual(defaultSocialLogins); expect(app.locals.balance).toEqual( @@ -494,7 +600,6 @@ describe('AppService updating app.locals and issuing warnings', () => { expect(app.locals).toBeDefined(); expect(app.locals.paths).toBeDefined(); - expect(app.locals.availableTools).toBeDefined(); expect(app.locals.fileStrategy).toEqual(customConfig.fileStrategy); expect(app.locals.socialLogins).toEqual(customConfig.registration.socialLogins); expect(app.locals.balance).toEqual(customConfig.balance); diff --git a/api/server/services/AssistantService.js b/api/server/services/AssistantService.js index 2db0a56b6b..5354b2e33a 100644 --- a/api/server/services/AssistantService.js +++ b/api/server/services/AssistantService.js @@ -1,4 +1,7 @@ const { klona } = require('klona'); +const { sleep } = require('@librechat/agents'); +const { sendEvent } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { StepTypes, RunStatus, @@ -11,11 +14,10 @@ const { } = require('librechat-data-provider'); const { retrieveAndProcessFile } = require('~/server/services/Files/process'); const { processRequiredActions } = require('~/server/services/ToolService'); -const { createOnProgress, sendMessage, sleep } = require('~/server/utils'); const { RunManager, waitForRun } = require('~/server/services/Runs'); const { processMessages } = require('~/server/services/Threads'); +const { createOnProgress } = require('~/server/utils'); const { TextStream } = require('~/app/clients'); -const { logger } = require('~/config'); /** * Sorts, processes, and flattens messages to a single string. @@ -64,7 +66,7 @@ async function createOnTextProgress({ }; logger.debug('Content data:', contentData); - sendMessage(openai.res, contentData); + sendEvent(openai.res, contentData); }; } diff --git a/api/server/services/AuthService.js b/api/server/services/AuthService.js index ac13172128..8c7cbf7d92 100644 --- a/api/server/services/AuthService.js +++ b/api/server/services/AuthService.js @@ -1,28 +1,29 @@ const bcrypt = require('bcryptjs'); +const jwt = require('jsonwebtoken'); const { webcrypto } = require('node:crypto'); +const { isEnabled } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { SystemRoles, errorsToString } = require('librechat-data-provider'); const { findUser, - countUsers, createUser, updateUser, - getUserById, - generateToken, - deleteUserById, -} = require('~/models/userMethods'); -const { - createToken, findToken, - deleteTokens, + countUsers, + getUserById, findSession, + createToken, + deleteTokens, deleteSession, createSession, + generateToken, + deleteUserById, generateRefreshToken, } = require('~/models'); -const { isEnabled, checkEmailConfig, sendEmail } = require('~/server/utils'); const { isEmailDomainAllowed } = require('~/server/services/domains'); +const { checkEmailConfig, sendEmail } = require('~/server/utils'); +const { getBalanceConfig } = require('~/server/services/Config'); const { registerSchema } = require('~/strategies/validators'); -const { logger } = require('~/config'); const domains = { client: process.env.DOMAIN_CLIENT, @@ -146,6 +147,7 @@ const verifyEmail = async (req) => { } const updatedUser = await updateUser(emailVerificationData.userId, { emailVerified: true }); + if (!updatedUser) { logger.warn(`[verifyEmail] [User update failed] [Email: ${decodedEmail}]`); return new Error('Failed to update user verification status'); @@ -155,6 +157,7 @@ const verifyEmail = async (req) => { logger.info(`[verifyEmail] Email verification successful [Email: ${decodedEmail}]`); return { message: 'Email verification was successful', status: 'success' }; }; + /** * Register a new user. * @param {MongoUser} user @@ -216,7 +219,9 @@ const registerUser = async (user, additionalData = {}) => { const emailEnabled = checkEmailConfig(); const disableTTL = isEnabled(process.env.ALLOW_UNVERIFIED_EMAIL_LOGIN); - const newUser = await createUser(newUserData, disableTTL, true); + const balanceConfig = await getBalanceConfig(); + + const newUser = await createUser(newUserData, balanceConfig, disableTTL, true); newUserId = newUser._id; if (emailEnabled && !newUser.emailVerified) { await sendVerificationEmail({ @@ -389,6 +394,7 @@ const setAuthTokens = async (userId, res, sessionId = null) => { throw error; } }; + /** * @function setOpenIDAuthTokens * Set OpenID Authentication Tokens @@ -405,7 +411,9 @@ const setOpenIDAuthTokens = (tokenset, res) => { return; } const { REFRESH_TOKEN_EXPIRY } = process.env ?? {}; - const expiryInMilliseconds = eval(REFRESH_TOKEN_EXPIRY) ?? 1000 * 60 * 60 * 24 * 7; // 7 days default + const expiryInMilliseconds = REFRESH_TOKEN_EXPIRY + ? eval(REFRESH_TOKEN_EXPIRY) + : 1000 * 60 * 60 * 24 * 7; // 7 days default const expirationDate = new Date(Date.now() + expiryInMilliseconds); if (tokenset == null) { logger.error('[setOpenIDAuthTokens] No tokenset found in request'); @@ -492,6 +500,18 @@ const resendVerificationEmail = async (req) => { }; } }; +/** + * Generate a short-lived JWT token + * @param {String} userId - The ID of the user + * @param {String} [expireIn='5m'] - The expiration time for the token (default is 5 minutes) + * @returns {String} - The generated JWT token + */ +const generateShortLivedToken = (userId, expireIn = '5m') => { + return jwt.sign({ id: userId }, process.env.JWT_SECRET, { + expiresIn: expireIn, + algorithm: 'HS256', + }); +}; module.exports = { logoutUser, @@ -499,7 +519,8 @@ module.exports = { registerUser, setAuthTokens, resetPassword, + setOpenIDAuthTokens, requestPasswordReset, resendVerificationEmail, - setOpenIDAuthTokens, + generateShortLivedToken, }; diff --git a/api/server/services/Config/EndpointService.js b/api/server/services/Config/EndpointService.js index 1f38b70a62..d8277dd67f 100644 --- a/api/server/services/Config/EndpointService.js +++ b/api/server/services/Config/EndpointService.js @@ -1,5 +1,6 @@ +const { isUserProvided } = require('@librechat/api'); const { EModelEndpoint } = require('librechat-data-provider'); -const { isUserProvided, generateConfig } = require('~/server/utils'); +const { generateConfig } = require('~/server/utils/handleText'); const { OPENAI_API_KEY: openAIApiKey, diff --git a/api/server/services/Config/getCachedTools.js b/api/server/services/Config/getCachedTools.js new file mode 100644 index 0000000000..b3a4f0c869 --- /dev/null +++ b/api/server/services/Config/getCachedTools.js @@ -0,0 +1,258 @@ +const { CacheKeys } = require('librechat-data-provider'); +const getLogStores = require('~/cache/getLogStores'); + +/** + * Cache key generators for different tool access patterns + * These will support future permission-based caching + */ +const ToolCacheKeys = { + /** Global tools available to all users */ + GLOBAL: 'tools:global', + /** Tools available to a specific user */ + USER: (userId) => `tools:user:${userId}`, + /** Tools available to a specific role */ + ROLE: (roleId) => `tools:role:${roleId}`, + /** Tools available to a specific group */ + GROUP: (groupId) => `tools:group:${groupId}`, + /** Combined effective tools for a user (computed from all sources) */ + EFFECTIVE: (userId) => `tools:effective:${userId}`, +}; + +/** + * Retrieves available tools from cache + * @function getCachedTools + * @param {Object} options - Options for retrieving tools + * @param {string} [options.userId] - User ID for user-specific tools + * @param {string[]} [options.roleIds] - Role IDs for role-based tools + * @param {string[]} [options.groupIds] - Group IDs for group-based tools + * @param {boolean} [options.includeGlobal=true] - Whether to include global tools + * @returns {Promise} The available tools object or null if not cached + */ +async function getCachedTools(options = {}) { + const cache = getLogStores(CacheKeys.CONFIG_STORE); + const { userId, roleIds = [], groupIds = [], includeGlobal = true } = options; + + // For now, return global tools (current behavior) + // This will be expanded to merge tools from different sources + if (!userId && includeGlobal) { + return await cache.get(ToolCacheKeys.GLOBAL); + } + + // Future implementation will merge tools from multiple sources + // based on user permissions, roles, and groups + if (userId) { + // Check if we have pre-computed effective tools for this user + const effectiveTools = await cache.get(ToolCacheKeys.EFFECTIVE(userId)); + if (effectiveTools) { + return effectiveTools; + } + + // Otherwise, compute from individual sources + const toolSources = []; + + if (includeGlobal) { + const globalTools = await cache.get(ToolCacheKeys.GLOBAL); + if (globalTools) { + toolSources.push(globalTools); + } + } + + // User-specific tools + const userTools = await cache.get(ToolCacheKeys.USER(userId)); + if (userTools) { + toolSources.push(userTools); + } + + // Role-based tools + for (const roleId of roleIds) { + const roleTools = await cache.get(ToolCacheKeys.ROLE(roleId)); + if (roleTools) { + toolSources.push(roleTools); + } + } + + // Group-based tools + for (const groupId of groupIds) { + const groupTools = await cache.get(ToolCacheKeys.GROUP(groupId)); + if (groupTools) { + toolSources.push(groupTools); + } + } + + // Merge all tool sources (for now, simple merge - future will handle conflicts) + if (toolSources.length > 0) { + return mergeToolSources(toolSources); + } + } + + return null; +} + +/** + * Sets available tools in cache + * @function setCachedTools + * @param {Object} tools - The tools object to cache + * @param {Object} options - Options for caching tools + * @param {string} [options.userId] - User ID for user-specific tools + * @param {string} [options.roleId] - Role ID for role-based tools + * @param {string} [options.groupId] - Group ID for group-based tools + * @param {boolean} [options.isGlobal=false] - Whether these are global tools + * @param {number} [options.ttl] - Time to live in milliseconds + * @returns {Promise} Whether the operation was successful + */ +async function setCachedTools(tools, options = {}) { + const cache = getLogStores(CacheKeys.CONFIG_STORE); + const { userId, roleId, groupId, isGlobal = false, ttl } = options; + + let cacheKey; + if (isGlobal || (!userId && !roleId && !groupId)) { + cacheKey = ToolCacheKeys.GLOBAL; + } else if (userId) { + cacheKey = ToolCacheKeys.USER(userId); + } else if (roleId) { + cacheKey = ToolCacheKeys.ROLE(roleId); + } else if (groupId) { + cacheKey = ToolCacheKeys.GROUP(groupId); + } + + if (!cacheKey) { + throw new Error('Invalid cache key options provided'); + } + + return await cache.set(cacheKey, tools, ttl); +} + +/** + * Invalidates cached tools + * @function invalidateCachedTools + * @param {Object} options - Options for invalidating tools + * @param {string} [options.userId] - User ID to invalidate + * @param {string} [options.roleId] - Role ID to invalidate + * @param {string} [options.groupId] - Group ID to invalidate + * @param {boolean} [options.invalidateGlobal=false] - Whether to invalidate global tools + * @param {boolean} [options.invalidateEffective=true] - Whether to invalidate effective tools + * @returns {Promise} + */ +async function invalidateCachedTools(options = {}) { + const cache = getLogStores(CacheKeys.CONFIG_STORE); + const { userId, roleId, groupId, invalidateGlobal = false, invalidateEffective = true } = options; + + const keysToDelete = []; + + if (invalidateGlobal) { + keysToDelete.push(ToolCacheKeys.GLOBAL); + } + + if (userId) { + keysToDelete.push(ToolCacheKeys.USER(userId)); + if (invalidateEffective) { + keysToDelete.push(ToolCacheKeys.EFFECTIVE(userId)); + } + } + + if (roleId) { + keysToDelete.push(ToolCacheKeys.ROLE(roleId)); + // TODO: In future, invalidate all users with this role + } + + if (groupId) { + keysToDelete.push(ToolCacheKeys.GROUP(groupId)); + // TODO: In future, invalidate all users in this group + } + + await Promise.all(keysToDelete.map((key) => cache.delete(key))); +} + +/** + * Computes and caches effective tools for a user + * @function computeEffectiveTools + * @param {string} userId - The user ID + * @param {Object} context - Context containing user's roles and groups + * @param {string[]} [context.roleIds=[]] - User's role IDs + * @param {string[]} [context.groupIds=[]] - User's group IDs + * @param {number} [ttl] - Time to live for the computed result + * @returns {Promise} The computed effective tools + */ +async function computeEffectiveTools(userId, context = {}, ttl) { + const { roleIds = [], groupIds = [] } = context; + + // Get all tool sources + const tools = await getCachedTools({ + userId, + roleIds, + groupIds, + includeGlobal: true, + }); + + if (tools) { + // Cache the computed result + const cache = getLogStores(CacheKeys.CONFIG_STORE); + await cache.set(ToolCacheKeys.EFFECTIVE(userId), tools, ttl); + } + + return tools; +} + +/** + * Merges multiple tool sources into a single tools object + * @function mergeToolSources + * @param {Object[]} sources - Array of tool objects to merge + * @returns {Object} Merged tools object + */ +function mergeToolSources(sources) { + // For now, simple merge that combines all tools + // Future implementation will handle: + // - Permission precedence (deny > allow) + // - Tool property conflicts + // - Metadata merging + const merged = {}; + + for (const source of sources) { + if (!source || typeof source !== 'object') { + continue; + } + + for (const [toolId, toolConfig] of Object.entries(source)) { + // Simple last-write-wins for now + // Future: merge based on permission levels + merged[toolId] = toolConfig; + } + } + + return merged; +} + +/** + * Middleware-friendly function to get tools for a request + * @function getToolsForRequest + * @param {Object} req - Express request object + * @returns {Promise} Available tools for the request + */ +async function getToolsForRequest(req) { + const userId = req.user?.id; + + // For now, return global tools if no user + if (!userId) { + return getCachedTools({ includeGlobal: true }); + } + + // Future: Extract roles and groups from req.user + const roleIds = req.user?.roles || []; + const groupIds = req.user?.groups || []; + + return getCachedTools({ + userId, + roleIds, + groupIds, + includeGlobal: true, + }); +} + +module.exports = { + ToolCacheKeys, + getCachedTools, + setCachedTools, + getToolsForRequest, + invalidateCachedTools, + computeEffectiveTools, +}; diff --git a/api/server/services/Config/getCustomConfig.js b/api/server/services/Config/getCustomConfig.js index 74828789fc..f3fb6f26b4 100644 --- a/api/server/services/Config/getCustomConfig.js +++ b/api/server/services/Config/getCustomConfig.js @@ -1,6 +1,10 @@ +const { logger } = require('@librechat/data-schemas'); +const { getUserMCPAuthMap } = require('@librechat/api'); const { CacheKeys, EModelEndpoint } = require('librechat-data-provider'); const { normalizeEndpointName, isEnabled } = require('~/server/utils'); const loadCustomConfig = require('./loadCustomConfig'); +const { getCachedTools } = require('./getCachedTools'); +const { findPluginAuthsByKeys } = require('~/models'); const getLogStores = require('~/cache/getLogStores'); /** @@ -36,6 +40,7 @@ async function getBalanceConfig() { /** * * @param {string | EModelEndpoint} endpoint + * @returns {Promise} */ const getCustomEndpointConfig = async (endpoint) => { const customConfig = await getCustomConfig(); @@ -50,4 +55,46 @@ const getCustomEndpointConfig = async (endpoint) => { ); }; -module.exports = { getCustomConfig, getBalanceConfig, getCustomEndpointConfig }; +async function createGetMCPAuthMap() { + const customConfig = await getCustomConfig(); + const mcpServers = customConfig?.mcpServers; + const hasCustomUserVars = Object.values(mcpServers ?? {}).some((server) => server.customUserVars); + if (!hasCustomUserVars) { + return; + } + + /** + * @param {Object} params + * @param {GenericTool[]} [params.tools] + * @param {string} params.userId + * @returns {Promise> | undefined>} + */ + return async function ({ tools, userId }) { + try { + if (!tools || tools.length === 0) { + return; + } + const appTools = await getCachedTools({ + userId, + }); + return await getUserMCPAuthMap({ + tools, + userId, + appTools, + findPluginAuthsByKeys, + }); + } catch (err) { + logger.error( + `[api/server/controllers/agents/client.js #chatCompletion] Error getting custom user vars for agent`, + err, + ); + } + }; +} + +module.exports = { + getCustomConfig, + getBalanceConfig, + createGetMCPAuthMap, + getCustomEndpointConfig, +}; diff --git a/api/server/services/Config/index.js b/api/server/services/Config/index.js index 9d668da958..ad25e57998 100644 --- a/api/server/services/Config/index.js +++ b/api/server/services/Config/index.js @@ -1,4 +1,5 @@ const { config } = require('./EndpointService'); +const getCachedTools = require('./getCachedTools'); const getCustomConfig = require('./getCustomConfig'); const loadCustomConfig = require('./loadCustomConfig'); const loadConfigModels = require('./loadConfigModels'); @@ -14,6 +15,7 @@ module.exports = { loadDefaultModels, loadOverrideConfig, loadAsyncEndpoints, + ...getCachedTools, ...getCustomConfig, ...getEndpointsConfig, }; diff --git a/api/server/services/Config/loadAsyncEndpoints.js b/api/server/services/Config/loadAsyncEndpoints.js index 0282146cd1..56f693c779 100644 --- a/api/server/services/Config/loadAsyncEndpoints.js +++ b/api/server/services/Config/loadAsyncEndpoints.js @@ -1,3 +1,5 @@ +const path = require('path'); +const { loadServiceKey } = require('@librechat/api'); const { EModelEndpoint } = require('librechat-data-provider'); const { isUserProvided } = require('~/server/utils'); const { config } = require('./EndpointService'); @@ -11,9 +13,13 @@ const { openAIApiKey, azureOpenAIApiKey, useAzurePlugins, userProvidedOpenAI, go async function loadAsyncEndpoints(req) { let i = 0; let serviceKey, googleUserProvides; + const serviceKeyPath = + process.env.GOOGLE_SERVICE_KEY_FILE_PATH || + path.join(__dirname, '../../..', 'data', 'auth.json'); + try { - serviceKey = require('~/data/auth.json'); - } catch (e) { + serviceKey = await loadServiceKey(serviceKeyPath); + } catch { if (i === 0) { i++; } @@ -32,14 +38,14 @@ async function loadAsyncEndpoints(req) { const gptPlugins = useAzure || openAIApiKey || azureOpenAIApiKey ? { - availableAgents: ['classic', 'functions'], - userProvide: useAzure ? false : userProvidedOpenAI, - userProvideURL: useAzure - ? false - : config[EModelEndpoint.openAI]?.userProvideURL || + availableAgents: ['classic', 'functions'], + userProvide: useAzure ? false : userProvidedOpenAI, + userProvideURL: useAzure + ? false + : config[EModelEndpoint.openAI]?.userProvideURL || config[EModelEndpoint.azureOpenAI]?.userProvideURL, - azure: useAzurePlugins || useAzure, - } + azure: useAzurePlugins || useAzure, + } : false; return { google, gptPlugins }; diff --git a/api/server/services/Config/loadCustomConfig.js b/api/server/services/Config/loadCustomConfig.js index 18f3a44748..393281daf2 100644 --- a/api/server/services/Config/loadCustomConfig.js +++ b/api/server/services/Config/loadCustomConfig.js @@ -1,18 +1,18 @@ const path = require('path'); -const { - CacheKeys, - configSchema, - EImageOutputType, - validateSettingDefinitions, - agentParamSettings, - paramSettings, -} = require('librechat-data-provider'); -const getLogStores = require('~/cache/getLogStores'); -const loadYaml = require('~/utils/loadYaml'); -const { logger } = require('~/config'); const axios = require('axios'); const yaml = require('js-yaml'); const keyBy = require('lodash/keyBy'); +const { loadYaml } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); +const { + CacheKeys, + configSchema, + paramSettings, + EImageOutputType, + agentParamSettings, + validateSettingDefinitions, +} = require('librechat-data-provider'); +const getLogStores = require('~/cache/getLogStores'); const projectRoot = path.resolve(__dirname, '..', '..', '..', '..'); const defaultConfigPath = path.resolve(projectRoot, 'librechat.yaml'); diff --git a/api/server/services/Config/loadCustomConfig.spec.js b/api/server/services/Config/loadCustomConfig.spec.js index ed698e57f1..9b905181c5 100644 --- a/api/server/services/Config/loadCustomConfig.spec.js +++ b/api/server/services/Config/loadCustomConfig.spec.js @@ -1,6 +1,9 @@ jest.mock('axios'); jest.mock('~/cache/getLogStores'); -jest.mock('~/utils/loadYaml'); +jest.mock('@librechat/api', () => ({ + ...jest.requireActual('@librechat/api'), + loadYaml: jest.fn(), +})); jest.mock('librechat-data-provider', () => { const actual = jest.requireActual('librechat-data-provider'); return { @@ -30,11 +33,22 @@ jest.mock('librechat-data-provider', () => { }; }); +jest.mock('@librechat/data-schemas', () => { + return { + logger: { + info: jest.fn(), + warn: jest.fn(), + debug: jest.fn(), + error: jest.fn(), + }, + }; +}); + const axios = require('axios'); +const { loadYaml } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const loadCustomConfig = require('./loadCustomConfig'); const getLogStores = require('~/cache/getLogStores'); -const loadYaml = require('~/utils/loadYaml'); -const { logger } = require('~/config'); describe('loadCustomConfig', () => { const mockSet = jest.fn(); diff --git a/api/server/services/Endpoints/agents/agent.js b/api/server/services/Endpoints/agents/agent.js new file mode 100644 index 0000000000..00c6baada3 --- /dev/null +++ b/api/server/services/Endpoints/agents/agent.js @@ -0,0 +1,189 @@ +const { Providers } = require('@librechat/agents'); +const { + primeResources, + extractLibreChatParams, + optionalChainWithEmptyCheck, +} = require('@librechat/api'); +const { + ErrorTypes, + EModelEndpoint, + EToolResources, + replaceSpecialVars, + providerEndpointMap, +} = require('librechat-data-provider'); +const { getProviderConfig } = require('~/server/services/Endpoints'); +const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts'); +const { processFiles } = require('~/server/services/Files/process'); +const { getFiles, getToolFilesByIds } = require('~/models/File'); +const { getConvoFiles } = require('~/models/Conversation'); +const { getModelMaxTokens } = require('~/utils'); + +/** + * @param {object} params + * @param {ServerRequest} params.req + * @param {ServerResponse} params.res + * @param {Agent} params.agent + * @param {string | null} [params.conversationId] + * @param {Array} [params.requestFiles] + * @param {typeof import('~/server/services/ToolService').loadAgentTools | undefined} [params.loadTools] + * @param {TEndpointOption} [params.endpointOption] + * @param {Set} [params.allowedProviders] + * @param {boolean} [params.isInitialAgent] + * @returns {Promise, toolContextMap: Record, maxContextTokens: number }>} + */ +const initializeAgent = async ({ + req, + res, + agent, + loadTools, + requestFiles, + conversationId, + endpointOption, + allowedProviders, + isInitialAgent = false, +}) => { + if (allowedProviders.size > 0 && !allowedProviders.has(agent.provider)) { + throw new Error( + `{ "type": "${ErrorTypes.INVALID_AGENT_PROVIDER}", "info": "${agent.provider}" }`, + ); + } + let currentFiles; + + const _modelOptions = structuredClone( + Object.assign( + { model: agent.model }, + agent.model_parameters ?? { model: agent.model }, + isInitialAgent === true ? endpointOption?.model_parameters : {}, + ), + ); + + const { resendFiles, maxContextTokens, modelOptions } = extractLibreChatParams(_modelOptions); + + if (isInitialAgent && conversationId != null && resendFiles) { + const fileIds = (await getConvoFiles(conversationId)) ?? []; + /** @type {Set} */ + const toolResourceSet = new Set(); + for (const tool of agent.tools) { + if (EToolResources[tool]) { + toolResourceSet.add(EToolResources[tool]); + } + } + const toolFiles = await getToolFilesByIds(fileIds, toolResourceSet); + if (requestFiles.length || toolFiles.length) { + currentFiles = await processFiles(requestFiles.concat(toolFiles)); + } + } else if (isInitialAgent && requestFiles.length) { + currentFiles = await processFiles(requestFiles); + } + + const { attachments, tool_resources } = await primeResources({ + req, + getFiles, + attachments: currentFiles, + tool_resources: agent.tool_resources, + requestFileSet: new Set(requestFiles?.map((file) => file.file_id)), + }); + + const provider = agent.provider; + const { tools: structuredTools, toolContextMap } = + (await loadTools?.({ + req, + res, + provider, + agentId: agent.id, + tools: agent.tools, + model: agent.model, + tool_resources, + })) ?? {}; + + agent.endpoint = provider; + const { getOptions, overrideProvider } = await getProviderConfig(provider); + if (overrideProvider) { + agent.provider = overrideProvider; + } + + const _endpointOption = + isInitialAgent === true + ? Object.assign({}, endpointOption, { model_parameters: modelOptions }) + : { model_parameters: modelOptions }; + + const options = await getOptions({ + req, + res, + optionsOnly: true, + overrideEndpoint: provider, + overrideModel: agent.model, + endpointOption: _endpointOption, + }); + + const tokensModel = + agent.provider === EModelEndpoint.azureOpenAI ? agent.model : modelOptions.model; + const maxTokens = optionalChainWithEmptyCheck( + modelOptions.maxOutputTokens, + modelOptions.maxTokens, + 0, + ); + const agentMaxContextTokens = optionalChainWithEmptyCheck( + maxContextTokens, + getModelMaxTokens(tokensModel, providerEndpointMap[provider]), + 4096, + ); + + if ( + agent.endpoint === EModelEndpoint.azureOpenAI && + options.llmConfig?.azureOpenAIApiInstanceName == null + ) { + agent.provider = Providers.OPENAI; + } + + if (options.provider != null) { + agent.provider = options.provider; + } + + /** @type {import('@librechat/agents').GenericTool[]} */ + let tools = options.tools?.length ? options.tools : structuredTools; + if ( + (agent.provider === Providers.GOOGLE || agent.provider === Providers.VERTEXAI) && + options.tools?.length && + structuredTools?.length + ) { + throw new Error(`{ "type": "${ErrorTypes.GOOGLE_TOOL_CONFLICT}"}`); + } else if ( + (agent.provider === Providers.OPENAI || agent.provider === Providers.AZURE) && + options.tools?.length && + structuredTools?.length + ) { + tools = structuredTools.concat(options.tools); + } + + /** @type {import('@librechat/agents').ClientOptions} */ + agent.model_parameters = { ...options.llmConfig }; + if (options.configOptions) { + agent.model_parameters.configuration = options.configOptions; + } + + if (agent.instructions && agent.instructions !== '') { + agent.instructions = replaceSpecialVars({ + text: agent.instructions, + user: req.user, + }); + } + + if (typeof agent.artifacts === 'string' && agent.artifacts !== '') { + agent.additional_instructions = generateArtifactsPrompt({ + endpoint: agent.provider, + artifacts: agent.artifacts, + }); + } + + return { + ...agent, + attachments, + resendFiles, + toolContextMap, + tools, + maxContextTokens: (agentMaxContextTokens - maxTokens) * 0.9, + }; +}; + +module.exports = { initializeAgent }; diff --git a/api/server/services/Endpoints/agents/build.js b/api/server/services/Endpoints/agents/build.js index 77ebbc58dc..143dde9459 100644 --- a/api/server/services/Endpoints/agents/build.js +++ b/api/server/services/Endpoints/agents/build.js @@ -1,10 +1,9 @@ -const { isAgentsEndpoint, Constants } = require('librechat-data-provider'); +const { isAgentsEndpoint, removeNullishValues, Constants } = require('librechat-data-provider'); const { loadAgent } = require('~/models/Agent'); const { logger } = require('~/config'); const buildOptions = (req, endpoint, parsedBody, endpointType) => { - const { spec, iconURL, agent_id, instructions, maxContextTokens, ...model_parameters } = - parsedBody; + const { spec, iconURL, agent_id, instructions, ...model_parameters } = parsedBody; const agentPromise = loadAgent({ req, agent_id: isAgentsEndpoint(endpoint) ? agent_id : Constants.EPHEMERAL_AGENT_ID, @@ -15,19 +14,16 @@ const buildOptions = (req, endpoint, parsedBody, endpointType) => { return undefined; }); - const endpointOption = { + return removeNullishValues({ spec, iconURL, endpoint, agent_id, endpointType, instructions, - maxContextTokens, model_parameters, agent: agentPromise, - }; - - return endpointOption; + }); }; module.exports = { buildOptions }; diff --git a/api/server/services/Endpoints/agents/initialize.js b/api/server/services/Endpoints/agents/initialize.js index c9e363e815..94af3bdd3b 100644 --- a/api/server/services/Endpoints/agents/initialize.js +++ b/api/server/services/Endpoints/agents/initialize.js @@ -1,294 +1,47 @@ -const { createContentAggregator, Providers } = require('@librechat/agents'); +const { logger } = require('@librechat/data-schemas'); +const { createContentAggregator } = require('@librechat/agents'); const { Constants, - ErrorTypes, EModelEndpoint, - EToolResources, + isAgentsEndpoint, getResponseSender, - AgentCapabilities, - replaceSpecialVars, - providerEndpointMap, } = require('librechat-data-provider'); const { - getDefaultHandlers, createToolEndCallback, + getDefaultHandlers, } = require('~/server/controllers/agents/callbacks'); -const initAnthropic = require('~/server/services/Endpoints/anthropic/initialize'); -const getBedrockOptions = require('~/server/services/Endpoints/bedrock/options'); -const initOpenAI = require('~/server/services/Endpoints/openAI/initialize'); -const initCustom = require('~/server/services/Endpoints/custom/initialize'); -const initGoogle = require('~/server/services/Endpoints/google/initialize'); -const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts'); +const { initializeAgent } = require('~/server/services/Endpoints/agents/agent'); const { getCustomEndpointConfig } = require('~/server/services/Config'); -const { processFiles } = require('~/server/services/Files/process'); const { loadAgentTools } = require('~/server/services/ToolService'); const AgentClient = require('~/server/controllers/agents/client'); -const { getConvoFiles } = require('~/models/Conversation'); -const { getToolFilesByIds } = require('~/models/File'); -const { getModelMaxTokens } = require('~/utils'); const { getAgent } = require('~/models/Agent'); -const { getFiles } = require('~/models/File'); -const { logger } = require('~/config'); -const providerConfigMap = { - [Providers.XAI]: initCustom, - [Providers.OLLAMA]: initCustom, - [Providers.DEEPSEEK]: initCustom, - [Providers.OPENROUTER]: initCustom, - [EModelEndpoint.openAI]: initOpenAI, - [EModelEndpoint.google]: initGoogle, - [EModelEndpoint.azureOpenAI]: initOpenAI, - [EModelEndpoint.anthropic]: initAnthropic, - [EModelEndpoint.bedrock]: getBedrockOptions, -}; - -/** - * @param {Object} params - * @param {ServerRequest} params.req - * @param {Promise> | undefined} [params.attachments] - * @param {Set} params.requestFileSet - * @param {AgentToolResources | undefined} [params.tool_resources] - * @returns {Promise<{ attachments: Array | undefined, tool_resources: AgentToolResources | undefined }>} - */ -const primeResources = async ({ - req, - attachments: _attachments, - tool_resources: _tool_resources, - requestFileSet, -}) => { - try { - /** @type {Array | undefined} */ - let attachments; - const tool_resources = _tool_resources ?? {}; - const isOCREnabled = (req.app.locals?.[EModelEndpoint.agents]?.capabilities ?? []).includes( - AgentCapabilities.ocr, - ); - if (tool_resources[EToolResources.ocr]?.file_ids && isOCREnabled) { - const context = await getFiles( - { - file_id: { $in: tool_resources.ocr.file_ids }, - }, - {}, - {}, - ); - attachments = (attachments ?? []).concat(context); +function createToolLoader() { + /** + * @param {object} params + * @param {ServerRequest} params.req + * @param {ServerResponse} params.res + * @param {string} params.agentId + * @param {string[]} params.tools + * @param {string} params.provider + * @param {string} params.model + * @param {AgentToolResources} params.tool_resources + * @returns {Promise<{ tools: StructuredTool[], toolContextMap: Record } | undefined>} + */ + return async function loadTools({ req, res, agentId, tools, provider, model, tool_resources }) { + const agent = { id: agentId, tools, provider, model }; + try { + return await loadAgentTools({ + req, + res, + agent, + tool_resources, + }); + } catch (error) { + logger.error('Error loading tools for agent ' + agentId, error); } - if (!_attachments) { - return { attachments, tool_resources }; - } - /** @type {Array | undefined} */ - const files = await _attachments; - if (!attachments) { - /** @type {Array} */ - attachments = []; - } - - for (const file of files) { - if (!file) { - continue; - } - if (file.metadata?.fileIdentifier) { - const execute_code = tool_resources[EToolResources.execute_code] ?? {}; - if (!execute_code.files) { - tool_resources[EToolResources.execute_code] = { ...execute_code, files: [] }; - } - tool_resources[EToolResources.execute_code].files.push(file); - } else if (file.embedded === true) { - const file_search = tool_resources[EToolResources.file_search] ?? {}; - if (!file_search.files) { - tool_resources[EToolResources.file_search] = { ...file_search, files: [] }; - } - tool_resources[EToolResources.file_search].files.push(file); - } else if ( - requestFileSet.has(file.file_id) && - file.type.startsWith('image') && - file.height && - file.width - ) { - const image_edit = tool_resources[EToolResources.image_edit] ?? {}; - if (!image_edit.files) { - tool_resources[EToolResources.image_edit] = { ...image_edit, files: [] }; - } - tool_resources[EToolResources.image_edit].files.push(file); - } - - attachments.push(file); - } - return { attachments, tool_resources }; - } catch (error) { - logger.error('Error priming resources', error); - return { attachments: _attachments, tool_resources: _tool_resources }; - } -}; - -/** - * @param {...string | number} values - * @returns {string | number | undefined} - */ -function optionalChainWithEmptyCheck(...values) { - for (const value of values) { - if (value !== undefined && value !== null && value !== '') { - return value; - } - } - return values[values.length - 1]; -} - -/** - * @param {object} params - * @param {ServerRequest} params.req - * @param {ServerResponse} params.res - * @param {Agent} params.agent - * @param {Set} [params.allowedProviders] - * @param {object} [params.endpointOption] - * @param {boolean} [params.isInitialAgent] - * @returns {Promise} - */ -const initializeAgentOptions = async ({ - req, - res, - agent, - endpointOption, - allowedProviders, - isInitialAgent = false, -}) => { - if (allowedProviders.size > 0 && !allowedProviders.has(agent.provider)) { - throw new Error( - `{ "type": "${ErrorTypes.INVALID_AGENT_PROVIDER}", "info": "${agent.provider}" }`, - ); - } - let currentFiles; - /** @type {Array} */ - const requestFiles = req.body.files ?? []; - if ( - isInitialAgent && - req.body.conversationId != null && - (agent.model_parameters?.resendFiles ?? true) === true - ) { - const fileIds = (await getConvoFiles(req.body.conversationId)) ?? []; - /** @type {Set} */ - const toolResourceSet = new Set(); - for (const tool of agent.tools) { - if (EToolResources[tool]) { - toolResourceSet.add(EToolResources[tool]); - } - } - const toolFiles = await getToolFilesByIds(fileIds, toolResourceSet); - if (requestFiles.length || toolFiles.length) { - currentFiles = await processFiles(requestFiles.concat(toolFiles)); - } - } else if (isInitialAgent && requestFiles.length) { - currentFiles = await processFiles(requestFiles); - } - - const { attachments, tool_resources } = await primeResources({ - req, - attachments: currentFiles, - tool_resources: agent.tool_resources, - requestFileSet: new Set(requestFiles.map((file) => file.file_id)), - }); - - const provider = agent.provider; - const { tools, toolContextMap } = await loadAgentTools({ - req, - res, - agent: { - id: agent.id, - tools: agent.tools, - provider, - model: agent.model, - }, - tool_resources, - }); - - agent.endpoint = provider; - let getOptions = providerConfigMap[provider]; - if (!getOptions && providerConfigMap[provider.toLowerCase()] != null) { - agent.provider = provider.toLowerCase(); - getOptions = providerConfigMap[agent.provider]; - } else if (!getOptions) { - const customEndpointConfig = await getCustomEndpointConfig(provider); - if (!customEndpointConfig) { - throw new Error(`Provider ${provider} not supported`); - } - getOptions = initCustom; - agent.provider = Providers.OPENAI; - } - const model_parameters = Object.assign( - {}, - agent.model_parameters ?? { model: agent.model }, - isInitialAgent === true ? endpointOption?.model_parameters : {}, - ); - const _endpointOption = - isInitialAgent === true - ? Object.assign({}, endpointOption, { model_parameters }) - : { model_parameters }; - - const options = await getOptions({ - req, - res, - optionsOnly: true, - overrideEndpoint: provider, - overrideModel: agent.model, - endpointOption: _endpointOption, - }); - - if ( - agent.endpoint === EModelEndpoint.azureOpenAI && - options.llmConfig?.azureOpenAIApiInstanceName == null - ) { - agent.provider = Providers.OPENAI; - } - - if (options.provider != null) { - agent.provider = options.provider; - } - - /** @type {import('@librechat/agents').ClientOptions} */ - agent.model_parameters = Object.assign(model_parameters, options.llmConfig); - if (options.configOptions) { - agent.model_parameters.configuration = options.configOptions; - } - - if (!agent.model_parameters.model) { - agent.model_parameters.model = agent.model; - } - - if (agent.instructions && agent.instructions !== '') { - agent.instructions = replaceSpecialVars({ - text: agent.instructions, - user: req.user, - }); - } - - if (typeof agent.artifacts === 'string' && agent.artifacts !== '') { - agent.additional_instructions = generateArtifactsPrompt({ - endpoint: agent.provider, - artifacts: agent.artifacts, - }); - } - - const tokensModel = - agent.provider === EModelEndpoint.azureOpenAI ? agent.model : agent.model_parameters.model; - const maxTokens = optionalChainWithEmptyCheck( - agent.model_parameters.maxOutputTokens, - agent.model_parameters.maxTokens, - 0, - ); - const maxContextTokens = optionalChainWithEmptyCheck( - agent.model_parameters.maxContextTokens, - agent.max_context_tokens, - getModelMaxTokens(tokensModel, providerEndpointMap[provider]), - 4096, - ); - return { - ...agent, - tools, - attachments, - toolContextMap, - maxContextTokens: (maxContextTokens - maxTokens) * 0.9, }; -}; +} const initializeClient = async ({ req, res, endpointOption }) => { if (!endpointOption) { @@ -313,8 +66,8 @@ const initializeClient = async ({ req, res, endpointOption }) => { throw new Error('No agent promise provided'); } - // Initialize primary agent const primaryAgent = await endpointOption.agent; + delete endpointOption.agent; if (!primaryAgent) { throw new Error('Agent not found'); } @@ -323,10 +76,18 @@ const initializeClient = async ({ req, res, endpointOption }) => { /** @type {Set} */ const allowedProviders = new Set(req?.app?.locals?.[EModelEndpoint.agents]?.allowedProviders); - // Handle primary agent - const primaryConfig = await initializeAgentOptions({ + const loadTools = createToolLoader(); + /** @type {Array} */ + const requestFiles = req.body.files ?? []; + /** @type {string} */ + const conversationId = req.body.conversationId; + + const primaryConfig = await initializeAgent({ req, res, + loadTools, + requestFiles, + conversationId, agent: primaryAgent, endpointOption, allowedProviders, @@ -340,10 +101,13 @@ const initializeClient = async ({ req, res, endpointOption }) => { if (!agent) { throw new Error(`Agent ${agentId} not found`); } - const config = await initializeAgentOptions({ + const config = await initializeAgent({ req, res, agent, + loadTools, + requestFiles, + conversationId, endpointOption, allowedProviders, }); @@ -351,11 +115,25 @@ const initializeClient = async ({ req, res, endpointOption }) => { } } + let endpointConfig = req.app.locals[primaryConfig.endpoint]; + if (!isAgentsEndpoint(primaryConfig.endpoint) && !endpointConfig) { + try { + endpointConfig = await getCustomEndpointConfig(primaryConfig.endpoint); + } catch (err) { + logger.error( + '[api/server/controllers/agents/client.js #titleConvo] Error getting custom endpoint config', + err, + ); + } + } + const sender = primaryAgent.name ?? getResponseSender({ ...endpointOption, model: endpointOption.model_parameters.model, + modelDisplayLabel: endpointConfig?.modelDisplayLabel, + modelLabel: endpointOption.model_parameters.modelLabel, }); const client = new AgentClient({ @@ -373,8 +151,8 @@ const initializeClient = async ({ req, res, endpointOption }) => { iconURL: endpointOption.iconURL, attachments: primaryConfig.attachments, endpointType: endpointOption.endpointType, + resendFiles: primaryConfig.resendFiles ?? true, maxContextTokens: primaryConfig.maxContextTokens, - resendFiles: primaryConfig.model_parameters?.resendFiles ?? true, endpoint: primaryConfig.id === Constants.EPHEMERAL_AGENT_ID ? primaryConfig.endpoint diff --git a/api/server/services/Endpoints/agents/title.js b/api/server/services/Endpoints/agents/title.js index ab171bc79d..2e5f00ecd0 100644 --- a/api/server/services/Endpoints/agents/title.js +++ b/api/server/services/Endpoints/agents/title.js @@ -23,7 +23,7 @@ const addTitle = async (req, { text, response, client }) => { let timeoutId; try { const timeoutPromise = new Promise((_, reject) => { - timeoutId = setTimeout(() => reject(new Error('Title generation timeout')), 25000); + timeoutId = setTimeout(() => reject(new Error('Title generation timeout')), 45000); }).catch((error) => { logger.error('Title error:', error); }); diff --git a/api/server/services/Endpoints/anthropic/initialize.js b/api/server/services/Endpoints/anthropic/initialize.js index d4c6dd1795..4546fc634c 100644 --- a/api/server/services/Endpoints/anthropic/initialize.js +++ b/api/server/services/Endpoints/anthropic/initialize.js @@ -41,7 +41,7 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio { reverseProxyUrl: ANTHROPIC_REVERSE_PROXY ?? null, proxy: PROXY ?? null, - modelOptions: endpointOption.model_parameters, + modelOptions: endpointOption?.model_parameters ?? {}, }, clientOptions, ); diff --git a/api/server/services/Endpoints/anthropic/llm.js b/api/server/services/Endpoints/anthropic/llm.js index 9f20b8e61d..a14960ccd5 100644 --- a/api/server/services/Endpoints/anthropic/llm.js +++ b/api/server/services/Endpoints/anthropic/llm.js @@ -1,4 +1,4 @@ -const { HttpsProxyAgent } = require('https-proxy-agent'); +const { ProxyAgent } = require('undici'); const { anthropicSettings, removeNullishValues } = require('librechat-data-provider'); const { checkPromptCacheSupport, getClaudeHeaders, configureReasoning } = require('./helpers'); @@ -67,11 +67,15 @@ function getLLMConfig(apiKey, options = {}) { } if (options.proxy) { - requestOptions.clientOptions.httpAgent = new HttpsProxyAgent(options.proxy); + const proxyAgent = new ProxyAgent(options.proxy); + requestOptions.clientOptions.fetchOptions = { + dispatcher: proxyAgent, + }; } if (options.reverseProxyUrl) { requestOptions.clientOptions.baseURL = options.reverseProxyUrl; + requestOptions.anthropicApiUrl = options.reverseProxyUrl; } return { diff --git a/api/server/services/Endpoints/anthropic/llm.spec.js b/api/server/services/Endpoints/anthropic/llm.spec.js index 9c453efb92..cd29975e0a 100644 --- a/api/server/services/Endpoints/anthropic/llm.spec.js +++ b/api/server/services/Endpoints/anthropic/llm.spec.js @@ -1,11 +1,45 @@ -const { anthropicSettings } = require('librechat-data-provider'); +const { anthropicSettings, removeNullishValues } = require('librechat-data-provider'); const { getLLMConfig } = require('~/server/services/Endpoints/anthropic/llm'); +const { checkPromptCacheSupport, getClaudeHeaders, configureReasoning } = require('./helpers'); jest.mock('https-proxy-agent', () => ({ HttpsProxyAgent: jest.fn().mockImplementation((proxy) => ({ proxy })), })); +jest.mock('./helpers', () => ({ + checkPromptCacheSupport: jest.fn(), + getClaudeHeaders: jest.fn(), + configureReasoning: jest.fn((requestOptions) => requestOptions), +})); + +jest.mock('librechat-data-provider', () => ({ + anthropicSettings: { + model: { default: 'claude-3-opus-20240229' }, + maxOutputTokens: { default: 4096, reset: jest.fn(() => 4096) }, + thinking: { default: false }, + promptCache: { default: false }, + thinkingBudget: { default: null }, + }, + removeNullishValues: jest.fn((obj) => { + const result = {}; + for (const key in obj) { + if (obj[key] !== null && obj[key] !== undefined) { + result[key] = obj[key]; + } + } + return result; + }), +})); + describe('getLLMConfig', () => { + beforeEach(() => { + jest.clearAllMocks(); + checkPromptCacheSupport.mockReturnValue(false); + getClaudeHeaders.mockReturnValue(undefined); + configureReasoning.mockImplementation((requestOptions) => requestOptions); + anthropicSettings.maxOutputTokens.reset.mockReturnValue(4096); + }); + it('should create a basic configuration with default values', () => { const result = getLLMConfig('test-api-key', { modelOptions: {} }); @@ -21,8 +55,12 @@ describe('getLLMConfig', () => { proxy: 'http://proxy:8080', }); - expect(result.llmConfig.clientOptions).toHaveProperty('httpAgent'); - expect(result.llmConfig.clientOptions.httpAgent).toHaveProperty('proxy', 'http://proxy:8080'); + expect(result.llmConfig.clientOptions).toHaveProperty('fetchOptions'); + expect(result.llmConfig.clientOptions.fetchOptions).toHaveProperty('dispatcher'); + expect(result.llmConfig.clientOptions.fetchOptions.dispatcher).toBeDefined(); + expect(result.llmConfig.clientOptions.fetchOptions.dispatcher.constructor.name).toBe( + 'ProxyAgent', + ); }); it('should include reverse proxy URL when provided', () => { @@ -32,6 +70,7 @@ describe('getLLMConfig', () => { }); expect(result.llmConfig.clientOptions).toHaveProperty('baseURL', 'http://reverse-proxy'); + expect(result.llmConfig).toHaveProperty('anthropicApiUrl', 'http://reverse-proxy'); }); it('should include topK and topP for non-Claude-3.7 models', () => { @@ -61,6 +100,11 @@ describe('getLLMConfig', () => { }); it('should NOT include topK and topP for Claude-3-7 models (hyphen notation)', () => { + configureReasoning.mockImplementation((requestOptions) => { + requestOptions.thinking = { type: 'enabled' }; + return requestOptions; + }); + const result = getLLMConfig('test-api-key', { modelOptions: { model: 'claude-3-7-sonnet', @@ -74,6 +118,11 @@ describe('getLLMConfig', () => { }); it('should NOT include topK and topP for Claude-3.7 models (decimal notation)', () => { + configureReasoning.mockImplementation((requestOptions) => { + requestOptions.thinking = { type: 'enabled' }; + return requestOptions; + }); + const result = getLLMConfig('test-api-key', { modelOptions: { model: 'claude-3.7-sonnet', @@ -150,4 +199,160 @@ describe('getLLMConfig', () => { expect(result3.llmConfig).toHaveProperty('topK', 10); expect(result3.llmConfig).toHaveProperty('topP', 0.9); }); + + describe('Edge cases', () => { + it('should handle missing apiKey', () => { + const result = getLLMConfig(undefined, { modelOptions: {} }); + expect(result.llmConfig).not.toHaveProperty('apiKey'); + }); + + it('should handle empty modelOptions', () => { + expect(() => { + getLLMConfig('test-api-key', {}); + }).toThrow("Cannot read properties of undefined (reading 'thinking')"); + }); + + it('should handle no options parameter', () => { + expect(() => { + getLLMConfig('test-api-key'); + }).toThrow("Cannot read properties of undefined (reading 'thinking')"); + }); + + it('should handle temperature, stop sequences, and stream settings', () => { + const result = getLLMConfig('test-api-key', { + modelOptions: { + temperature: 0.7, + stop: ['\n\n', 'END'], + stream: false, + }, + }); + + expect(result.llmConfig).toHaveProperty('temperature', 0.7); + expect(result.llmConfig).toHaveProperty('stopSequences', ['\n\n', 'END']); + expect(result.llmConfig).toHaveProperty('stream', false); + }); + + it('should handle maxOutputTokens when explicitly set to falsy value', () => { + anthropicSettings.maxOutputTokens.reset.mockReturnValue(8192); + const result = getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3-opus', + maxOutputTokens: null, + }, + }); + + expect(anthropicSettings.maxOutputTokens.reset).toHaveBeenCalledWith('claude-3-opus'); + expect(result.llmConfig).toHaveProperty('maxTokens', 8192); + }); + + it('should handle both proxy and reverseProxyUrl', () => { + const result = getLLMConfig('test-api-key', { + modelOptions: {}, + proxy: 'http://proxy:8080', + reverseProxyUrl: 'https://reverse-proxy.com', + }); + + expect(result.llmConfig.clientOptions).toHaveProperty('fetchOptions'); + expect(result.llmConfig.clientOptions.fetchOptions).toHaveProperty('dispatcher'); + expect(result.llmConfig.clientOptions.fetchOptions.dispatcher).toBeDefined(); + expect(result.llmConfig.clientOptions.fetchOptions.dispatcher.constructor.name).toBe( + 'ProxyAgent', + ); + expect(result.llmConfig.clientOptions).toHaveProperty('baseURL', 'https://reverse-proxy.com'); + expect(result.llmConfig).toHaveProperty('anthropicApiUrl', 'https://reverse-proxy.com'); + }); + + it('should handle prompt cache with supported model', () => { + checkPromptCacheSupport.mockReturnValue(true); + getClaudeHeaders.mockReturnValue({ 'anthropic-beta': 'prompt-caching-2024-07-31' }); + + const result = getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3-5-sonnet', + promptCache: true, + }, + }); + + expect(checkPromptCacheSupport).toHaveBeenCalledWith('claude-3-5-sonnet'); + expect(getClaudeHeaders).toHaveBeenCalledWith('claude-3-5-sonnet', true); + expect(result.llmConfig.clientOptions.defaultHeaders).toEqual({ + 'anthropic-beta': 'prompt-caching-2024-07-31', + }); + }); + + it('should handle thinking and thinkingBudget options', () => { + configureReasoning.mockImplementation((requestOptions, systemOptions) => { + if (systemOptions.thinking) { + requestOptions.thinking = { type: 'enabled' }; + } + if (systemOptions.thinkingBudget) { + requestOptions.thinking = { + ...requestOptions.thinking, + budget_tokens: systemOptions.thinkingBudget, + }; + } + return requestOptions; + }); + + getLLMConfig('test-api-key', { + modelOptions: { + model: 'claude-3-7-sonnet', + thinking: true, + thinkingBudget: 5000, + }, + }); + + expect(configureReasoning).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + thinking: true, + promptCache: false, + thinkingBudget: 5000, + }), + ); + }); + + it('should remove system options from modelOptions', () => { + const modelOptions = { + model: 'claude-3-opus', + thinking: true, + promptCache: true, + thinkingBudget: 1000, + temperature: 0.5, + }; + + getLLMConfig('test-api-key', { modelOptions }); + + expect(modelOptions).not.toHaveProperty('thinking'); + expect(modelOptions).not.toHaveProperty('promptCache'); + expect(modelOptions).not.toHaveProperty('thinkingBudget'); + expect(modelOptions).toHaveProperty('temperature', 0.5); + }); + + it('should handle all nullish values removal', () => { + removeNullishValues.mockImplementation((obj) => { + const cleaned = {}; + Object.entries(obj).forEach(([key, value]) => { + if (value !== null && value !== undefined) { + cleaned[key] = value; + } + }); + return cleaned; + }); + + const result = getLLMConfig('test-api-key', { + modelOptions: { + temperature: null, + topP: undefined, + topK: 0, + stop: [], + }, + }); + + expect(result.llmConfig).not.toHaveProperty('temperature'); + expect(result.llmConfig).not.toHaveProperty('topP'); + expect(result.llmConfig).toHaveProperty('topK', 0); + expect(result.llmConfig).toHaveProperty('stopSequences', []); + }); + }); }); diff --git a/api/server/services/Endpoints/azureAssistants/initialize.js b/api/server/services/Endpoints/azureAssistants/initialize.js index fc8024af07..132c123e7e 100644 --- a/api/server/services/Endpoints/azureAssistants/initialize.js +++ b/api/server/services/Endpoints/azureAssistants/initialize.js @@ -1,19 +1,13 @@ const OpenAI = require('openai'); const { HttpsProxyAgent } = require('https-proxy-agent'); -const { - ErrorTypes, - EModelEndpoint, - resolveHeaders, - mapModelToAzureConfig, -} = require('librechat-data-provider'); +const { constructAzureURL, isUserProvided, resolveHeaders } = require('@librechat/api'); +const { ErrorTypes, EModelEndpoint, mapModelToAzureConfig } = require('librechat-data-provider'); const { getUserKeyValues, getUserKeyExpiry, checkUserKeyExpiry, } = require('~/server/services/UserService'); const OpenAIClient = require('~/app/clients/OpenAIClient'); -const { isUserProvided } = require('~/server/utils'); -const { constructAzureURL } = require('~/utils'); class Files { constructor(client) { @@ -115,11 +109,14 @@ const initializeClient = async ({ req, res, version, endpointOption, initAppClie apiKey = azureOptions.azureOpenAIApiKey; opts.defaultQuery = { 'api-version': azureOptions.azureOpenAIApiVersion }; - opts.defaultHeaders = resolveHeaders({ - ...headers, - 'api-key': apiKey, - 'OpenAI-Beta': `assistants=${version}`, - }); + opts.defaultHeaders = resolveHeaders( + { + ...headers, + 'api-key': apiKey, + 'OpenAI-Beta': `assistants=${version}`, + }, + req.user, + ); opts.model = azureOptions.azureOpenAIApiDeploymentName; if (initAppClient) { diff --git a/api/server/services/Endpoints/bedrock/options.js b/api/server/services/Endpoints/bedrock/options.js index da332060e9..a31d6e10c4 100644 --- a/api/server/services/Endpoints/bedrock/options.js +++ b/api/server/services/Endpoints/bedrock/options.js @@ -1,4 +1,5 @@ const { HttpsProxyAgent } = require('https-proxy-agent'); +const { createHandleLLMNewToken } = require('@librechat/api'); const { AuthType, Constants, @@ -8,7 +9,6 @@ const { removeNullishValues, } = require('librechat-data-provider'); const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService'); -const { createHandleLLMNewToken } = require('~/app/clients/generators'); const getOptions = async ({ req, overrideModel, endpointOption }) => { const { @@ -64,7 +64,7 @@ const getOptions = async ({ req, overrideModel, endpointOption }) => { /** @type {BedrockClientOptions} */ const requestOptions = { - model: overrideModel ?? endpointOption.model, + model: overrideModel ?? endpointOption?.model, region: BEDROCK_AWS_DEFAULT_REGION, }; @@ -76,7 +76,7 @@ const getOptions = async ({ req, overrideModel, endpointOption }) => { const llmConfig = bedrockOutputParser( bedrockInputParser.parse( - removeNullishValues(Object.assign(requestOptions, endpointOption.model_parameters)), + removeNullishValues(Object.assign(requestOptions, endpointOption?.model_parameters ?? {})), ), ); diff --git a/api/server/services/Endpoints/custom/initialize.js b/api/server/services/Endpoints/custom/initialize.js index 39def8d0d5..4fcbe76ea6 100644 --- a/api/server/services/Endpoints/custom/initialize.js +++ b/api/server/services/Endpoints/custom/initialize.js @@ -6,10 +6,9 @@ const { extractEnvVariable, } = require('librechat-data-provider'); const { Providers } = require('@librechat/agents'); +const { getOpenAIConfig, createHandleLLMNewToken, resolveHeaders } = require('@librechat/api'); const { getUserKeyValues, checkUserKeyExpiry } = require('~/server/services/UserService'); -const { getLLMConfig } = require('~/server/services/Endpoints/openAI/llm'); const { getCustomEndpointConfig } = require('~/server/services/Config'); -const { createHandleLLMNewToken } = require('~/app/clients/generators'); const { fetchModels } = require('~/server/services/ModelService'); const OpenAIClient = require('~/app/clients/OpenAIClient'); const { isUserProvided } = require('~/server/utils'); @@ -29,12 +28,7 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid const CUSTOM_API_KEY = extractEnvVariable(endpointConfig.apiKey); const CUSTOM_BASE_URL = extractEnvVariable(endpointConfig.baseURL); - let resolvedHeaders = {}; - if (endpointConfig.headers && typeof endpointConfig.headers === 'object') { - Object.keys(endpointConfig.headers).forEach((key) => { - resolvedHeaders[key] = extractEnvVariable(endpointConfig.headers[key]); - }); - } + let resolvedHeaders = resolveHeaders(endpointConfig.headers, req.user); if (CUSTOM_API_KEY.match(envVarRegex)) { throw new Error(`Missing API Key for ${endpoint}.`); @@ -135,7 +129,7 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid }; if (optionsOnly) { - const modelOptions = endpointOption.model_parameters; + const modelOptions = endpointOption?.model_parameters ?? {}; if (endpoint !== Providers.OLLAMA) { clientOptions = Object.assign( { @@ -144,7 +138,7 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid clientOptions, ); clientOptions.modelOptions.user = req.user.id; - const options = getLLMConfig(apiKey, clientOptions, endpoint); + const options = getOpenAIConfig(apiKey, clientOptions, endpoint); if (!customOptions.streamRate) { return options; } diff --git a/api/server/services/Endpoints/custom/initialize.spec.js b/api/server/services/Endpoints/custom/initialize.spec.js new file mode 100644 index 0000000000..7e28995127 --- /dev/null +++ b/api/server/services/Endpoints/custom/initialize.spec.js @@ -0,0 +1,93 @@ +const initializeClient = require('./initialize'); + +jest.mock('@librechat/api', () => ({ + resolveHeaders: jest.fn(), + getOpenAIConfig: jest.fn(), + createHandleLLMNewToken: jest.fn(), +})); + +jest.mock('librechat-data-provider', () => ({ + CacheKeys: { TOKEN_CONFIG: 'token_config' }, + ErrorTypes: { NO_USER_KEY: 'NO_USER_KEY', NO_BASE_URL: 'NO_BASE_URL' }, + envVarRegex: /\$\{([^}]+)\}/, + FetchTokenConfig: {}, + extractEnvVariable: jest.fn((value) => value), +})); + +jest.mock('@librechat/agents', () => ({ + Providers: { OLLAMA: 'ollama' }, +})); + +jest.mock('~/server/services/UserService', () => ({ + getUserKeyValues: jest.fn(), + checkUserKeyExpiry: jest.fn(), +})); + +jest.mock('~/server/services/Config', () => ({ + getCustomEndpointConfig: jest.fn().mockResolvedValue({ + apiKey: 'test-key', + baseURL: 'https://test.com', + headers: { 'x-user': '{{LIBRECHAT_USER_ID}}', 'x-email': '{{LIBRECHAT_USER_EMAIL}}' }, + models: { default: ['test-model'] }, + }), +})); + +jest.mock('~/server/services/ModelService', () => ({ + fetchModels: jest.fn(), +})); + +jest.mock('~/app/clients/OpenAIClient', () => { + return jest.fn().mockImplementation(() => ({ + options: {}, + })); +}); + +jest.mock('~/server/utils', () => ({ + isUserProvided: jest.fn().mockReturnValue(false), +})); + +jest.mock('~/cache/getLogStores', () => + jest.fn().mockReturnValue({ + get: jest.fn(), + }), +); + +describe('custom/initializeClient', () => { + const mockRequest = { + body: { endpoint: 'test-endpoint' }, + user: { id: 'user-123', email: 'test@example.com' }, + app: { locals: {} }, + }; + const mockResponse = {}; + + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('calls resolveHeaders with headers and user', async () => { + const { resolveHeaders } = require('@librechat/api'); + await initializeClient({ req: mockRequest, res: mockResponse, optionsOnly: true }); + expect(resolveHeaders).toHaveBeenCalledWith( + { 'x-user': '{{LIBRECHAT_USER_ID}}', 'x-email': '{{LIBRECHAT_USER_EMAIL}}' }, + { id: 'user-123', email: 'test@example.com' }, + ); + }); + + it('throws if endpoint config is missing', async () => { + const { getCustomEndpointConfig } = require('~/server/services/Config'); + getCustomEndpointConfig.mockResolvedValueOnce(null); + await expect( + initializeClient({ req: mockRequest, res: mockResponse, optionsOnly: true }), + ).rejects.toThrow('Config not found for the test-endpoint custom endpoint.'); + }); + + it('throws if user is missing', async () => { + await expect( + initializeClient({ + req: { ...mockRequest, user: undefined }, + res: mockResponse, + optionsOnly: true, + }), + ).rejects.toThrow("Cannot read properties of undefined (reading 'id')"); + }); +}); diff --git a/api/server/services/Endpoints/google/initialize.js b/api/server/services/Endpoints/google/initialize.js index b7419a8a87..871feda604 100644 --- a/api/server/services/Endpoints/google/initialize.js +++ b/api/server/services/Endpoints/google/initialize.js @@ -1,7 +1,7 @@ +const path = require('path'); const { EModelEndpoint, AuthKeys } = require('librechat-data-provider'); +const { getGoogleConfig, isEnabled, loadServiceKey } = require('@librechat/api'); const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService'); -const { getLLMConfig } = require('~/server/services/Endpoints/google/llm'); -const { isEnabled } = require('~/server/utils'); const { GoogleClient } = require('~/app'); const initializeClient = async ({ req, res, endpointOption, overrideModel, optionsOnly }) => { @@ -16,18 +16,25 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio } let serviceKey = {}; + try { - serviceKey = require('~/data/auth.json'); - } catch (e) { + const serviceKeyPath = + process.env.GOOGLE_SERVICE_KEY_FILE_PATH || + path.join(__dirname, '../../../..', 'data', 'auth.json'); + serviceKey = await loadServiceKey(serviceKeyPath); + if (!serviceKey) { + serviceKey = {}; + } + } catch (_e) { // Do nothing } const credentials = isUserProvided ? userKey : { - [AuthKeys.GOOGLE_SERVICE_KEY]: serviceKey, - [AuthKeys.GOOGLE_API_KEY]: GOOGLE_KEY, - }; + [AuthKeys.GOOGLE_SERVICE_KEY]: serviceKey, + [AuthKeys.GOOGLE_API_KEY]: GOOGLE_KEY, + }; let clientOptions = {}; @@ -58,14 +65,14 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio if (optionsOnly) { clientOptions = Object.assign( { - modelOptions: endpointOption.model_parameters, + modelOptions: endpointOption?.model_parameters ?? {}, }, clientOptions, ); if (overrideModel) { clientOptions.modelOptions.model = overrideModel; } - return getLLMConfig(credentials, clientOptions); + return getGoogleConfig(credentials, clientOptions); } const client = new GoogleClient(credentials, clientOptions); diff --git a/api/server/services/Endpoints/gptPlugins/build.js b/api/server/services/Endpoints/gptPlugins/build.js deleted file mode 100644 index 0d1ec097ad..0000000000 --- a/api/server/services/Endpoints/gptPlugins/build.js +++ /dev/null @@ -1,41 +0,0 @@ -const { removeNullishValues } = require('librechat-data-provider'); -const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts'); - -const buildOptions = (endpoint, parsedBody) => { - const { - modelLabel, - chatGptLabel, - promptPrefix, - agentOptions, - tools = [], - iconURL, - greeting, - spec, - maxContextTokens, - artifacts, - ...modelOptions - } = parsedBody; - const endpointOption = removeNullishValues({ - endpoint, - tools: tools - .map((tool) => tool?.pluginKey ?? tool) - .filter((toolName) => typeof toolName === 'string'), - modelLabel, - chatGptLabel, - promptPrefix, - agentOptions, - iconURL, - greeting, - spec, - maxContextTokens, - modelOptions, - }); - - if (typeof artifacts === 'string') { - endpointOption.artifactsPrompt = generateArtifactsPrompt({ endpoint, artifacts }); - } - - return endpointOption; -}; - -module.exports = buildOptions; diff --git a/api/server/services/Endpoints/gptPlugins/index.js b/api/server/services/Endpoints/gptPlugins/index.js deleted file mode 100644 index 202cb0e4d7..0000000000 --- a/api/server/services/Endpoints/gptPlugins/index.js +++ /dev/null @@ -1,7 +0,0 @@ -const buildOptions = require('./build'); -const initializeClient = require('./initialize'); - -module.exports = { - buildOptions, - initializeClient, -}; diff --git a/api/server/services/Endpoints/gptPlugins/initialize.js b/api/server/services/Endpoints/gptPlugins/initialize.js deleted file mode 100644 index 7bfb43f004..0000000000 --- a/api/server/services/Endpoints/gptPlugins/initialize.js +++ /dev/null @@ -1,135 +0,0 @@ -const { - EModelEndpoint, - mapModelToAzureConfig, - resolveHeaders, -} = require('librechat-data-provider'); -const { getUserKeyValues, checkUserKeyExpiry } = require('~/server/services/UserService'); -const { isEnabled, isUserProvided } = require('~/server/utils'); -const { getAzureCredentials } = require('~/utils'); -const { PluginsClient } = require('~/app'); - -const initializeClient = async ({ req, res, endpointOption }) => { - const { - PROXY, - OPENAI_API_KEY, - AZURE_API_KEY, - PLUGINS_USE_AZURE, - OPENAI_REVERSE_PROXY, - AZURE_OPENAI_BASEURL, - OPENAI_SUMMARIZE, - DEBUG_PLUGINS, - } = process.env; - - const { key: expiresAt, model: modelName } = req.body; - const contextStrategy = isEnabled(OPENAI_SUMMARIZE) ? 'summarize' : null; - - let useAzure = isEnabled(PLUGINS_USE_AZURE); - let endpoint = useAzure ? EModelEndpoint.azureOpenAI : EModelEndpoint.openAI; - - /** @type {false | TAzureConfig} */ - const azureConfig = req.app.locals[EModelEndpoint.azureOpenAI]; - useAzure = useAzure || azureConfig?.plugins; - - if (useAzure && endpoint !== EModelEndpoint.azureOpenAI) { - endpoint = EModelEndpoint.azureOpenAI; - } - - const credentials = { - [EModelEndpoint.openAI]: OPENAI_API_KEY, - [EModelEndpoint.azureOpenAI]: AZURE_API_KEY, - }; - - const baseURLOptions = { - [EModelEndpoint.openAI]: OPENAI_REVERSE_PROXY, - [EModelEndpoint.azureOpenAI]: AZURE_OPENAI_BASEURL, - }; - - const userProvidesKey = isUserProvided(credentials[endpoint]); - const userProvidesURL = isUserProvided(baseURLOptions[endpoint]); - - let userValues = null; - if (expiresAt && (userProvidesKey || userProvidesURL)) { - checkUserKeyExpiry(expiresAt, endpoint); - userValues = await getUserKeyValues({ userId: req.user.id, name: endpoint }); - } - - let apiKey = userProvidesKey ? userValues?.apiKey : credentials[endpoint]; - let baseURL = userProvidesURL ? userValues?.baseURL : baseURLOptions[endpoint]; - - const clientOptions = { - contextStrategy, - debug: isEnabled(DEBUG_PLUGINS), - reverseProxyUrl: baseURL ? baseURL : null, - proxy: PROXY ?? null, - req, - res, - ...endpointOption, - }; - - if (useAzure && azureConfig) { - const { modelGroupMap, groupMap } = azureConfig; - const { - azureOptions, - baseURL, - headers = {}, - serverless, - } = mapModelToAzureConfig({ - modelName, - modelGroupMap, - groupMap, - }); - - clientOptions.reverseProxyUrl = baseURL ?? clientOptions.reverseProxyUrl; - clientOptions.headers = resolveHeaders({ ...headers, ...(clientOptions.headers ?? {}) }); - - clientOptions.titleConvo = azureConfig.titleConvo; - clientOptions.titleModel = azureConfig.titleModel; - clientOptions.titleMethod = azureConfig.titleMethod ?? 'completion'; - - const azureRate = modelName.includes('gpt-4') ? 30 : 17; - clientOptions.streamRate = azureConfig.streamRate ?? azureRate; - - const groupName = modelGroupMap[modelName].group; - clientOptions.addParams = azureConfig.groupMap[groupName].addParams; - clientOptions.dropParams = azureConfig.groupMap[groupName].dropParams; - clientOptions.forcePrompt = azureConfig.groupMap[groupName].forcePrompt; - - apiKey = azureOptions.azureOpenAIApiKey; - clientOptions.azure = !serverless && azureOptions; - if (serverless === true) { - clientOptions.defaultQuery = azureOptions.azureOpenAIApiVersion - ? { 'api-version': azureOptions.azureOpenAIApiVersion } - : undefined; - clientOptions.headers['api-key'] = apiKey; - } - } else if (useAzure || (apiKey && apiKey.includes('{"azure') && !clientOptions.azure)) { - clientOptions.azure = userProvidesKey ? JSON.parse(userValues.apiKey) : getAzureCredentials(); - apiKey = clientOptions.azure.azureOpenAIApiKey; - } - - /** @type {undefined | TBaseEndpoint} */ - const pluginsConfig = req.app.locals[EModelEndpoint.gptPlugins]; - - if (!useAzure && pluginsConfig) { - clientOptions.streamRate = pluginsConfig.streamRate; - } - - /** @type {undefined | TBaseEndpoint} */ - const allConfig = req.app.locals.all; - if (allConfig) { - clientOptions.streamRate = allConfig.streamRate; - } - - if (!apiKey) { - throw new Error(`${endpoint} API key not provided. Please provide it again.`); - } - - const client = new PluginsClient(apiKey, clientOptions); - return { - client, - azure: clientOptions.azure, - openAIApiKey: apiKey, - }; -}; - -module.exports = initializeClient; diff --git a/api/server/services/Endpoints/gptPlugins/initialize.spec.js b/api/server/services/Endpoints/gptPlugins/initialize.spec.js deleted file mode 100644 index 02199c9397..0000000000 --- a/api/server/services/Endpoints/gptPlugins/initialize.spec.js +++ /dev/null @@ -1,410 +0,0 @@ -// gptPlugins/initializeClient.spec.js -jest.mock('~/cache/getLogStores'); -const { EModelEndpoint, ErrorTypes, validateAzureGroups } = require('librechat-data-provider'); -const { getUserKey, getUserKeyValues } = require('~/server/services/UserService'); -const initializeClient = require('./initialize'); -const { PluginsClient } = require('~/app'); - -// Mock getUserKey since it's the only function we want to mock -jest.mock('~/server/services/UserService', () => ({ - getUserKey: jest.fn(), - getUserKeyValues: jest.fn(), - checkUserKeyExpiry: jest.requireActual('~/server/services/UserService').checkUserKeyExpiry, -})); - -describe('gptPlugins/initializeClient', () => { - // Set up environment variables - const originalEnvironment = process.env; - const app = { - locals: {}, - }; - - const validAzureConfigs = [ - { - group: 'librechat-westus', - apiKey: 'WESTUS_API_KEY', - instanceName: 'librechat-westus', - version: '2023-12-01-preview', - models: { - 'gpt-4-vision-preview': { - deploymentName: 'gpt-4-vision-preview', - version: '2024-02-15-preview', - }, - 'gpt-3.5-turbo': { - deploymentName: 'gpt-35-turbo', - }, - 'gpt-3.5-turbo-1106': { - deploymentName: 'gpt-35-turbo-1106', - }, - 'gpt-4': { - deploymentName: 'gpt-4', - }, - 'gpt-4-1106-preview': { - deploymentName: 'gpt-4-1106-preview', - }, - }, - }, - { - group: 'librechat-eastus', - apiKey: 'EASTUS_API_KEY', - instanceName: 'librechat-eastus', - deploymentName: 'gpt-4-turbo', - version: '2024-02-15-preview', - models: { - 'gpt-4-turbo': true, - }, - baseURL: 'https://eastus.example.com', - additionalHeaders: { - 'x-api-key': 'x-api-key-value', - }, - }, - { - group: 'mistral-inference', - apiKey: 'AZURE_MISTRAL_API_KEY', - baseURL: - 'https://Mistral-large-vnpet-serverless.region.inference.ai.azure.com/v1/chat/completions', - serverless: true, - models: { - 'mistral-large': true, - }, - }, - { - group: 'llama-70b-chat', - apiKey: 'AZURE_LLAMA2_70B_API_KEY', - baseURL: - 'https://Llama-2-70b-chat-qmvyb-serverless.region.inference.ai.azure.com/v1/chat/completions', - serverless: true, - models: { - 'llama-70b-chat': true, - }, - }, - ]; - - const { modelNames, modelGroupMap, groupMap } = validateAzureGroups(validAzureConfigs); - - beforeEach(() => { - jest.resetModules(); // Clears the cache - process.env = { ...originalEnvironment }; // Make a copy - }); - - afterAll(() => { - process.env = originalEnvironment; // Restore original env vars - }); - - test('should initialize PluginsClient with OpenAI API key and default options', async () => { - process.env.OPENAI_API_KEY = 'test-openai-api-key'; - process.env.PLUGINS_USE_AZURE = 'false'; - process.env.DEBUG_PLUGINS = 'false'; - process.env.OPENAI_SUMMARIZE = 'false'; - - const req = { - body: { key: null }, - user: { id: '123' }, - app, - }; - const res = {}; - const endpointOption = { modelOptions: { model: 'default-model' } }; - - const { client, openAIApiKey } = await initializeClient({ req, res, endpointOption }); - - expect(openAIApiKey).toBe('test-openai-api-key'); - expect(client).toBeInstanceOf(PluginsClient); - }); - - test('should initialize PluginsClient with Azure credentials when PLUGINS_USE_AZURE is true', async () => { - process.env.AZURE_API_KEY = 'test-azure-api-key'; - (process.env.AZURE_OPENAI_API_INSTANCE_NAME = 'some-value'), - (process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME = 'some-value'), - (process.env.AZURE_OPENAI_API_VERSION = 'some-value'), - (process.env.AZURE_OPENAI_API_COMPLETIONS_DEPLOYMENT_NAME = 'some-value'), - (process.env.AZURE_OPENAI_API_EMBEDDINGS_DEPLOYMENT_NAME = 'some-value'), - (process.env.PLUGINS_USE_AZURE = 'true'); - process.env.DEBUG_PLUGINS = 'false'; - process.env.OPENAI_SUMMARIZE = 'false'; - - const req = { - body: { key: null }, - user: { id: '123' }, - app, - }; - const res = {}; - const endpointOption = { modelOptions: { model: 'test-model' } }; - - const { client, azure } = await initializeClient({ req, res, endpointOption }); - - expect(azure.azureOpenAIApiKey).toBe('test-azure-api-key'); - expect(client).toBeInstanceOf(PluginsClient); - }); - - test('should use the debug option when DEBUG_PLUGINS is enabled', async () => { - process.env.OPENAI_API_KEY = 'test-openai-api-key'; - process.env.DEBUG_PLUGINS = 'true'; - - const req = { - body: { key: null }, - user: { id: '123' }, - app, - }; - const res = {}; - const endpointOption = { modelOptions: { model: 'default-model' } }; - - const { client } = await initializeClient({ req, res, endpointOption }); - - expect(client.options.debug).toBe(true); - }); - - test('should set contextStrategy to summarize when OPENAI_SUMMARIZE is enabled', async () => { - process.env.OPENAI_API_KEY = 'test-openai-api-key'; - process.env.OPENAI_SUMMARIZE = 'true'; - - const req = { - body: { key: null }, - user: { id: '123' }, - app, - }; - const res = {}; - const endpointOption = { modelOptions: { model: 'default-model' } }; - - const { client } = await initializeClient({ req, res, endpointOption }); - - expect(client.options.contextStrategy).toBe('summarize'); - }); - - // ... additional tests for reverseProxyUrl, proxy, user-provided keys, etc. - - test('should throw an error if no API keys are provided in the environment', async () => { - // Clear the environment variables for API keys - delete process.env.OPENAI_API_KEY; - delete process.env.AZURE_API_KEY; - - const req = { - body: { key: null }, - user: { id: '123' }, - app, - }; - const res = {}; - const endpointOption = { modelOptions: { model: 'default-model' } }; - - await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow( - `${EModelEndpoint.openAI} API key not provided.`, - ); - }); - - // Additional tests for gptPlugins/initializeClient.spec.js - - // ... (previous test setup code) - - test('should handle user-provided OpenAI keys and check expiry', async () => { - process.env.OPENAI_API_KEY = 'user_provided'; - process.env.PLUGINS_USE_AZURE = 'false'; - - const futureDate = new Date(Date.now() + 10000).toISOString(); - const req = { - body: { key: futureDate }, - user: { id: '123' }, - app, - }; - const res = {}; - const endpointOption = { modelOptions: { model: 'default-model' } }; - - getUserKeyValues.mockResolvedValue({ apiKey: 'test-user-provided-openai-api-key' }); - - const { openAIApiKey } = await initializeClient({ req, res, endpointOption }); - - expect(openAIApiKey).toBe('test-user-provided-openai-api-key'); - }); - - test('should handle user-provided Azure keys and check expiry', async () => { - process.env.AZURE_API_KEY = 'user_provided'; - process.env.PLUGINS_USE_AZURE = 'true'; - - const futureDate = new Date(Date.now() + 10000).toISOString(); - const req = { - body: { key: futureDate }, - user: { id: '123' }, - app, - }; - const res = {}; - const endpointOption = { modelOptions: { model: 'test-model' } }; - - getUserKeyValues.mockResolvedValue({ - apiKey: JSON.stringify({ - azureOpenAIApiKey: 'test-user-provided-azure-api-key', - azureOpenAIApiDeploymentName: 'test-deployment', - }), - }); - - const { azure } = await initializeClient({ req, res, endpointOption }); - - expect(azure.azureOpenAIApiKey).toBe('test-user-provided-azure-api-key'); - }); - - test('should throw an error if the user-provided key has expired', async () => { - process.env.OPENAI_API_KEY = 'user_provided'; - process.env.PLUGINS_USE_AZURE = 'FALSE'; - const expiresAt = new Date(Date.now() - 10000).toISOString(); // Expired - const req = { - body: { key: expiresAt }, - user: { id: '123' }, - app, - }; - const res = {}; - const endpointOption = { modelOptions: { model: 'default-model' } }; - - await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow( - /expired_user_key/, - ); - }); - - test('should throw an error if the user-provided Azure key is invalid JSON', async () => { - process.env.AZURE_API_KEY = 'user_provided'; - process.env.PLUGINS_USE_AZURE = 'true'; - - const req = { - body: { key: new Date(Date.now() + 10000).toISOString() }, - user: { id: '123' }, - app, - }; - const res = {}; - const endpointOption = { modelOptions: { model: 'default-model' } }; - - // Simulate an invalid JSON string returned from getUserKey - getUserKey.mockResolvedValue('invalid-json'); - getUserKeyValues.mockImplementation(() => { - let userValues = getUserKey(); - try { - userValues = JSON.parse(userValues); - } catch (e) { - throw new Error( - JSON.stringify({ - type: ErrorTypes.INVALID_USER_KEY, - }), - ); - } - return userValues; - }); - - await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow( - /invalid_user_key/, - ); - }); - - test('should correctly handle the presence of a reverse proxy', async () => { - process.env.OPENAI_REVERSE_PROXY = 'http://reverse.proxy'; - process.env.PROXY = 'http://proxy'; - process.env.OPENAI_API_KEY = 'test-openai-api-key'; - - const req = { - body: { key: null }, - user: { id: '123' }, - app, - }; - const res = {}; - const endpointOption = { modelOptions: { model: 'default-model' } }; - - const { client } = await initializeClient({ req, res, endpointOption }); - - expect(client.options.reverseProxyUrl).toBe('http://reverse.proxy'); - expect(client.options.proxy).toBe('http://proxy'); - }); - - test('should throw an error when user-provided values are not valid JSON', async () => { - process.env.OPENAI_API_KEY = 'user_provided'; - const req = { - body: { key: new Date(Date.now() + 10000).toISOString(), endpoint: 'openAI' }, - user: { id: '123' }, - app, - }; - const res = {}; - const endpointOption = {}; - - // Mock getUserKey to return a non-JSON string - getUserKey.mockResolvedValue('not-a-json'); - getUserKeyValues.mockImplementation(() => { - let userValues = getUserKey(); - try { - userValues = JSON.parse(userValues); - } catch (e) { - throw new Error( - JSON.stringify({ - type: ErrorTypes.INVALID_USER_KEY, - }), - ); - } - return userValues; - }); - - await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow( - /invalid_user_key/, - ); - }); - - test('should initialize client correctly for Azure OpenAI with valid configuration', async () => { - const req = { - body: { - key: null, - endpoint: EModelEndpoint.gptPlugins, - model: modelNames[0], - }, - user: { id: '123' }, - app: { - locals: { - [EModelEndpoint.azureOpenAI]: { - plugins: true, - modelNames, - modelGroupMap, - groupMap, - }, - }, - }, - }; - const res = {}; - const endpointOption = {}; - - const client = await initializeClient({ req, res, endpointOption }); - expect(client.client.options.azure).toBeDefined(); - }); - - test('should initialize client with default options when certain env vars are not set', async () => { - delete process.env.OPENAI_SUMMARIZE; - process.env.OPENAI_API_KEY = 'some-api-key'; - - const req = { - body: { key: null, endpoint: EModelEndpoint.gptPlugins }, - user: { id: '123' }, - app, - }; - const res = {}; - const endpointOption = {}; - - const client = await initializeClient({ req, res, endpointOption }); - expect(client.client.options.contextStrategy).toBe(null); - }); - - test('should correctly use user-provided apiKey and baseURL when provided', async () => { - process.env.OPENAI_API_KEY = 'user_provided'; - process.env.OPENAI_REVERSE_PROXY = 'user_provided'; - const req = { - body: { - key: new Date(Date.now() + 10000).toISOString(), - endpoint: 'openAI', - }, - user: { - id: '123', - }, - app, - }; - const res = {}; - const endpointOption = {}; - - getUserKeyValues.mockResolvedValue({ - apiKey: 'test', - baseURL: 'https://user-provided-url.com', - }); - - const result = await initializeClient({ req, res, endpointOption }); - - expect(result.openAIApiKey).toBe('test'); - expect(result.client.options.reverseProxyUrl).toBe('https://user-provided-url.com'); - }); -}); diff --git a/api/server/services/Endpoints/index.js b/api/server/services/Endpoints/index.js new file mode 100644 index 0000000000..8171789418 --- /dev/null +++ b/api/server/services/Endpoints/index.js @@ -0,0 +1,75 @@ +const { Providers } = require('@librechat/agents'); +const { EModelEndpoint } = require('librechat-data-provider'); +const initAnthropic = require('~/server/services/Endpoints/anthropic/initialize'); +const getBedrockOptions = require('~/server/services/Endpoints/bedrock/options'); +const initOpenAI = require('~/server/services/Endpoints/openAI/initialize'); +const initCustom = require('~/server/services/Endpoints/custom/initialize'); +const initGoogle = require('~/server/services/Endpoints/google/initialize'); +const { getCustomEndpointConfig } = require('~/server/services/Config'); + +/** Check if the provider is a known custom provider + * @param {string | undefined} [provider] - The provider string + * @returns {boolean} - True if the provider is a known custom provider, false otherwise + */ +function isKnownCustomProvider(provider) { + return [Providers.XAI, Providers.OLLAMA, Providers.DEEPSEEK, Providers.OPENROUTER].includes( + provider || '', + ); +} + +const providerConfigMap = { + [Providers.XAI]: initCustom, + [Providers.OLLAMA]: initCustom, + [Providers.DEEPSEEK]: initCustom, + [Providers.OPENROUTER]: initCustom, + [EModelEndpoint.openAI]: initOpenAI, + [EModelEndpoint.google]: initGoogle, + [EModelEndpoint.azureOpenAI]: initOpenAI, + [EModelEndpoint.anthropic]: initAnthropic, + [EModelEndpoint.bedrock]: getBedrockOptions, +}; + +/** + * Get the provider configuration and override endpoint based on the provider string + * @param {string} provider - The provider string + * @returns {Promise<{ + * getOptions: Function, + * overrideProvider?: string, + * customEndpointConfig?: TEndpoint + * }>} + */ +async function getProviderConfig(provider) { + let getOptions = providerConfigMap[provider]; + let overrideProvider; + /** @type {TEndpoint | undefined} */ + let customEndpointConfig; + + if (!getOptions && providerConfigMap[provider.toLowerCase()] != null) { + overrideProvider = provider.toLowerCase(); + getOptions = providerConfigMap[overrideProvider]; + } else if (!getOptions) { + customEndpointConfig = await getCustomEndpointConfig(provider); + if (!customEndpointConfig) { + throw new Error(`Provider ${provider} not supported`); + } + getOptions = initCustom; + overrideProvider = Providers.OPENAI; + } + + if (isKnownCustomProvider(overrideProvider)) { + customEndpointConfig = await getCustomEndpointConfig(provider); + if (!customEndpointConfig) { + throw new Error(`Provider ${provider} not supported`); + } + } + + return { + getOptions, + overrideProvider, + customEndpointConfig, + }; +} + +module.exports = { + getProviderConfig, +}; diff --git a/api/server/services/Endpoints/openAI/initialize.js b/api/server/services/Endpoints/openAI/initialize.js index 714ed5a1e6..e86596181a 100644 --- a/api/server/services/Endpoints/openAI/initialize.js +++ b/api/server/services/Endpoints/openAI/initialize.js @@ -1,15 +1,14 @@ +const { ErrorTypes, EModelEndpoint, mapModelToAzureConfig } = require('librechat-data-provider'); const { - ErrorTypes, - EModelEndpoint, + isEnabled, resolveHeaders, - mapModelToAzureConfig, -} = require('librechat-data-provider'); + isUserProvided, + getOpenAIConfig, + getAzureCredentials, + createHandleLLMNewToken, +} = require('@librechat/api'); const { getUserKeyValues, checkUserKeyExpiry } = require('~/server/services/UserService'); -const { getLLMConfig } = require('~/server/services/Endpoints/openAI/llm'); -const { createHandleLLMNewToken } = require('~/app/clients/generators'); -const { isEnabled, isUserProvided } = require('~/server/utils'); const OpenAIClient = require('~/app/clients/OpenAIClient'); -const { getAzureCredentials } = require('~/utils'); const initializeClient = async ({ req, @@ -81,7 +80,10 @@ const initializeClient = async ({ }); clientOptions.reverseProxyUrl = baseURL ?? clientOptions.reverseProxyUrl; - clientOptions.headers = resolveHeaders({ ...headers, ...(clientOptions.headers ?? {}) }); + clientOptions.headers = resolveHeaders( + { ...headers, ...(clientOptions.headers ?? {}) }, + req.user, + ); clientOptions.titleConvo = azureConfig.titleConvo; clientOptions.titleModel = azureConfig.titleModel; @@ -136,11 +138,11 @@ const initializeClient = async ({ } if (optionsOnly) { - const modelOptions = endpointOption.model_parameters; + const modelOptions = endpointOption?.model_parameters ?? {}; modelOptions.model = modelName; clientOptions = Object.assign({ modelOptions }, clientOptions); clientOptions.modelOptions.user = req.user.id; - const options = getLLMConfig(apiKey, clientOptions); + const options = getOpenAIConfig(apiKey, clientOptions); const streamRate = clientOptions.streamRate; if (!streamRate) { return options; diff --git a/api/server/services/Endpoints/openAI/llm.js b/api/server/services/Endpoints/openAI/llm.js deleted file mode 100644 index c1fd090b28..0000000000 --- a/api/server/services/Endpoints/openAI/llm.js +++ /dev/null @@ -1,170 +0,0 @@ -const { HttpsProxyAgent } = require('https-proxy-agent'); -const { KnownEndpoints } = require('librechat-data-provider'); -const { sanitizeModelName, constructAzureURL } = require('~/utils'); -const { isEnabled } = require('~/server/utils'); - -/** - * Generates configuration options for creating a language model (LLM) instance. - * @param {string} apiKey - The API key for authentication. - * @param {Object} options - Additional options for configuring the LLM. - * @param {Object} [options.modelOptions] - Model-specific options. - * @param {string} [options.modelOptions.model] - The name of the model to use. - * @param {string} [options.modelOptions.user] - The user ID - * @param {number} [options.modelOptions.temperature] - Controls randomness in output generation (0-2). - * @param {number} [options.modelOptions.top_p] - Controls diversity via nucleus sampling (0-1). - * @param {number} [options.modelOptions.frequency_penalty] - Reduces repetition of token sequences (-2 to 2). - * @param {number} [options.modelOptions.presence_penalty] - Encourages discussing new topics (-2 to 2). - * @param {number} [options.modelOptions.max_tokens] - The maximum number of tokens to generate. - * @param {string[]} [options.modelOptions.stop] - Sequences where the API will stop generating further tokens. - * @param {string} [options.reverseProxyUrl] - URL for a reverse proxy, if used. - * @param {boolean} [options.useOpenRouter] - Flag to use OpenRouter API. - * @param {Object} [options.headers] - Additional headers for API requests. - * @param {string} [options.proxy] - Proxy server URL. - * @param {Object} [options.azure] - Azure-specific configurations. - * @param {boolean} [options.streaming] - Whether to use streaming mode. - * @param {Object} [options.addParams] - Additional parameters to add to the model options. - * @param {string[]} [options.dropParams] - Parameters to remove from the model options. - * @param {string|null} [endpoint=null] - The endpoint name - * @returns {Object} Configuration options for creating an LLM instance. - */ -function getLLMConfig(apiKey, options = {}, endpoint = null) { - let { - modelOptions = {}, - reverseProxyUrl, - defaultQuery, - headers, - proxy, - azure, - streaming = true, - addParams, - dropParams, - } = options; - - /** @type {OpenAIClientOptions} */ - let llmConfig = { - streaming, - }; - - Object.assign(llmConfig, modelOptions); - - if (addParams && typeof addParams === 'object') { - Object.assign(llmConfig, addParams); - } - /** Note: OpenAI Web Search models do not support any known parameters besdies `max_tokens` */ - if (modelOptions.model && /gpt-4o.*search/.test(modelOptions.model)) { - const searchExcludeParams = [ - 'frequency_penalty', - 'presence_penalty', - 'temperature', - 'top_p', - 'top_k', - 'stop', - 'logit_bias', - 'seed', - 'response_format', - 'n', - 'logprobs', - 'user', - ]; - - dropParams = dropParams || []; - dropParams = [...new Set([...dropParams, ...searchExcludeParams])]; - } - - if (dropParams && Array.isArray(dropParams)) { - dropParams.forEach((param) => { - if (llmConfig[param]) { - llmConfig[param] = undefined; - } - }); - } - - let useOpenRouter; - /** @type {OpenAIClientOptions['configuration']} */ - const configOptions = {}; - if ( - (reverseProxyUrl && reverseProxyUrl.includes(KnownEndpoints.openrouter)) || - (endpoint && endpoint.toLowerCase().includes(KnownEndpoints.openrouter)) - ) { - useOpenRouter = true; - llmConfig.include_reasoning = true; - configOptions.baseURL = reverseProxyUrl; - configOptions.defaultHeaders = Object.assign( - { - 'HTTP-Referer': 'https://librechat.ai', - 'X-Title': 'LibreChat', - }, - headers, - ); - } else if (reverseProxyUrl) { - configOptions.baseURL = reverseProxyUrl; - if (headers) { - configOptions.defaultHeaders = headers; - } - } - - if (defaultQuery) { - configOptions.defaultQuery = defaultQuery; - } - - if (proxy) { - const proxyAgent = new HttpsProxyAgent(proxy); - Object.assign(configOptions, { - httpAgent: proxyAgent, - httpsAgent: proxyAgent, - }); - } - - if (azure) { - const useModelName = isEnabled(process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME); - azure.azureOpenAIApiDeploymentName = useModelName - ? sanitizeModelName(llmConfig.model) - : azure.azureOpenAIApiDeploymentName; - - if (process.env.AZURE_OPENAI_DEFAULT_MODEL) { - llmConfig.model = process.env.AZURE_OPENAI_DEFAULT_MODEL; - } - - if (configOptions.baseURL) { - const azureURL = constructAzureURL({ - baseURL: configOptions.baseURL, - azureOptions: azure, - }); - azure.azureOpenAIBasePath = azureURL.split(`/${azure.azureOpenAIApiDeploymentName}`)[0]; - } - - Object.assign(llmConfig, azure); - llmConfig.model = llmConfig.azureOpenAIApiDeploymentName; - } else { - llmConfig.apiKey = apiKey; - // Object.assign(llmConfig, { - // configuration: { apiKey }, - // }); - } - - if (process.env.OPENAI_ORGANIZATION && this.azure) { - llmConfig.organization = process.env.OPENAI_ORGANIZATION; - } - - if (useOpenRouter && llmConfig.reasoning_effort != null) { - llmConfig.reasoning = { - effort: llmConfig.reasoning_effort, - }; - delete llmConfig.reasoning_effort; - } - - if (llmConfig?.['max_tokens'] != null) { - /** @type {number} */ - llmConfig.maxTokens = llmConfig['max_tokens']; - delete llmConfig['max_tokens']; - } - - return { - /** @type {OpenAIClientOptions} */ - llmConfig, - /** @type {OpenAIClientOptions['configuration']} */ - configOptions, - }; -} - -module.exports = { getLLMConfig }; diff --git a/api/server/services/Files/Audio/STTService.js b/api/server/services/Files/Audio/STTService.js index d6c8cc4146..49a800336b 100644 --- a/api/server/services/Files/Audio/STTService.js +++ b/api/server/services/Files/Audio/STTService.js @@ -2,9 +2,9 @@ const axios = require('axios'); const fs = require('fs').promises; const FormData = require('form-data'); const { Readable } = require('stream'); +const { genAzureEndpoint } = require('@librechat/api'); const { extractEnvVariable, STTProviders } = require('librechat-data-provider'); const { getCustomConfig } = require('~/server/services/Config'); -const { genAzureEndpoint } = require('~/utils'); const { logger } = require('~/config'); /** diff --git a/api/server/services/Files/Audio/TTSService.js b/api/server/services/Files/Audio/TTSService.js index cd718fdfc1..34d8202156 100644 --- a/api/server/services/Files/Audio/TTSService.js +++ b/api/server/services/Files/Audio/TTSService.js @@ -1,8 +1,8 @@ const axios = require('axios'); +const { genAzureEndpoint } = require('@librechat/api'); const { extractEnvVariable, TTSProviders } = require('librechat-data-provider'); const { getRandomVoiceId, createChunkProcessor, splitTextIntoChunks } = require('./streamAudio'); const { getCustomConfig } = require('~/server/services/Config'); -const { genAzureEndpoint } = require('~/utils'); const { logger } = require('~/config'); /** diff --git a/api/server/services/Files/Azure/images.js b/api/server/services/Files/Azure/images.js index a83b700af3..80d5e76290 100644 --- a/api/server/services/Files/Azure/images.js +++ b/api/server/services/Files/Azure/images.js @@ -1,10 +1,9 @@ const fs = require('fs'); const path = require('path'); const sharp = require('sharp'); +const { logger } = require('@librechat/data-schemas'); const { resizeImageBuffer } = require('../images/resize'); -const { updateUser } = require('~/models/userMethods'); -const { updateFile } = require('~/models/File'); -const { logger } = require('~/config'); +const { updateUser, updateFile } = require('~/models'); const { saveBufferToAzure } = require('./crud'); /** @@ -92,24 +91,44 @@ async function prepareAzureImageURL(req, file) { * @param {Buffer} params.buffer - The avatar image buffer. * @param {string} params.userId - The user's id. * @param {string} params.manual - Flag to indicate manual update. + * @param {string} [params.agentId] - Optional agent ID if this is an agent avatar. * @param {string} [params.basePath='images'] - The base folder within the container. * @param {string} [params.containerName] - The Azure Blob container name. * @returns {Promise} The URL of the avatar. */ -async function processAzureAvatar({ buffer, userId, manual, basePath = 'images', containerName }) { +async function processAzureAvatar({ + buffer, + userId, + manual, + agentId, + basePath = 'images', + containerName, +}) { try { + const metadata = await sharp(buffer).metadata(); + const extension = metadata.format === 'gif' ? 'gif' : 'png'; + const timestamp = new Date().getTime(); + + /** Unique filename with timestamp and optional agent ID */ + const fileName = agentId + ? `agent-${agentId}-avatar-${timestamp}.${extension}` + : `avatar-${timestamp}.${extension}`; + const downloadURL = await saveBufferToAzure({ userId, buffer, - fileName: 'avatar.png', + fileName, basePath, containerName, }); const isManual = manual === 'true'; const url = `${downloadURL}?manual=${isManual}`; - if (isManual) { + + // Only update user record if this is a user avatar (manual === 'true') + if (isManual && !agentId) { await updateUser(userId, { avatar: url }); } + return url; } catch (error) { logger.error('[processAzureAvatar] Error uploading profile picture to Azure:', error); diff --git a/api/server/services/Files/Code/crud.js b/api/server/services/Files/Code/crud.js index caea9ab30a..c696eae0c4 100644 --- a/api/server/services/Files/Code/crud.js +++ b/api/server/services/Files/Code/crud.js @@ -1,7 +1,6 @@ const FormData = require('form-data'); const { getCodeBaseURL } = require('@librechat/agents'); -const { createAxiosInstance } = require('~/config'); -const { logAxiosError } = require('~/utils'); +const { createAxiosInstance, logAxiosError } = require('@librechat/api'); const axios = createAxiosInstance(); diff --git a/api/server/services/Files/Code/process.js b/api/server/services/Files/Code/process.js index c92e628589..cf65154983 100644 --- a/api/server/services/Files/Code/process.js +++ b/api/server/services/Files/Code/process.js @@ -1,6 +1,8 @@ const path = require('path'); const { v4 } = require('uuid'); const axios = require('axios'); +const { logAxiosError } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { getCodeBaseURL } = require('@librechat/agents'); const { Tools, @@ -12,8 +14,6 @@ const { const { getStrategyFunctions } = require('~/server/services/Files/strategies'); const { convertImage } = require('~/server/services/Files/images/convert'); const { createFile, getFiles, updateFile } = require('~/models/File'); -const { logAxiosError } = require('~/utils'); -const { logger } = require('~/config'); /** * Process OpenAI image files, convert to target format, save and return file metadata. diff --git a/api/server/services/Files/Firebase/images.js b/api/server/services/Files/Firebase/images.js index 7345f30df1..8b0866b5d0 100644 --- a/api/server/services/Files/Firebase/images.js +++ b/api/server/services/Files/Firebase/images.js @@ -1,11 +1,10 @@ const fs = require('fs'); const path = require('path'); const sharp = require('sharp'); +const { logger } = require('@librechat/data-schemas'); const { resizeImageBuffer } = require('../images/resize'); -const { updateUser } = require('~/models/userMethods'); +const { updateUser, updateFile } = require('~/models'); const { saveBufferToFirebase } = require('./crud'); -const { updateFile } = require('~/models/File'); -const { logger } = require('~/config'); /** * Converts an image file to the target format. The function first resizes the image based on the specified @@ -83,22 +82,32 @@ async function prepareImageURL(req, file) { * @param {Buffer} params.buffer - The Buffer containing the avatar image. * @param {string} params.userId - The user ID. * @param {string} params.manual - A string flag indicating whether the update is manual ('true' or 'false'). + * @param {string} [params.agentId] - Optional agent ID if this is an agent avatar. * @returns {Promise} - A promise that resolves with the URL of the uploaded avatar. * @throws {Error} - Throws an error if Firebase is not initialized or if there is an error in uploading. */ -async function processFirebaseAvatar({ buffer, userId, manual }) { +async function processFirebaseAvatar({ buffer, userId, manual, agentId }) { try { + const metadata = await sharp(buffer).metadata(); + const extension = metadata.format === 'gif' ? 'gif' : 'png'; + const timestamp = new Date().getTime(); + + /** Unique filename with timestamp and optional agent ID */ + const fileName = agentId + ? `agent-${agentId}-avatar-${timestamp}.${extension}` + : `avatar-${timestamp}.${extension}`; + const downloadURL = await saveBufferToFirebase({ userId, buffer, - fileName: 'avatar.png', + fileName, }); const isManual = manual === 'true'; - const url = `${downloadURL}?manual=${isManual}`; - if (isManual) { + // Only update user record if this is a user avatar (manual === 'true') + if (isManual && !agentId) { await updateUser(userId, { avatar: url }); } diff --git a/api/server/services/Files/Local/crud.js b/api/server/services/Files/Local/crud.js index 783230f2f6..455d4e0c4f 100644 --- a/api/server/services/Files/Local/crud.js +++ b/api/server/services/Files/Local/crud.js @@ -1,10 +1,11 @@ const fs = require('fs'); const path = require('path'); const axios = require('axios'); +const { logger } = require('@librechat/data-schemas'); const { EModelEndpoint } = require('librechat-data-provider'); +const { generateShortLivedToken } = require('~/server/services/AuthService'); const { getBufferMetadata } = require('~/server/utils'); const paths = require('~/config/paths'); -const { logger } = require('~/config'); /** * Saves a file to a specified output path with a new filename. @@ -201,8 +202,12 @@ const unlinkFile = async (filepath) => { */ const deleteLocalFile = async (req, file) => { const { publicPath, uploads } = req.app.locals.paths; + + /** Filepath stripped of query parameters (e.g., ?manual=true) */ + const cleanFilepath = file.filepath.split('?')[0]; + if (file.embedded && process.env.RAG_API_URL) { - const jwtToken = req.headers.authorization.split(' ')[1]; + const jwtToken = generateShortLivedToken(req.user.id); axios.delete(`${process.env.RAG_API_URL}/documents`, { headers: { Authorization: `Bearer ${jwtToken}`, @@ -213,32 +218,32 @@ const deleteLocalFile = async (req, file) => { }); } - if (file.filepath.startsWith(`/uploads/${req.user.id}`)) { + if (cleanFilepath.startsWith(`/uploads/${req.user.id}`)) { const userUploadDir = path.join(uploads, req.user.id); - const basePath = file.filepath.split(`/uploads/${req.user.id}/`)[1]; + const basePath = cleanFilepath.split(`/uploads/${req.user.id}/`)[1]; if (!basePath) { - throw new Error(`Invalid file path: ${file.filepath}`); + throw new Error(`Invalid file path: ${cleanFilepath}`); } const filepath = path.join(userUploadDir, basePath); const rel = path.relative(userUploadDir, filepath); if (rel.startsWith('..') || path.isAbsolute(rel) || rel.includes(`..${path.sep}`)) { - throw new Error(`Invalid file path: ${file.filepath}`); + throw new Error(`Invalid file path: ${cleanFilepath}`); } await unlinkFile(filepath); return; } - const parts = file.filepath.split(path.sep); + const parts = cleanFilepath.split(path.sep); const subfolder = parts[1]; if (!subfolder && parts[0] === EModelEndpoint.agents) { logger.warn(`Agent File ${file.file_id} is missing filepath, may have been deleted already`); return; } - const filepath = path.join(publicPath, file.filepath); + const filepath = path.join(publicPath, cleanFilepath); if (!isValidPath(req, publicPath, subfolder, filepath)) { throw new Error('Invalid file path'); diff --git a/api/server/services/Files/Local/images.js b/api/server/services/Files/Local/images.js index 1305505381..ea3af87c70 100644 --- a/api/server/services/Files/Local/images.js +++ b/api/server/services/Files/Local/images.js @@ -2,8 +2,7 @@ const fs = require('fs'); const path = require('path'); const sharp = require('sharp'); const { resizeImageBuffer } = require('../images/resize'); -const { updateUser } = require('~/models/userMethods'); -const { updateFile } = require('~/models/File'); +const { updateUser, updateFile } = require('~/models'); /** * Converts an image file to the target format. The function first resizes the image based on the specified @@ -113,10 +112,11 @@ async function prepareImagesLocal(req, file) { * @param {Buffer} params.buffer - The Buffer containing the avatar image. * @param {string} params.userId - The user ID. * @param {string} params.manual - A string flag indicating whether the update is manual ('true' or 'false'). + * @param {string} [params.agentId] - Optional agent ID if this is an agent avatar. * @returns {Promise} - A promise that resolves with the URL of the uploaded avatar. * @throws {Error} - Throws an error if Firebase is not initialized or if there is an error in uploading. */ -async function processLocalAvatar({ buffer, userId, manual }) { +async function processLocalAvatar({ buffer, userId, manual, agentId }) { const userDir = path.resolve( __dirname, '..', @@ -130,7 +130,14 @@ async function processLocalAvatar({ buffer, userId, manual }) { userId, ); - const fileName = `avatar-${new Date().getTime()}.png`; + const metadata = await sharp(buffer).metadata(); + const extension = metadata.format === 'gif' ? 'gif' : 'png'; + + const timestamp = new Date().getTime(); + /** Unique filename with timestamp and optional agent ID */ + const fileName = agentId + ? `agent-${agentId}-avatar-${timestamp}.${extension}` + : `avatar-${timestamp}.${extension}`; const urlRoute = `/images/${userId}/${fileName}`; const avatarPath = path.join(userDir, fileName); @@ -140,7 +147,8 @@ async function processLocalAvatar({ buffer, userId, manual }) { const isManual = manual === 'true'; let url = `${urlRoute}?manual=${isManual}`; - if (isManual) { + // Only update user record if this is a user avatar (manual === 'true') + if (isManual && !agentId) { await updateUser(userId, { avatar: url }); } diff --git a/api/server/services/Files/MistralOCR/crud.js b/api/server/services/Files/MistralOCR/crud.js deleted file mode 100644 index cc01d803b0..0000000000 --- a/api/server/services/Files/MistralOCR/crud.js +++ /dev/null @@ -1,237 +0,0 @@ -// ~/server/services/Files/MistralOCR/crud.js -const fs = require('fs'); -const path = require('path'); -const FormData = require('form-data'); -const { - FileSources, - envVarRegex, - extractEnvVariable, - extractVariableName, -} = require('librechat-data-provider'); -const { loadAuthValues } = require('~/server/services/Tools/credentials'); -const { logger, createAxiosInstance } = require('~/config'); -const { logAxiosError } = require('~/utils/axios'); - -const axios = createAxiosInstance(); - -/** - * Uploads a document to Mistral API using file streaming to avoid loading the entire file into memory - * - * @param {Object} params Upload parameters - * @param {string} params.filePath The path to the file on disk - * @param {string} [params.fileName] Optional filename to use (defaults to the name from filePath) - * @param {string} params.apiKey Mistral API key - * @param {string} [params.baseURL=https://api.mistral.ai/v1] Mistral API base URL - * @returns {Promise} The response from Mistral API - */ -async function uploadDocumentToMistral({ - filePath, - fileName = '', - apiKey, - baseURL = 'https://api.mistral.ai/v1', -}) { - const form = new FormData(); - form.append('purpose', 'ocr'); - const actualFileName = fileName || path.basename(filePath); - const fileStream = fs.createReadStream(filePath); - form.append('file', fileStream, { filename: actualFileName }); - - return axios - .post(`${baseURL}/files`, form, { - headers: { - Authorization: `Bearer ${apiKey}`, - ...form.getHeaders(), - }, - maxBodyLength: Infinity, - maxContentLength: Infinity, - }) - .then((res) => res.data) - .catch((error) => { - throw error; - }); -} - -async function getSignedUrl({ - apiKey, - fileId, - expiry = 24, - baseURL = 'https://api.mistral.ai/v1', -}) { - return axios - .get(`${baseURL}/files/${fileId}/url?expiry=${expiry}`, { - headers: { - Authorization: `Bearer ${apiKey}`, - }, - }) - .then((res) => res.data) - .catch((error) => { - logger.error('Error fetching signed URL:', error.message); - throw error; - }); -} - -/** - * @param {Object} params - * @param {string} params.apiKey - * @param {string} params.url - The document or image URL - * @param {string} [params.documentType='document_url'] - 'document_url' or 'image_url' - * @param {string} [params.model] - * @param {string} [params.baseURL] - * @returns {Promise} - */ -async function performOCR({ - apiKey, - url, - documentType = 'document_url', - model = 'mistral-ocr-latest', - baseURL = 'https://api.mistral.ai/v1', -}) { - const documentKey = documentType === 'image_url' ? 'image_url' : 'document_url'; - return axios - .post( - `${baseURL}/ocr`, - { - model, - include_image_base64: false, - document: { - type: documentType, - [documentKey]: url, - }, - }, - { - headers: { - 'Content-Type': 'application/json', - Authorization: `Bearer ${apiKey}`, - }, - }, - ) - .then((res) => res.data) - .catch((error) => { - logger.error('Error performing OCR:', error.message); - throw error; - }); -} - -/** - * Uploads a file to the Mistral OCR API and processes the OCR result. - * - * @param {Object} params - The params object. - * @param {ServerRequest} params.req - The request object from Express. It should have a `user` property with an `id` - * representing the user - * @param {Express.Multer.File} params.file - The file object, which is part of the request. The file object should - * have a `mimetype` property that tells us the file type - * @param {string} params.file_id - The file ID. - * @param {string} [params.entity_id] - The entity ID, not used here but passed for consistency. - * @returns {Promise<{ filepath: string, bytes: number }>} - The result object containing the processed `text` and `images` (not currently used), - * along with the `filename` and `bytes` properties. - */ -const uploadMistralOCR = async ({ req, file, file_id, entity_id }) => { - try { - /** @type {TCustomConfig['ocr']} */ - const ocrConfig = req.app.locals?.ocr; - - const apiKeyConfig = ocrConfig.apiKey || ''; - const baseURLConfig = ocrConfig.baseURL || ''; - - const isApiKeyEnvVar = envVarRegex.test(apiKeyConfig); - const isBaseURLEnvVar = envVarRegex.test(baseURLConfig); - - const isApiKeyEmpty = !apiKeyConfig.trim(); - const isBaseURLEmpty = !baseURLConfig.trim(); - - let apiKey, baseURL; - - if (isApiKeyEnvVar || isBaseURLEnvVar || isApiKeyEmpty || isBaseURLEmpty) { - const apiKeyVarName = isApiKeyEnvVar ? extractVariableName(apiKeyConfig) : 'OCR_API_KEY'; - const baseURLVarName = isBaseURLEnvVar ? extractVariableName(baseURLConfig) : 'OCR_BASEURL'; - - const authValues = await loadAuthValues({ - userId: req.user.id, - authFields: [baseURLVarName, apiKeyVarName], - optional: new Set([baseURLVarName]), - }); - - apiKey = authValues[apiKeyVarName]; - baseURL = authValues[baseURLVarName]; - } else { - apiKey = apiKeyConfig; - baseURL = baseURLConfig; - } - - const mistralFile = await uploadDocumentToMistral({ - filePath: file.path, - fileName: file.originalname, - apiKey, - baseURL, - }); - - const modelConfig = ocrConfig.mistralModel || ''; - const model = envVarRegex.test(modelConfig) - ? extractEnvVariable(modelConfig) - : modelConfig.trim() || 'mistral-ocr-latest'; - - const signedUrlResponse = await getSignedUrl({ - apiKey, - baseURL, - fileId: mistralFile.id, - }); - - const mimetype = (file.mimetype || '').toLowerCase(); - const originalname = file.originalname || ''; - const isImage = - mimetype.startsWith('image') || /\.(png|jpe?g|gif|bmp|webp|tiff?)$/i.test(originalname); - const documentType = isImage ? 'image_url' : 'document_url'; - - const ocrResult = await performOCR({ - apiKey, - baseURL, - model, - url: signedUrlResponse.url, - documentType, - }); - - let aggregatedText = ''; - const images = []; - ocrResult.pages.forEach((page, index) => { - if (ocrResult.pages.length > 1) { - aggregatedText += `# PAGE ${index + 1}\n`; - } - - aggregatedText += page.markdown + '\n\n'; - - if (page.images && page.images.length > 0) { - page.images.forEach((image) => { - if (image.image_base64) { - images.push(image.image_base64); - } - }); - } - }); - - return { - filename: file.originalname, - bytes: aggregatedText.length * 4, - filepath: FileSources.mistral_ocr, - text: aggregatedText, - images, - }; - } catch (error) { - let message = 'Error uploading document to Mistral OCR API'; - const detail = error?.response?.data?.detail; - if (detail && detail !== '') { - message = detail; - } - - const responseMessage = error?.response?.data?.message; - throw new Error( - `${logAxiosError({ error, message })}${responseMessage && responseMessage !== '' ? ` - ${responseMessage}` : ''}`, - ); - } -}; - -module.exports = { - uploadDocumentToMistral, - uploadMistralOCR, - getSignedUrl, - performOCR, -}; diff --git a/api/server/services/Files/MistralOCR/crud.spec.js b/api/server/services/Files/MistralOCR/crud.spec.js deleted file mode 100644 index 8cc63cade2..0000000000 --- a/api/server/services/Files/MistralOCR/crud.spec.js +++ /dev/null @@ -1,846 +0,0 @@ -const fs = require('fs'); - -const mockAxios = { - interceptors: { - request: { use: jest.fn(), eject: jest.fn() }, - response: { use: jest.fn(), eject: jest.fn() }, - }, - create: jest.fn().mockReturnValue({ - defaults: { - proxy: null, - }, - get: jest.fn().mockResolvedValue({ data: {} }), - post: jest.fn().mockResolvedValue({ data: {} }), - put: jest.fn().mockResolvedValue({ data: {} }), - delete: jest.fn().mockResolvedValue({ data: {} }), - }), - get: jest.fn().mockResolvedValue({ data: {} }), - post: jest.fn().mockResolvedValue({ data: {} }), - put: jest.fn().mockResolvedValue({ data: {} }), - delete: jest.fn().mockResolvedValue({ data: {} }), - reset: jest.fn().mockImplementation(function () { - this.get.mockClear(); - this.post.mockClear(); - this.put.mockClear(); - this.delete.mockClear(); - this.create.mockClear(); - }), -}; - -jest.mock('axios', () => mockAxios); -jest.mock('fs'); -jest.mock('~/config', () => ({ - logger: { - error: jest.fn(), - }, - createAxiosInstance: () => mockAxios, -})); -jest.mock('~/server/services/Tools/credentials', () => ({ - loadAuthValues: jest.fn(), -})); - -const { uploadDocumentToMistral, uploadMistralOCR, getSignedUrl, performOCR } = require('./crud'); - -describe('MistralOCR Service', () => { - afterEach(() => { - mockAxios.reset(); - jest.clearAllMocks(); - }); - - describe('uploadDocumentToMistral', () => { - beforeEach(() => { - // Create a more complete mock for file streams that FormData can work with - const mockReadStream = { - on: jest.fn().mockImplementation(function (event, handler) { - // Simulate immediate 'end' event to make FormData complete processing - if (event === 'end') { - handler(); - } - return this; - }), - pipe: jest.fn().mockImplementation(function () { - return this; - }), - pause: jest.fn(), - resume: jest.fn(), - emit: jest.fn(), - once: jest.fn(), - destroy: jest.fn(), - }; - - fs.createReadStream = jest.fn().mockReturnValue(mockReadStream); - - // Mock FormData's append to avoid actual stream processing - jest.mock('form-data', () => { - const mockFormData = function () { - return { - append: jest.fn(), - getHeaders: jest - .fn() - .mockReturnValue({ 'content-type': 'multipart/form-data; boundary=---boundary' }), - getBuffer: jest.fn().mockReturnValue(Buffer.from('mock-form-data')), - getLength: jest.fn().mockReturnValue(100), - }; - }; - return mockFormData; - }); - }); - - it('should upload a document to Mistral API using file streaming', async () => { - const mockResponse = { data: { id: 'file-123', purpose: 'ocr' } }; - mockAxios.post.mockResolvedValueOnce(mockResponse); - - const result = await uploadDocumentToMistral({ - filePath: '/path/to/test.pdf', - fileName: 'test.pdf', - apiKey: 'test-api-key', - }); - - // Check that createReadStream was called with the correct file path - expect(fs.createReadStream).toHaveBeenCalledWith('/path/to/test.pdf'); - - // Since we're mocking FormData, we'll just check that axios was called correctly - expect(mockAxios.post).toHaveBeenCalledWith( - 'https://api.mistral.ai/v1/files', - expect.anything(), - expect.objectContaining({ - headers: expect.objectContaining({ - Authorization: 'Bearer test-api-key', - }), - maxBodyLength: Infinity, - maxContentLength: Infinity, - }), - ); - expect(result).toEqual(mockResponse.data); - }); - - it('should handle errors during document upload', async () => { - const errorMessage = 'API error'; - mockAxios.post.mockRejectedValueOnce(new Error(errorMessage)); - - await expect( - uploadDocumentToMistral({ - filePath: '/path/to/test.pdf', - fileName: 'test.pdf', - apiKey: 'test-api-key', - }), - ).rejects.toThrow(errorMessage); - }); - }); - - describe('getSignedUrl', () => { - it('should fetch signed URL from Mistral API', async () => { - const mockResponse = { data: { url: 'https://document-url.com' } }; - mockAxios.get.mockResolvedValueOnce(mockResponse); - - const result = await getSignedUrl({ - fileId: 'file-123', - apiKey: 'test-api-key', - }); - - expect(mockAxios.get).toHaveBeenCalledWith( - 'https://api.mistral.ai/v1/files/file-123/url?expiry=24', - { - headers: { - Authorization: 'Bearer test-api-key', - }, - }, - ); - expect(result).toEqual(mockResponse.data); - }); - - it('should handle errors when fetching signed URL', async () => { - const errorMessage = 'API error'; - mockAxios.get.mockRejectedValueOnce(new Error(errorMessage)); - - await expect( - getSignedUrl({ - fileId: 'file-123', - apiKey: 'test-api-key', - }), - ).rejects.toThrow(); - - const { logger } = require('~/config'); - expect(logger.error).toHaveBeenCalledWith('Error fetching signed URL:', errorMessage); - }); - }); - - describe('performOCR', () => { - it('should perform OCR using Mistral API (document_url)', async () => { - const mockResponse = { - data: { - pages: [{ markdown: 'Page 1 content' }, { markdown: 'Page 2 content' }], - }, - }; - mockAxios.post.mockResolvedValueOnce(mockResponse); - - const result = await performOCR({ - apiKey: 'test-api-key', - url: 'https://document-url.com', - model: 'mistral-ocr-latest', - documentType: 'document_url', - }); - - expect(mockAxios.post).toHaveBeenCalledWith( - 'https://api.mistral.ai/v1/ocr', - { - model: 'mistral-ocr-latest', - include_image_base64: false, - document: { - type: 'document_url', - document_url: 'https://document-url.com', - }, - }, - { - headers: { - 'Content-Type': 'application/json', - Authorization: 'Bearer test-api-key', - }, - }, - ); - expect(result).toEqual(mockResponse.data); - }); - - it('should perform OCR using Mistral API (image_url)', async () => { - const mockResponse = { - data: { - pages: [{ markdown: 'Image OCR content' }], - }, - }; - mockAxios.post.mockResolvedValueOnce(mockResponse); - - const result = await performOCR({ - apiKey: 'test-api-key', - url: 'https://image-url.com/image.png', - model: 'mistral-ocr-latest', - documentType: 'image_url', - }); - - expect(mockAxios.post).toHaveBeenCalledWith( - 'https://api.mistral.ai/v1/ocr', - { - model: 'mistral-ocr-latest', - include_image_base64: false, - document: { - type: 'image_url', - image_url: 'https://image-url.com/image.png', - }, - }, - { - headers: { - 'Content-Type': 'application/json', - Authorization: 'Bearer test-api-key', - }, - }, - ); - expect(result).toEqual(mockResponse.data); - }); - - it('should handle errors during OCR processing', async () => { - const errorMessage = 'OCR processing error'; - mockAxios.post.mockRejectedValueOnce(new Error(errorMessage)); - - await expect( - performOCR({ - apiKey: 'test-api-key', - url: 'https://document-url.com', - }), - ).rejects.toThrow(); - - const { logger } = require('~/config'); - expect(logger.error).toHaveBeenCalledWith('Error performing OCR:', errorMessage); - }); - }); - - describe('uploadMistralOCR', () => { - beforeEach(() => { - const mockReadStream = { - on: jest.fn().mockImplementation(function (event, handler) { - if (event === 'end') { - handler(); - } - return this; - }), - pipe: jest.fn().mockImplementation(function () { - return this; - }), - pause: jest.fn(), - resume: jest.fn(), - emit: jest.fn(), - once: jest.fn(), - destroy: jest.fn(), - }; - - fs.createReadStream = jest.fn().mockReturnValue(mockReadStream); - }); - - it('should process OCR for a file with standard configuration', async () => { - // Setup mocks - const { loadAuthValues } = require('~/server/services/Tools/credentials'); - loadAuthValues.mockResolvedValue({ - OCR_API_KEY: 'test-api-key', - OCR_BASEURL: 'https://api.mistral.ai/v1', - }); - - // Mock file upload response - mockAxios.post.mockResolvedValueOnce({ - data: { id: 'file-123', purpose: 'ocr' }, - }); - - // Mock signed URL response - mockAxios.get.mockResolvedValueOnce({ - data: { url: 'https://signed-url.com' }, - }); - - // Mock OCR response with text and images - mockAxios.post.mockResolvedValueOnce({ - data: { - pages: [ - { - markdown: 'Page 1 content', - images: [{ image_base64: 'base64image1' }], - }, - { - markdown: 'Page 2 content', - images: [{ image_base64: 'base64image2' }], - }, - ], - }, - }); - - const req = { - user: { id: 'user123' }, - app: { - locals: { - ocr: { - // Use environment variable syntax to ensure loadAuthValues is called - apiKey: '${OCR_API_KEY}', - baseURL: '${OCR_BASEURL}', - mistralModel: 'mistral-medium', - }, - }, - }, - }; - - const file = { - path: '/tmp/upload/file.pdf', - originalname: 'document.pdf', - mimetype: 'application/pdf', - }; - - const result = await uploadMistralOCR({ - req, - file, - file_id: 'file123', - entity_id: 'entity123', - }); - - expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); - - expect(loadAuthValues).toHaveBeenCalledWith({ - userId: 'user123', - authFields: ['OCR_BASEURL', 'OCR_API_KEY'], - optional: expect.any(Set), - }); - - // Verify OCR result - expect(result).toEqual({ - filename: 'document.pdf', - bytes: expect.any(Number), - filepath: 'mistral_ocr', - text: expect.stringContaining('# PAGE 1'), - images: ['base64image1', 'base64image2'], - }); - }); - - it('should process OCR for an image file and use image_url type', async () => { - const { loadAuthValues } = require('~/server/services/Tools/credentials'); - loadAuthValues.mockResolvedValue({ - OCR_API_KEY: 'test-api-key', - OCR_BASEURL: 'https://api.mistral.ai/v1', - }); - - // Mock file upload response - mockAxios.post.mockResolvedValueOnce({ - data: { id: 'file-456', purpose: 'ocr' }, - }); - - // Mock signed URL response - mockAxios.get.mockResolvedValueOnce({ - data: { url: 'https://signed-url.com/image.png' }, - }); - - // Mock OCR response for image - mockAxios.post.mockResolvedValueOnce({ - data: { - pages: [ - { - markdown: 'Image OCR result', - images: [{ image_base64: 'imgbase64' }], - }, - ], - }, - }); - - const req = { - user: { id: 'user456' }, - app: { - locals: { - ocr: { - apiKey: '${OCR_API_KEY}', - baseURL: '${OCR_BASEURL}', - mistralModel: 'mistral-medium', - }, - }, - }, - }; - - const file = { - path: '/tmp/upload/image.png', - originalname: 'image.png', - mimetype: 'image/png', - }; - - const result = await uploadMistralOCR({ - req, - file, - file_id: 'file456', - entity_id: 'entity456', - }); - - expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/image.png'); - - expect(loadAuthValues).toHaveBeenCalledWith({ - userId: 'user456', - authFields: ['OCR_BASEURL', 'OCR_API_KEY'], - optional: expect.any(Set), - }); - - // Check that the OCR API was called with image_url type - expect(mockAxios.post).toHaveBeenCalledWith( - 'https://api.mistral.ai/v1/ocr', - expect.objectContaining({ - document: expect.objectContaining({ - type: 'image_url', - image_url: 'https://signed-url.com/image.png', - }), - }), - expect.any(Object), - ); - - expect(result).toEqual({ - filename: 'image.png', - bytes: expect.any(Number), - filepath: 'mistral_ocr', - text: expect.stringContaining('Image OCR result'), - images: ['imgbase64'], - }); - }); - - it('should process variable references in configuration', async () => { - // Setup mocks with environment variables - const { loadAuthValues } = require('~/server/services/Tools/credentials'); - loadAuthValues.mockResolvedValue({ - CUSTOM_API_KEY: 'custom-api-key', - CUSTOM_BASEURL: 'https://custom-api.mistral.ai/v1', - }); - - // Mock API responses - mockAxios.post.mockResolvedValueOnce({ - data: { id: 'file-123', purpose: 'ocr' }, - }); - mockAxios.get.mockResolvedValueOnce({ - data: { url: 'https://signed-url.com' }, - }); - mockAxios.post.mockResolvedValueOnce({ - data: { - pages: [{ markdown: 'Content from custom API' }], - }, - }); - - const req = { - user: { id: 'user123' }, - app: { - locals: { - ocr: { - apiKey: '${CUSTOM_API_KEY}', - baseURL: '${CUSTOM_BASEURL}', - mistralModel: '${CUSTOM_MODEL}', - }, - }, - }, - }; - - // Set environment variable for model - process.env.CUSTOM_MODEL = 'mistral-large'; - - const file = { - path: '/tmp/upload/file.pdf', - originalname: 'document.pdf', - }; - - const result = await uploadMistralOCR({ - req, - file, - file_id: 'file123', - entity_id: 'entity123', - }); - - expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); - - // Verify that custom environment variables were extracted and used - expect(loadAuthValues).toHaveBeenCalledWith({ - userId: 'user123', - authFields: ['CUSTOM_BASEURL', 'CUSTOM_API_KEY'], - optional: expect.any(Set), - }); - - // Check that mistral-large was used in the OCR API call - expect(mockAxios.post).toHaveBeenCalledWith( - expect.anything(), - expect.objectContaining({ - model: 'mistral-large', - }), - expect.anything(), - ); - - expect(result.text).toEqual('Content from custom API\n\n'); - }); - - it('should fall back to default values when variables are not properly formatted', async () => { - const { loadAuthValues } = require('~/server/services/Tools/credentials'); - loadAuthValues.mockResolvedValue({ - OCR_API_KEY: 'default-api-key', - OCR_BASEURL: undefined, // Testing optional parameter - }); - - mockAxios.post.mockResolvedValueOnce({ - data: { id: 'file-123', purpose: 'ocr' }, - }); - mockAxios.get.mockResolvedValueOnce({ - data: { url: 'https://signed-url.com' }, - }); - mockAxios.post.mockResolvedValueOnce({ - data: { - pages: [{ markdown: 'Default API result' }], - }, - }); - - const req = { - user: { id: 'user123' }, - app: { - locals: { - ocr: { - // Use environment variable syntax to ensure loadAuthValues is called - apiKey: '${INVALID_FORMAT}', // Using valid env var format but with an invalid name - baseURL: '${OCR_BASEURL}', // Using valid env var format - mistralModel: 'mistral-ocr-latest', // Plain string value - }, - }, - }, - }; - - const file = { - path: '/tmp/upload/file.pdf', - originalname: 'document.pdf', - }; - - await uploadMistralOCR({ - req, - file, - file_id: 'file123', - entity_id: 'entity123', - }); - - expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); - - // Should use the default values - expect(loadAuthValues).toHaveBeenCalledWith({ - userId: 'user123', - authFields: ['OCR_BASEURL', 'INVALID_FORMAT'], - optional: expect.any(Set), - }); - - // Should use the default model when not using environment variable format - expect(mockAxios.post).toHaveBeenCalledWith( - expect.anything(), - expect.objectContaining({ - model: 'mistral-ocr-latest', - }), - expect.anything(), - ); - }); - - it('should handle API errors during OCR process', async () => { - const { loadAuthValues } = require('~/server/services/Tools/credentials'); - loadAuthValues.mockResolvedValue({ - OCR_API_KEY: 'test-api-key', - }); - - // Mock file upload to fail - mockAxios.post.mockRejectedValueOnce(new Error('Upload failed')); - - const req = { - user: { id: 'user123' }, - app: { - locals: { - ocr: { - apiKey: 'OCR_API_KEY', - baseURL: 'OCR_BASEURL', - }, - }, - }, - }; - - const file = { - path: '/tmp/upload/file.pdf', - originalname: 'document.pdf', - }; - - await expect( - uploadMistralOCR({ - req, - file, - file_id: 'file123', - entity_id: 'entity123', - }), - ).rejects.toThrow('Error uploading document to Mistral OCR API'); - expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); - }); - - it('should handle single page documents without page numbering', async () => { - const { loadAuthValues } = require('~/server/services/Tools/credentials'); - loadAuthValues.mockResolvedValue({ - OCR_API_KEY: 'test-api-key', - OCR_BASEURL: 'https://api.mistral.ai/v1', // Make sure this is included - }); - - // Clear all previous mocks - mockAxios.post.mockClear(); - mockAxios.get.mockClear(); - - // 1. First mock: File upload response - mockAxios.post.mockImplementationOnce(() => - Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }), - ); - - // 2. Second mock: Signed URL response - mockAxios.get.mockImplementationOnce(() => - Promise.resolve({ data: { url: 'https://signed-url.com' } }), - ); - - // 3. Third mock: OCR response - mockAxios.post.mockImplementationOnce(() => - Promise.resolve({ - data: { - pages: [{ markdown: 'Single page content' }], - }, - }), - ); - - const req = { - user: { id: 'user123' }, - app: { - locals: { - ocr: { - apiKey: 'OCR_API_KEY', - baseURL: 'OCR_BASEURL', - mistralModel: 'mistral-ocr-latest', - }, - }, - }, - }; - - const file = { - path: '/tmp/upload/file.pdf', - originalname: 'single-page.pdf', - }; - - const result = await uploadMistralOCR({ - req, - file, - file_id: 'file123', - entity_id: 'entity123', - }); - - expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); - - // Verify that single page documents don't include page numbering - expect(result.text).not.toContain('# PAGE'); - expect(result.text).toEqual('Single page content\n\n'); - }); - - it('should use literal values in configuration when provided directly', async () => { - const { loadAuthValues } = require('~/server/services/Tools/credentials'); - // We'll still mock this but it should not be used for literal values - loadAuthValues.mockResolvedValue({}); - - // Clear all previous mocks - mockAxios.post.mockClear(); - mockAxios.get.mockClear(); - - // 1. First mock: File upload response - mockAxios.post.mockImplementationOnce(() => - Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }), - ); - - // 2. Second mock: Signed URL response - mockAxios.get.mockImplementationOnce(() => - Promise.resolve({ data: { url: 'https://signed-url.com' } }), - ); - - // 3. Third mock: OCR response - mockAxios.post.mockImplementationOnce(() => - Promise.resolve({ - data: { - pages: [{ markdown: 'Processed with literal config values' }], - }, - }), - ); - - const req = { - user: { id: 'user123' }, - app: { - locals: { - ocr: { - // Direct values that should be used as-is, without variable substitution - apiKey: 'actual-api-key-value', - baseURL: 'https://direct-api-url.mistral.ai/v1', - mistralModel: 'mistral-direct-model', - }, - }, - }, - }; - - const file = { - path: '/tmp/upload/file.pdf', - originalname: 'direct-values.pdf', - }; - - const result = await uploadMistralOCR({ - req, - file, - file_id: 'file123', - entity_id: 'entity123', - }); - - expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); - - // Verify the correct URL was used with the direct baseURL value - expect(mockAxios.post).toHaveBeenCalledWith( - 'https://direct-api-url.mistral.ai/v1/files', - expect.any(Object), - expect.objectContaining({ - headers: expect.objectContaining({ - Authorization: 'Bearer actual-api-key-value', - }), - }), - ); - - // Check the OCR call was made with the direct model value - expect(mockAxios.post).toHaveBeenCalledWith( - 'https://direct-api-url.mistral.ai/v1/ocr', - expect.objectContaining({ - model: 'mistral-direct-model', - }), - expect.any(Object), - ); - - // Verify the result - expect(result.text).toEqual('Processed with literal config values\n\n'); - - // Verify loadAuthValues was never called since we used direct values - expect(loadAuthValues).not.toHaveBeenCalled(); - }); - - it('should handle empty configuration values and use defaults', async () => { - const { loadAuthValues } = require('~/server/services/Tools/credentials'); - // Set up the mock values to be returned by loadAuthValues - loadAuthValues.mockResolvedValue({ - OCR_API_KEY: 'default-from-env-key', - OCR_BASEURL: 'https://default-from-env.mistral.ai/v1', - }); - - // Clear all previous mocks - mockAxios.post.mockClear(); - mockAxios.get.mockClear(); - - // 1. First mock: File upload response - mockAxios.post.mockImplementationOnce(() => - Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }), - ); - - // 2. Second mock: Signed URL response - mockAxios.get.mockImplementationOnce(() => - Promise.resolve({ data: { url: 'https://signed-url.com' } }), - ); - - // 3. Third mock: OCR response - mockAxios.post.mockImplementationOnce(() => - Promise.resolve({ - data: { - pages: [{ markdown: 'Content from default configuration' }], - }, - }), - ); - - const req = { - user: { id: 'user123' }, - app: { - locals: { - ocr: { - // Empty string values - should fall back to defaults - apiKey: '', - baseURL: '', - mistralModel: '', - }, - }, - }, - }; - - const file = { - path: '/tmp/upload/file.pdf', - originalname: 'empty-config.pdf', - }; - - const result = await uploadMistralOCR({ - req, - file, - file_id: 'file123', - entity_id: 'entity123', - }); - - expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); - - // Verify loadAuthValues was called with the default variable names - expect(loadAuthValues).toHaveBeenCalledWith({ - userId: 'user123', - authFields: ['OCR_BASEURL', 'OCR_API_KEY'], - optional: expect.any(Set), - }); - - // Verify the API calls used the default values from loadAuthValues - expect(mockAxios.post).toHaveBeenCalledWith( - 'https://default-from-env.mistral.ai/v1/files', - expect.any(Object), - expect.objectContaining({ - headers: expect.objectContaining({ - Authorization: 'Bearer default-from-env-key', - }), - }), - ); - - // Verify the OCR model defaulted to mistral-ocr-latest - expect(mockAxios.post).toHaveBeenCalledWith( - 'https://default-from-env.mistral.ai/v1/ocr', - expect.objectContaining({ - model: 'mistral-ocr-latest', - }), - expect.any(Object), - ); - - // Check result - expect(result.text).toEqual('Content from default configuration\n\n'); - }); - }); -}); diff --git a/api/server/services/Files/MistralOCR/index.js b/api/server/services/Files/MistralOCR/index.js deleted file mode 100644 index a6223d1ee5..0000000000 --- a/api/server/services/Files/MistralOCR/index.js +++ /dev/null @@ -1,5 +0,0 @@ -const crud = require('./crud'); - -module.exports = { - ...crud, -}; diff --git a/api/server/services/Files/S3/images.js b/api/server/services/Files/S3/images.js index 378212cb5e..688d5eb68b 100644 --- a/api/server/services/Files/S3/images.js +++ b/api/server/services/Files/S3/images.js @@ -1,11 +1,10 @@ const fs = require('fs'); const path = require('path'); const sharp = require('sharp'); +const { logger } = require('@librechat/data-schemas'); const { resizeImageBuffer } = require('../images/resize'); -const { updateUser } = require('~/models/userMethods'); +const { updateUser, updateFile } = require('~/models'); const { saveBufferToS3 } = require('./crud'); -const { updateFile } = require('~/models/File'); -const { logger } = require('~/config'); const defaultBasePath = 'images'; @@ -95,15 +94,28 @@ async function prepareImageURLS3(req, file) { * @param {Buffer} params.buffer - Avatar image buffer. * @param {string} params.userId - User's unique identifier. * @param {string} params.manual - 'true' or 'false' flag for manual update. + * @param {string} [params.agentId] - Optional agent ID if this is an agent avatar. * @param {string} [params.basePath='images'] - Base path in the bucket. * @returns {Promise} Signed URL of the uploaded avatar. */ -async function processS3Avatar({ buffer, userId, manual, basePath = defaultBasePath }) { +async function processS3Avatar({ buffer, userId, manual, agentId, basePath = defaultBasePath }) { try { - const downloadURL = await saveBufferToS3({ userId, buffer, fileName: 'avatar.png', basePath }); - if (manual === 'true') { + const metadata = await sharp(buffer).metadata(); + const extension = metadata.format === 'gif' ? 'gif' : 'png'; + const timestamp = new Date().getTime(); + + /** Unique filename with timestamp and optional agent ID */ + const fileName = agentId + ? `agent-${agentId}-avatar-${timestamp}.${extension}` + : `avatar-${timestamp}.${extension}`; + + const downloadURL = await saveBufferToS3({ userId, buffer, fileName, basePath }); + + // Only update user record if this is a user avatar (manual === 'true') + if (manual === 'true' && !agentId) { await updateUser(userId, { avatar: downloadURL }); } + return downloadURL; } catch (error) { logger.error('[processS3Avatar] Error processing S3 avatar:', error.message); diff --git a/api/server/services/Files/VectorDB/crud.js b/api/server/services/Files/VectorDB/crud.js index 37a1e81487..d7018f7669 100644 --- a/api/server/services/Files/VectorDB/crud.js +++ b/api/server/services/Files/VectorDB/crud.js @@ -1,9 +1,10 @@ const fs = require('fs'); const axios = require('axios'); const FormData = require('form-data'); +const { logAxiosError } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { FileSources } = require('librechat-data-provider'); -const { logAxiosError } = require('~/utils'); -const { logger } = require('~/config'); +const { generateShortLivedToken } = require('~/server/services/AuthService'); /** * Deletes a file from the vector database. This function takes a file object, constructs the full path, and @@ -23,7 +24,8 @@ const deleteVectors = async (req, file) => { return; } try { - const jwtToken = req.headers.authorization.split(' ')[1]; + const jwtToken = generateShortLivedToken(req.user.id); + return await axios.delete(`${process.env.RAG_API_URL}/documents`, { headers: { Authorization: `Bearer ${jwtToken}`, @@ -70,7 +72,7 @@ async function uploadVectors({ req, file, file_id, entity_id }) { } try { - const jwtToken = req.headers.authorization.split(' ')[1]; + const jwtToken = generateShortLivedToken(req.user.id); const formData = new FormData(); formData.append('file_id', file_id); formData.append('file', fs.createReadStream(file.path)); diff --git a/api/server/services/Files/images/avatar.js b/api/server/services/Files/images/avatar.js index 3c1068a453..8e81dea26c 100644 --- a/api/server/services/Files/images/avatar.js +++ b/api/server/services/Files/images/avatar.js @@ -44,8 +44,25 @@ async function resizeAvatar({ userId, input, desiredFormat = EImageOutputType.PN throw new Error('Invalid input type. Expected URL, Buffer, or File.'); } - const { width, height } = await sharp(imageBuffer).metadata(); + const metadata = await sharp(imageBuffer).metadata(); + const { width, height } = metadata; const minSize = Math.min(width, height); + + if (metadata.format === 'gif') { + const resizedBuffer = await sharp(imageBuffer, { animated: true }) + .extract({ + left: Math.floor((width - minSize) / 2), + top: Math.floor((height - minSize) / 2), + width: minSize, + height: minSize, + }) + .resize(250, 250) + .gif() + .toBuffer(); + + return resizedBuffer; + } + const squaredBuffer = await sharp(imageBuffer) .extract({ left: Math.floor((width - minSize) / 2), diff --git a/api/server/services/Files/images/encode.js b/api/server/services/Files/images/encode.js index 154941fd89..e87654b378 100644 --- a/api/server/services/Files/images/encode.js +++ b/api/server/services/Files/images/encode.js @@ -1,4 +1,5 @@ const axios = require('axios'); +const { logAxiosError } = require('@librechat/api'); const { FileSources, VisionModes, @@ -7,8 +8,6 @@ const { EModelEndpoint, } = require('librechat-data-provider'); const { getStrategyFunctions } = require('~/server/services/Files/strategies'); -const { logAxiosError } = require('~/utils'); -const { logger } = require('~/config'); /** * Converts a readable stream to a base64 encoded string. diff --git a/api/server/services/Files/process.js b/api/server/services/Files/process.js index 94b1bc4dad..38ccdafdd7 100644 --- a/api/server/services/Files/process.js +++ b/api/server/services/Files/process.js @@ -55,7 +55,9 @@ const processFiles = async (files, fileIds) => { } if (!fileIds) { - return await Promise.all(promises); + const results = await Promise.all(promises); + // Filter out null results from failed updateFileUsage calls + return results.filter((result) => result != null); } for (let file_id of fileIds) { @@ -67,7 +69,9 @@ const processFiles = async (files, fileIds) => { } // TODO: calculate token cost when image is first uploaded - return await Promise.all(promises); + const results = await Promise.all(promises); + // Filter out null results from failed updateFileUsage calls + return results.filter((result) => result != null); }; /** @@ -522,7 +526,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => { throw new Error('OCR capability is not enabled for Agents'); } - const { handleFileUpload: uploadMistralOCR } = getStrategyFunctions( + const { handleFileUpload: uploadOCR } = getStrategyFunctions( req.app.locals?.ocr?.strategy ?? FileSources.mistral_ocr, ); const { file_id, temp_file_id } = metadata; @@ -534,7 +538,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => { images, filename, filepath: ocrFileURL, - } = await uploadMistralOCR({ req, file, file_id, entity_id: agent_id, basePath }); + } = await uploadOCR({ req, file, loadAuthValues }); const fileInfo = removeNullishValues({ text, diff --git a/api/server/services/Files/processFiles.test.js b/api/server/services/Files/processFiles.test.js new file mode 100644 index 0000000000..8665d33665 --- /dev/null +++ b/api/server/services/Files/processFiles.test.js @@ -0,0 +1,208 @@ +// Mock the updateFileUsage function before importing the actual processFiles +jest.mock('~/models/File', () => ({ + updateFileUsage: jest.fn(), +})); + +// Mock winston and logger configuration to avoid dependency issues +jest.mock('~/config', () => ({ + logger: { + info: jest.fn(), + warn: jest.fn(), + debug: jest.fn(), + error: jest.fn(), + }, +})); + +// Mock all other dependencies that might cause issues +jest.mock('librechat-data-provider', () => ({ + isUUID: { parse: jest.fn() }, + megabyte: 1024 * 1024, + FileContext: { message_attachment: 'message_attachment' }, + FileSources: { local: 'local' }, + EModelEndpoint: { assistants: 'assistants' }, + EToolResources: { file_search: 'file_search' }, + mergeFileConfig: jest.fn(), + removeNullishValues: jest.fn((obj) => obj), + isAssistantsEndpoint: jest.fn(), +})); + +jest.mock('~/server/services/Files/images', () => ({ + convertImage: jest.fn(), + resizeAndConvert: jest.fn(), + resizeImageBuffer: jest.fn(), +})); + +jest.mock('~/server/controllers/assistants/v2', () => ({ + addResourceFileId: jest.fn(), + deleteResourceFileId: jest.fn(), +})); + +jest.mock('~/models/Agent', () => ({ + addAgentResourceFile: jest.fn(), + removeAgentResourceFiles: jest.fn(), +})); + +jest.mock('~/server/controllers/assistants/helpers', () => ({ + getOpenAIClient: jest.fn(), +})); + +jest.mock('~/server/services/Tools/credentials', () => ({ + loadAuthValues: jest.fn(), +})); + +jest.mock('~/server/services/Config', () => ({ + checkCapability: jest.fn(), +})); + +jest.mock('~/server/utils/queue', () => ({ + LB_QueueAsyncCall: jest.fn(), +})); + +jest.mock('./strategies', () => ({ + getStrategyFunctions: jest.fn(), +})); + +jest.mock('~/server/utils', () => ({ + determineFileType: jest.fn(), +})); + +// Import the actual processFiles function after all mocks are set up +const { processFiles } = require('./process'); +const { updateFileUsage } = require('~/models/File'); + +describe('processFiles', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + describe('null filtering functionality', () => { + it('should filter out null results from updateFileUsage when files do not exist', async () => { + const mockFiles = [ + { file_id: 'existing-file-1' }, + { file_id: 'non-existent-file' }, + { file_id: 'existing-file-2' }, + ]; + + // Mock updateFileUsage to return null for non-existent files + updateFileUsage.mockImplementation(({ file_id }) => { + if (file_id === 'non-existent-file') { + return Promise.resolve(null); // Simulate file not found in the database + } + return Promise.resolve({ file_id, usage: 1 }); + }); + + const result = await processFiles(mockFiles); + + expect(updateFileUsage).toHaveBeenCalledTimes(3); + expect(result).toEqual([ + { file_id: 'existing-file-1', usage: 1 }, + { file_id: 'existing-file-2', usage: 1 }, + ]); + + // Critical test - ensure no null values in result + expect(result).not.toContain(null); + expect(result).not.toContain(undefined); + expect(result.length).toBe(2); // Only valid files should be returned + }); + + it('should return empty array when all updateFileUsage calls return null', async () => { + const mockFiles = [{ file_id: 'non-existent-1' }, { file_id: 'non-existent-2' }]; + + // All updateFileUsage calls return null + updateFileUsage.mockResolvedValue(null); + + const result = await processFiles(mockFiles); + + expect(updateFileUsage).toHaveBeenCalledTimes(2); + expect(result).toEqual([]); + expect(result).not.toContain(null); + expect(result.length).toBe(0); + }); + + it('should work correctly when all files exist', async () => { + const mockFiles = [{ file_id: 'file-1' }, { file_id: 'file-2' }]; + + updateFileUsage.mockImplementation(({ file_id }) => { + return Promise.resolve({ file_id, usage: 1 }); + }); + + const result = await processFiles(mockFiles); + + expect(result).toEqual([ + { file_id: 'file-1', usage: 1 }, + { file_id: 'file-2', usage: 1 }, + ]); + expect(result).not.toContain(null); + expect(result.length).toBe(2); + }); + + it('should handle fileIds parameter and filter nulls correctly', async () => { + const mockFiles = [{ file_id: 'file-1' }]; + const mockFileIds = ['file-2', 'non-existent-file']; + + updateFileUsage.mockImplementation(({ file_id }) => { + if (file_id === 'non-existent-file') { + return Promise.resolve(null); + } + return Promise.resolve({ file_id, usage: 1 }); + }); + + const result = await processFiles(mockFiles, mockFileIds); + + expect(result).toEqual([ + { file_id: 'file-1', usage: 1 }, + { file_id: 'file-2', usage: 1 }, + ]); + expect(result).not.toContain(null); + expect(result).not.toContain(undefined); + expect(result.length).toBe(2); + }); + + it('should handle duplicate file_ids correctly', async () => { + const mockFiles = [ + { file_id: 'duplicate-file' }, + { file_id: 'duplicate-file' }, // Duplicate should be ignored + { file_id: 'unique-file' }, + ]; + + updateFileUsage.mockImplementation(({ file_id }) => { + return Promise.resolve({ file_id, usage: 1 }); + }); + + const result = await processFiles(mockFiles); + + // Should only call updateFileUsage twice (duplicate ignored) + expect(updateFileUsage).toHaveBeenCalledTimes(2); + expect(result).toEqual([ + { file_id: 'duplicate-file', usage: 1 }, + { file_id: 'unique-file', usage: 1 }, + ]); + expect(result.length).toBe(2); + }); + }); + + describe('edge cases', () => { + it('should handle empty files array', async () => { + const result = await processFiles([]); + expect(result).toEqual([]); + expect(updateFileUsage).not.toHaveBeenCalled(); + }); + + it('should handle mixed null and undefined returns from updateFileUsage', async () => { + const mockFiles = [{ file_id: 'file-1' }, { file_id: 'file-2' }, { file_id: 'file-3' }]; + + updateFileUsage.mockImplementation(({ file_id }) => { + if (file_id === 'file-1') return Promise.resolve(null); + if (file_id === 'file-2') return Promise.resolve(undefined); + return Promise.resolve({ file_id, usage: 1 }); + }); + + const result = await processFiles(mockFiles); + + expect(result).toEqual([{ file_id: 'file-3', usage: 1 }]); + expect(result).not.toContain(null); + expect(result).not.toContain(undefined); + expect(result.length).toBe(1); + }); + }); +}); diff --git a/api/server/services/Files/strategies.js b/api/server/services/Files/strategies.js index c6cfe77069..4f8067142b 100644 --- a/api/server/services/Files/strategies.js +++ b/api/server/services/Files/strategies.js @@ -1,4 +1,9 @@ const { FileSources } = require('librechat-data-provider'); +const { + uploadMistralOCR, + uploadAzureMistralOCR, + uploadGoogleVertexMistralOCR, +} = require('@librechat/api'); const { getFirebaseURL, prepareImageURL, @@ -46,7 +51,6 @@ const { const { uploadOpenAIFile, deleteOpenAIFile, getOpenAIFileStream } = require('./OpenAI'); const { getCodeOutputDownloadStream, uploadCodeEnvFile } = require('./Code'); const { uploadVectors, deleteVectors } = require('./VectorDB'); -const { uploadMistralOCR } = require('./MistralOCR'); /** * Firebase Storage Strategy Functions @@ -202,6 +206,46 @@ const mistralOCRStrategy = () => ({ handleFileUpload: uploadMistralOCR, }); +const azureMistralOCRStrategy = () => ({ + /** @type {typeof saveFileFromURL | null} */ + saveURL: null, + /** @type {typeof getLocalFileURL | null} */ + getFileURL: null, + /** @type {typeof saveLocalBuffer | null} */ + saveBuffer: null, + /** @type {typeof processLocalAvatar | null} */ + processAvatar: null, + /** @type {typeof uploadLocalImage | null} */ + handleImageUpload: null, + /** @type {typeof prepareImagesLocal | null} */ + prepareImagePayload: null, + /** @type {typeof deleteLocalFile | null} */ + deleteFile: null, + /** @type {typeof getLocalFileStream | null} */ + getDownloadStream: null, + handleFileUpload: uploadAzureMistralOCR, +}); + +const vertexMistralOCRStrategy = () => ({ + /** @type {typeof saveFileFromURL | null} */ + saveURL: null, + /** @type {typeof getLocalFileURL | null} */ + getFileURL: null, + /** @type {typeof saveLocalBuffer | null} */ + saveBuffer: null, + /** @type {typeof processLocalAvatar | null} */ + processAvatar: null, + /** @type {typeof uploadLocalImage | null} */ + handleImageUpload: null, + /** @type {typeof prepareImagesLocal | null} */ + prepareImagePayload: null, + /** @type {typeof deleteLocalFile | null} */ + deleteFile: null, + /** @type {typeof getLocalFileStream | null} */ + getDownloadStream: null, + handleFileUpload: uploadGoogleVertexMistralOCR, +}); + // Strategy Selector const getStrategyFunctions = (fileSource) => { if (fileSource === FileSources.firebase) { @@ -222,6 +266,10 @@ const getStrategyFunctions = (fileSource) => { return codeOutputStrategy(); } else if (fileSource === FileSources.mistral_ocr) { return mistralOCRStrategy(); + } else if (fileSource === FileSources.azure_mistral_ocr) { + return azureMistralOCRStrategy(); + } else if (fileSource === FileSources.vertexai_mistral_ocr) { + return vertexMistralOCRStrategy(); } else { throw new Error('Invalid file source'); } diff --git a/api/server/services/MCP.js b/api/server/services/MCP.js index b9baef462e..527fe2d514 100644 --- a/api/server/services/MCP.js +++ b/api/server/services/MCP.js @@ -1,27 +1,111 @@ const { z } = require('zod'); const { tool } = require('@langchain/core/tools'); -const { normalizeServerName } = require('librechat-mcp'); -const { Constants: AgentConstants, Providers } = require('@librechat/agents'); +const { logger } = require('@librechat/data-schemas'); +const { Time, CacheKeys, StepTypes } = require('librechat-data-provider'); +const { sendEvent, normalizeServerName, MCPOAuthHandler } = require('@librechat/api'); +const { Constants: AgentConstants, Providers, GraphEvents } = require('@librechat/agents'); const { Constants, ContentTypes, isAssistantsEndpoint, convertJsonSchemaToZod, } = require('librechat-data-provider'); -const { logger, getMCPManager } = require('~/config'); +const { getMCPManager, getFlowStateManager } = require('~/config'); +const { findToken, createToken, updateToken } = require('~/models'); +const { getCachedTools } = require('./Config'); +const { getLogStores } = require('~/cache'); + +/** + * @param {object} params + * @param {ServerResponse} params.res - The Express response object for sending events. + * @param {string} params.stepId - The ID of the step in the flow. + * @param {ToolCallChunk} params.toolCall - The tool call object containing tool information. + * @param {string} params.loginFlowId - The ID of the login flow. + * @param {FlowStateManager} params.flowManager - The flow manager instance. + */ +function createOAuthStart({ res, stepId, toolCall, loginFlowId, flowManager, signal }) { + /** + * Creates a function to handle OAuth login requests. + * @param {string} authURL - The URL to redirect the user for OAuth authentication. + * @returns {Promise} Returns true to indicate the event was sent successfully. + */ + return async function (authURL) { + /** @type {{ id: string; delta: AgentToolCallDelta }} */ + const data = { + id: stepId, + delta: { + type: StepTypes.TOOL_CALLS, + tool_calls: [{ ...toolCall, args: '' }], + auth: authURL, + expires_at: Date.now() + Time.TWO_MINUTES, + }, + }; + /** Used to ensure the handler (use of `sendEvent`) is only invoked once */ + await flowManager.createFlowWithHandler( + loginFlowId, + 'oauth_login', + async () => { + sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data }); + logger.debug('Sent OAuth login request to client'); + return true; + }, + signal, + ); + }; +} + +/** + * @param {object} params + * @param {ServerResponse} params.res - The Express response object for sending events. + * @param {string} params.stepId - The ID of the step in the flow. + * @param {ToolCallChunk} params.toolCall - The tool call object containing tool information. + * @param {string} params.loginFlowId - The ID of the login flow. + * @param {FlowStateManager} params.flowManager - The flow manager instance. + */ +function createOAuthEnd({ res, stepId, toolCall }) { + return async function () { + /** @type {{ id: string; delta: AgentToolCallDelta }} */ + const data = { + id: stepId, + delta: { + type: StepTypes.TOOL_CALLS, + tool_calls: [{ ...toolCall }], + }, + }; + sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data }); + logger.debug('Sent OAuth login success to client'); + }; +} + +/** + * @param {object} params + * @param {string} params.userId - The ID of the user. + * @param {string} params.serverName - The name of the server. + * @param {string} params.toolName - The name of the tool. + * @param {FlowStateManager} params.flowManager - The flow manager instance. + */ +function createAbortHandler({ userId, serverName, toolName, flowManager }) { + return function () { + logger.info(`[MCP][User: ${userId}][${serverName}][${toolName}] Tool call aborted`); + const flowId = MCPOAuthHandler.generateFlowId(userId, serverName); + flowManager.failFlow(flowId, 'mcp_oauth', new Error('Tool call aborted')); + }; +} /** * Creates a general tool for an entire action set. * * @param {Object} params - The parameters for loading action sets. * @param {ServerRequest} params.req - The Express request object, containing user/request info. + * @param {ServerResponse} params.res - The Express response object for sending events. * @param {string} params.toolKey - The toolKey for the tool. * @param {import('@librechat/agents').Providers | EModelEndpoint} params.provider - The provider for the tool. * @param {string} params.model - The model for the tool. * @returns { Promise unknown}> } An object with `_call` method to execute the tool input. */ -async function createMCPTool({ req, toolKey, provider: _provider }) { - const toolDefinition = req.app.locals.availableTools[toolKey]?.function; +async function createMCPTool({ req, res, toolKey, provider: _provider }) { + const availableTools = await getCachedTools({ includeGlobal: true }); + const toolDefinition = availableTools?.[toolKey]?.function; if (!toolDefinition) { logger.error(`Tool ${toolKey} not found in available tools`); return null; @@ -50,19 +134,61 @@ async function createMCPTool({ req, toolKey, provider: _provider }) { /** @type {(toolArguments: Object | string, config?: GraphRunnableConfig) => Promise} */ const _call = async (toolArguments, config) => { + const userId = config?.configurable?.user?.id || config?.configurable?.user_id; + /** @type {ReturnType} */ + let abortHandler = null; + /** @type {AbortSignal} */ + let derivedSignal = null; + try { - const derivedSignal = config?.signal ? AbortSignal.any([config.signal]) : undefined; - const mcpManager = getMCPManager(config?.configurable?.user_id); + const flowsCache = getLogStores(CacheKeys.FLOWS); + const flowManager = getFlowStateManager(flowsCache); + derivedSignal = config?.signal ? AbortSignal.any([config.signal]) : undefined; + const mcpManager = getMCPManager(userId); const provider = (config?.metadata?.provider || _provider)?.toLowerCase(); + + const { args: _args, stepId, ...toolCall } = config.toolCall ?? {}; + const loginFlowId = `${serverName}:oauth_login:${config.metadata.thread_id}:${config.metadata.run_id}`; + const oauthStart = createOAuthStart({ + res, + stepId, + toolCall, + loginFlowId, + flowManager, + signal: derivedSignal, + }); + const oauthEnd = createOAuthEnd({ + res, + stepId, + toolCall, + }); + + if (derivedSignal) { + abortHandler = createAbortHandler({ userId, serverName, toolName, flowManager }); + derivedSignal.addEventListener('abort', abortHandler, { once: true }); + } + + const customUserVars = + config?.configurable?.userMCPAuthMap?.[`${Constants.mcp_prefix}${serverName}`]; + const result = await mcpManager.callTool({ serverName, toolName, provider, toolArguments, options: { - userId: config?.configurable?.user_id, signal: derivedSignal, }, + user: config?.configurable?.user, + customUserVars, + flowManager, + tokenMethods: { + findToken, + createToken, + updateToken, + }, + oauthStart, + oauthEnd, }); if (isAssistantsEndpoint(provider) && Array.isArray(result)) { @@ -74,12 +200,31 @@ async function createMCPTool({ req, toolKey, provider: _provider }) { return result; } catch (error) { logger.error( - `[MCP][User: ${config?.configurable?.user_id}][${serverName}] Error calling "${toolName}" MCP tool:`, + `[MCP][User: ${userId}][${serverName}] Error calling "${toolName}" MCP tool:`, error, ); + + /** OAuth error, provide a helpful message */ + const isOAuthError = + error.message?.includes('401') || + error.message?.includes('OAuth') || + error.message?.includes('authentication') || + error.message?.includes('Non-200 status code (401)'); + + if (isOAuthError) { + throw new Error( + `OAuth authentication required for ${serverName}. Please check the server logs for the authentication URL.`, + ); + } + throw new Error( `"${toolKey}" tool call failed${error?.message ? `: ${error?.message}` : '.'}`, ); + } finally { + // Clean up abort handler to prevent memory leaks + if (abortHandler && derivedSignal) { + derivedSignal.removeEventListener('abort', abortHandler); + } } }; diff --git a/api/server/services/ModelService.js b/api/server/services/ModelService.js index a1ccd7643b..0db13ec318 100644 --- a/api/server/services/ModelService.js +++ b/api/server/services/ModelService.js @@ -1,12 +1,13 @@ const axios = require('axios'); const { Providers } = require('@librechat/agents'); +const { logAxiosError } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { HttpsProxyAgent } = require('https-proxy-agent'); const { EModelEndpoint, defaultModels, CacheKeys } = require('librechat-data-provider'); -const { inputSchema, logAxiosError, extractBaseURL, processModelData } = require('~/utils'); +const { inputSchema, extractBaseURL, processModelData } = require('~/utils'); const { OllamaClient } = require('~/app/clients/OllamaClient'); const { isUserProvided } = require('~/server/utils'); const getLogStores = require('~/cache/getLogStores'); -const { logger } = require('~/config'); /** * Splits a string by commas and trims each resulting value. diff --git a/api/server/services/ModelService.spec.js b/api/server/services/ModelService.spec.js index fb4481f840..33ab9a7aaf 100644 --- a/api/server/services/ModelService.spec.js +++ b/api/server/services/ModelService.spec.js @@ -1,6 +1,6 @@ const axios = require('axios'); +const { logger } = require('@librechat/data-schemas'); const { EModelEndpoint, defaultModels } = require('librechat-data-provider'); -const { logger } = require('~/config'); const { fetchModels, @@ -28,7 +28,8 @@ jest.mock('~/cache/getLogStores', () => set: jest.fn().mockResolvedValue(true), })), ); -jest.mock('~/config', () => ({ +jest.mock('@librechat/data-schemas', () => ({ + ...jest.requireActual('@librechat/data-schemas'), logger: { error: jest.fn(), }, diff --git a/api/server/services/PluginService.js b/api/server/services/PluginService.js index 03e90bce41..af42e0471c 100644 --- a/api/server/services/PluginService.js +++ b/api/server/services/PluginService.js @@ -1,6 +1,6 @@ -const PluginAuth = require('~/models/schema/pluginAuthSchema'); -const { encrypt, decrypt } = require('~/server/utils/'); -const { logger } = require('~/config'); +const { logger } = require('@librechat/data-schemas'); +const { encrypt, decrypt } = require('@librechat/api'); +const { findOnePluginAuth, updatePluginAuth, deletePluginAuth } = require('~/models'); /** * Asynchronously retrieves and decrypts the authentication value for a user's plugin, based on a specified authentication field. @@ -25,7 +25,7 @@ const { logger } = require('~/config'); */ const getUserPluginAuthValue = async (userId, authField, throwError = true) => { try { - const pluginAuth = await PluginAuth.findOne({ userId, authField }).lean(); + const pluginAuth = await findOnePluginAuth({ userId, authField }); if (!pluginAuth) { throw new Error(`No plugin auth ${authField} found for user ${userId}`); } @@ -79,23 +79,12 @@ const getUserPluginAuthValue = async (userId, authField, throwError = true) => { const updateUserPluginAuth = async (userId, authField, pluginKey, value) => { try { const encryptedValue = await encrypt(value); - const pluginAuth = await PluginAuth.findOne({ userId, authField }).lean(); - if (pluginAuth) { - return await PluginAuth.findOneAndUpdate( - { userId, authField }, - { $set: { value: encryptedValue } }, - { new: true, upsert: true }, - ).lean(); - } else { - const newPluginAuth = await new PluginAuth({ - userId, - authField, - value: encryptedValue, - pluginKey, - }); - await newPluginAuth.save(); - return newPluginAuth.toObject(); - } + return await updatePluginAuth({ + userId, + authField, + pluginKey, + value: encryptedValue, + }); } catch (err) { logger.error('[updateUserPluginAuth]', err); return err; @@ -105,26 +94,25 @@ const updateUserPluginAuth = async (userId, authField, pluginKey, value) => { /** * @async * @param {string} userId - * @param {string} authField - * @param {boolean} [all] + * @param {string | null} authField - The specific authField to delete, or null if `all` is true. + * @param {boolean} [all=false] - Whether to delete all auths for the user (or for a specific pluginKey if provided). + * @param {string} [pluginKey] - Optional. If `all` is true and `pluginKey` is provided, delete all auths for this user and pluginKey. * @returns {Promise} * @throws {Error} */ -const deleteUserPluginAuth = async (userId, authField, all = false) => { - if (all) { - try { - const response = await PluginAuth.deleteMany({ userId }); - return response; - } catch (err) { - logger.error('[deleteUserPluginAuth]', err); - return err; - } - } - +const deleteUserPluginAuth = async (userId, authField, all = false, pluginKey) => { try { - return await PluginAuth.deleteOne({ userId, authField }); + return await deletePluginAuth({ + userId, + authField, + pluginKey, + all, + }); } catch (err) { - logger.error('[deleteUserPluginAuth]', err); + logger.error( + `[deleteUserPluginAuth] Error deleting ${all ? 'all' : 'single'} auth(s) for userId: ${userId}${pluginKey ? ` and pluginKey: ${pluginKey}` : ''}`, + err, + ); return err; } }; diff --git a/api/server/services/Runs/StreamRunManager.js b/api/server/services/Runs/StreamRunManager.js index 4bab7326bb..4f6994e0cb 100644 --- a/api/server/services/Runs/StreamRunManager.js +++ b/api/server/services/Runs/StreamRunManager.js @@ -1,3 +1,6 @@ +const { sleep } = require('@librechat/agents'); +const { sendEvent } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { Constants, StepTypes, @@ -8,9 +11,8 @@ const { } = require('librechat-data-provider'); const { retrieveAndProcessFile } = require('~/server/services/Files/process'); const { processRequiredActions } = require('~/server/services/ToolService'); -const { createOnProgress, sendMessage, sleep } = require('~/server/utils'); const { processMessages } = require('~/server/services/Threads'); -const { logger } = require('~/config'); +const { createOnProgress } = require('~/server/utils'); /** * Implements the StreamRunManager functionality for managing the streaming @@ -126,7 +128,7 @@ class StreamRunManager { conversationId: this.finalMessage.conversationId, }; - sendMessage(this.res, contentData); + sendEvent(this.res, contentData); } /* <------------------ Misc. Helpers ------------------> */ @@ -302,7 +304,7 @@ class StreamRunManager { for (const d of delta[key]) { if (typeof d === 'object' && !Object.prototype.hasOwnProperty.call(d, 'index')) { - logger.warn('Expected an object with an \'index\' for array updates but got:', d); + logger.warn("Expected an object with an 'index' for array updates but got:", d); continue; } diff --git a/api/server/services/Runs/methods.js b/api/server/services/Runs/methods.js index 3c18e9969b..167b9cc2ba 100644 --- a/api/server/services/Runs/methods.js +++ b/api/server/services/Runs/methods.js @@ -1,6 +1,6 @@ const axios = require('axios'); +const { logAxiosError } = require('@librechat/api'); const { EModelEndpoint } = require('librechat-data-provider'); -const { logAxiosError } = require('~/utils'); /** * @typedef {Object} RetrieveOptions diff --git a/api/server/services/TokenService.js b/api/server/services/TokenService.js deleted file mode 100644 index 3dd2e79ffa..0000000000 --- a/api/server/services/TokenService.js +++ /dev/null @@ -1,172 +0,0 @@ -const axios = require('axios'); -const { handleOAuthToken } = require('~/models/Token'); -const { decryptV2 } = require('~/server/utils/crypto'); -const { logAxiosError } = require('~/utils'); -const { logger } = require('~/config'); - -/** - * Processes the access tokens and stores them in the database. - * @param {object} tokenData - * @param {string} tokenData.access_token - * @param {number} tokenData.expires_in - * @param {string} [tokenData.refresh_token] - * @param {number} [tokenData.refresh_token_expires_in] - * @param {object} metadata - * @param {string} metadata.userId - * @param {string} metadata.identifier - * @returns {Promise} - */ -async function processAccessTokens(tokenData, { userId, identifier }) { - const { access_token, expires_in = 3600, refresh_token, refresh_token_expires_in } = tokenData; - if (!access_token) { - logger.error('Access token not found: ', tokenData); - throw new Error('Access token not found'); - } - await handleOAuthToken({ - identifier, - token: access_token, - expiresIn: expires_in, - userId, - }); - - if (refresh_token != null) { - logger.debug('Processing refresh token'); - await handleOAuthToken({ - token: refresh_token, - type: 'oauth_refresh', - userId, - identifier: `${identifier}:refresh`, - expiresIn: refresh_token_expires_in ?? null, - }); - } - logger.debug('Access tokens processed'); -} - -/** - * Refreshes the access token using the refresh token. - * @param {object} fields - * @param {string} fields.userId - The ID of the user. - * @param {string} fields.client_url - The URL of the OAuth provider. - * @param {string} fields.identifier - The identifier for the token. - * @param {string} fields.refresh_token - The refresh token to use. - * @param {string} fields.encrypted_oauth_client_id - The client ID for the OAuth provider. - * @param {string} fields.encrypted_oauth_client_secret - The client secret for the OAuth provider. - * @returns {Promise<{ - * access_token: string, - * expires_in: number, - * refresh_token?: string, - * refresh_token_expires_in?: number, - * }>} - */ -const refreshAccessToken = async ({ - userId, - client_url, - identifier, - refresh_token, - encrypted_oauth_client_id, - encrypted_oauth_client_secret, -}) => { - try { - const oauth_client_id = await decryptV2(encrypted_oauth_client_id); - const oauth_client_secret = await decryptV2(encrypted_oauth_client_secret); - const params = new URLSearchParams({ - client_id: oauth_client_id, - client_secret: oauth_client_secret, - grant_type: 'refresh_token', - refresh_token, - }); - - const response = await axios({ - method: 'POST', - url: client_url, - headers: { - 'Content-Type': 'application/x-www-form-urlencoded', - Accept: 'application/json', - }, - data: params.toString(), - }); - await processAccessTokens(response.data, { - userId, - identifier, - }); - logger.debug(`Access token refreshed successfully for ${identifier}`); - return response.data; - } catch (error) { - const message = 'Error refreshing OAuth tokens'; - throw new Error( - logAxiosError({ - message, - error, - }), - ); - } -}; - -/** - * Handles the OAuth callback and exchanges the authorization code for tokens. - * @param {object} fields - * @param {string} fields.code - The authorization code returned by the provider. - * @param {string} fields.userId - The ID of the user. - * @param {string} fields.identifier - The identifier for the token. - * @param {string} fields.client_url - The URL of the OAuth provider. - * @param {string} fields.redirect_uri - The redirect URI for the OAuth provider. - * @param {string} fields.encrypted_oauth_client_id - The client ID for the OAuth provider. - * @param {string} fields.encrypted_oauth_client_secret - The client secret for the OAuth provider. - * @returns {Promise<{ - * access_token: string, - * expires_in: number, - * refresh_token?: string, - * refresh_token_expires_in?: number, - * }>} - */ -const getAccessToken = async ({ - code, - userId, - identifier, - client_url, - redirect_uri, - encrypted_oauth_client_id, - encrypted_oauth_client_secret, -}) => { - const oauth_client_id = await decryptV2(encrypted_oauth_client_id); - const oauth_client_secret = await decryptV2(encrypted_oauth_client_secret); - const params = new URLSearchParams({ - code, - client_id: oauth_client_id, - client_secret: oauth_client_secret, - grant_type: 'authorization_code', - redirect_uri, - }); - - try { - const response = await axios({ - method: 'POST', - url: client_url, - headers: { - 'Content-Type': 'application/x-www-form-urlencoded', - Accept: 'application/json', - }, - data: params.toString(), - }); - - await processAccessTokens(response.data, { - userId, - identifier, - }); - logger.debug(`Access tokens successfully created for ${identifier}`); - return response.data; - } catch (error) { - const message = 'Error exchanging OAuth code'; - throw new Error( - logAxiosError({ - message, - error, - }), - ); - } -}; - -module.exports = { - getAccessToken, - refreshAccessToken, -}; diff --git a/api/server/services/ToolService.js b/api/server/services/ToolService.js index 8dd2fbf865..f1567a3783 100644 --- a/api/server/services/ToolService.js +++ b/api/server/services/ToolService.js @@ -1,5 +1,7 @@ const fs = require('fs'); const path = require('path'); +const { sleep } = require('@librechat/agents'); +const { logger } = require('@librechat/data-schemas'); const { zodToJsonSchema } = require('zod-to-json-schema'); const { Calculator } = require('@langchain/community/tools/calculator'); const { tool: toolFn, Tool, DynamicStructuredTool } = require('@langchain/core/tools'); @@ -31,14 +33,12 @@ const { toolkits, } = require('~/app/clients/tools'); const { processFileURL, uploadImageBuffer } = require('~/server/services/Files/process'); +const { getEndpointsConfig, getCachedTools } = require('~/server/services/Config'); const { createOnSearchResults } = require('~/server/services/Tools/search'); const { isActionDomainAllowed } = require('~/server/services/domains'); -const { getEndpointsConfig } = require('~/server/services/Config'); const { recordUsage } = require('~/server/services/Threads'); const { loadTools } = require('~/app/clients/tools/util'); const { redactMessage } = require('~/config/parsers'); -const { sleep } = require('~/server/utils'); -const { logger } = require('~/config'); /** * @param {string} toolName @@ -226,7 +226,7 @@ async function processRequiredActions(client, requiredActions) { `[required actions] user: ${client.req.user.id} | thread_id: ${requiredActions[0].thread_id} | run_id: ${requiredActions[0].run_id}`, requiredActions, ); - const toolDefinitions = client.req.app.locals.availableTools; + const toolDefinitions = await getCachedTools({ includeGlobal: true }); const seenToolkits = new Set(); const tools = requiredActions .map((action) => { @@ -500,6 +500,8 @@ async function processRequiredActions(client, requiredActions) { async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey }) { if (!agent.tools || agent.tools.length === 0) { return {}; + } else if (agent.tools && agent.tools.length === 1 && agent.tools[0] === AgentCapabilities.ocr) { + return {}; } const endpointsConfig = await getEndpointsConfig(req); @@ -551,6 +553,7 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey }) tools: _agentTools, options: { req, + res, openAIApiKey, tool_resources, processFileURL, diff --git a/api/server/services/UserService.js b/api/server/services/UserService.js index 91d772477b..7cf2f832a3 100644 --- a/api/server/services/UserService.js +++ b/api/server/services/UserService.js @@ -1,7 +1,8 @@ +const { logger } = require('@librechat/data-schemas'); +const { encrypt, decrypt } = require('@librechat/api'); const { ErrorTypes } = require('librechat-data-provider'); -const { encrypt, decrypt } = require('~/server/utils'); -const { updateUser, Key } = require('~/models'); -const { logger } = require('~/config'); +const { updateUser } = require('~/models'); +const { Key } = require('~/db/models'); /** * Updates the plugins for a user based on the action specified (install/uninstall). @@ -69,6 +70,7 @@ const getUserKeyValues = async ({ userId, name }) => { try { userValues = JSON.parse(userValues); } catch (e) { + logger.error('[getUserKeyValues]', e); throw new Error( JSON.stringify({ type: ErrorTypes.INVALID_USER_KEY, diff --git a/api/server/services/initializeMCP.js b/api/server/services/initializeMCP.js new file mode 100644 index 0000000000..98b87d156e --- /dev/null +++ b/api/server/services/initializeMCP.js @@ -0,0 +1,53 @@ +const { logger } = require('@librechat/data-schemas'); +const { CacheKeys } = require('librechat-data-provider'); +const { findToken, updateToken, createToken, deleteTokens } = require('~/models'); +const { getMCPManager, getFlowStateManager } = require('~/config'); +const { getCachedTools, setCachedTools } = require('./Config'); +const { getLogStores } = require('~/cache'); + +/** + * Initialize MCP servers + * @param {import('express').Application} app - Express app instance + */ +async function initializeMCP(app) { + const mcpServers = app.locals.mcpConfig; + if (!mcpServers) { + return; + } + + logger.info('Initializing MCP servers...'); + const mcpManager = getMCPManager(); + const flowsCache = getLogStores(CacheKeys.FLOWS); + const flowManager = flowsCache ? getFlowStateManager(flowsCache) : null; + + try { + await mcpManager.initializeMCP({ + mcpServers, + flowManager, + tokenMethods: { + findToken, + updateToken, + createToken, + deleteTokens, + }, + }); + + delete app.locals.mcpConfig; + const availableTools = await getCachedTools(); + + if (!availableTools) { + logger.warn('No available tools found in cache during MCP initialization'); + return; + } + + const toolsCopy = { ...availableTools }; + await mcpManager.mapAvailableTools(toolsCopy, flowManager); + await setCachedTools(toolsCopy, { isGlobal: true }); + + logger.info('MCP servers initialized successfully'); + } catch (error) { + logger.error('Failed to initialize MCP servers:', error); + } +} + +module.exports = initializeMCP; diff --git a/api/server/services/signPayload.js b/api/server/services/signPayload.js deleted file mode 100644 index a7bb0c64fc..0000000000 --- a/api/server/services/signPayload.js +++ /dev/null @@ -1,26 +0,0 @@ -const jwt = require('jsonwebtoken'); - -/** - * Signs a given payload using either the `jose` library (for Bun runtime) or `jsonwebtoken`. - * - * @async - * @function - * @param {Object} options - The options for signing the payload. - * @param {Object} options.payload - The payload to be signed. - * @param {string} options.secret - The secret key used for signing. - * @param {number} options.expirationTime - The expiration time in seconds. - * @returns {Promise} Returns a promise that resolves to the signed JWT. - * @throws {Error} Throws an error if there's an issue during signing. - * - * @example - * const signedPayload = await signPayload({ - * payload: { userId: 123 }, - * secret: 'my-secret-key', - * expirationTime: 3600 - * }); - */ -async function signPayload({ payload, secret, expirationTime }) { - return jwt.sign(payload, secret, { expiresIn: expirationTime }); -} - -module.exports = signPayload; diff --git a/api/server/services/start/agents.js b/api/server/services/start/agents.js deleted file mode 100644 index 10653f3fb6..0000000000 --- a/api/server/services/start/agents.js +++ /dev/null @@ -1,14 +0,0 @@ -const { EModelEndpoint, agentsEndpointSChema } = require('librechat-data-provider'); - -/** - * Sets up the Agents configuration from the config (`librechat.yaml`) file. - * @param {TCustomConfig} config - The loaded custom configuration. - * @returns {Partial} The Agents endpoint configuration. - */ -function agentsConfigSetup(config) { - const agentsConfig = config.endpoints[EModelEndpoint.agents]; - const parsedConfig = agentsEndpointSChema.parse(agentsConfig); - return parsedConfig; -} - -module.exports = { agentsConfigSetup }; diff --git a/api/server/services/start/interface.js b/api/server/services/start/interface.js index 7578c036b2..5c08b1af2e 100644 --- a/api/server/services/start/interface.js +++ b/api/server/services/start/interface.js @@ -2,6 +2,7 @@ const { SystemRoles, Permissions, PermissionTypes, + isMemoryEnabled, removeNullishValues, } = require('librechat-data-provider'); const { updateAccessPermissions } = require('~/models/Role'); @@ -20,6 +21,14 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol const hasModelSpecs = config?.modelSpecs?.list?.length > 0; const includesAddedEndpoints = config?.modelSpecs?.addedEndpoints?.length > 0; + const memoryConfig = config?.memory; + const memoryEnabled = isMemoryEnabled(memoryConfig); + /** Only disable memories if memory config is present but disabled/invalid */ + const shouldDisableMemories = memoryConfig && !memoryEnabled; + /** Check if personalization is enabled (defaults to true if memory is configured and enabled) */ + const isPersonalizationEnabled = + memoryConfig && memoryEnabled && memoryConfig.personalize !== false; + /** @type {TCustomConfig['interface']} */ const loadedInterface = removeNullishValues({ endpointsMenu: @@ -32,7 +41,9 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol sidePanel: interfaceConfig?.sidePanel ?? defaults.sidePanel, privacyPolicy: interfaceConfig?.privacyPolicy ?? defaults.privacyPolicy, termsOfService: interfaceConfig?.termsOfService ?? defaults.termsOfService, + mcpServers: interfaceConfig?.mcpServers ?? defaults.mcpServers, bookmarks: interfaceConfig?.bookmarks ?? defaults.bookmarks, + memories: shouldDisableMemories ? false : (interfaceConfig?.memories ?? defaults.memories), prompts: interfaceConfig?.prompts ?? defaults.prompts, multiConvo: interfaceConfig?.multiConvo ?? defaults.multiConvo, agents: interfaceConfig?.agents ?? defaults.agents, @@ -45,6 +56,10 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol await updateAccessPermissions(roleName, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: loadedInterface.prompts }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: loadedInterface.bookmarks }, + [PermissionTypes.MEMORIES]: { + [Permissions.USE]: loadedInterface.memories, + [Permissions.OPT_OUT]: isPersonalizationEnabled, + }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: loadedInterface.multiConvo }, [PermissionTypes.AGENTS]: { [Permissions.USE]: loadedInterface.agents }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: loadedInterface.temporaryChat }, @@ -54,6 +69,10 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol await updateAccessPermissions(SystemRoles.ADMIN, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: loadedInterface.prompts }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: loadedInterface.bookmarks }, + [PermissionTypes.MEMORIES]: { + [Permissions.USE]: loadedInterface.memories, + [Permissions.OPT_OUT]: isPersonalizationEnabled, + }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: loadedInterface.multiConvo }, [PermissionTypes.AGENTS]: { [Permissions.USE]: loadedInterface.agents }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: loadedInterface.temporaryChat }, diff --git a/api/server/services/start/interface.spec.js b/api/server/services/start/interface.spec.js index d0dcfaf55f..1a05c9cf12 100644 --- a/api/server/services/start/interface.spec.js +++ b/api/server/services/start/interface.spec.js @@ -12,6 +12,7 @@ describe('loadDefaultInterface', () => { interface: { prompts: true, bookmarks: true, + memories: true, multiConvo: true, agents: true, temporaryChat: true, @@ -26,6 +27,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: true }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: true }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true }, [PermissionTypes.AGENTS]: { [Permissions.USE]: true }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true }, @@ -39,6 +41,7 @@ describe('loadDefaultInterface', () => { interface: { prompts: false, bookmarks: false, + memories: false, multiConvo: false, agents: false, temporaryChat: false, @@ -53,6 +56,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: false }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: false }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: false }, [PermissionTypes.AGENTS]: { [Permissions.USE]: false }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: false }, @@ -70,6 +74,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: undefined }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: undefined }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, @@ -83,6 +88,7 @@ describe('loadDefaultInterface', () => { interface: { prompts: undefined, bookmarks: undefined, + memories: undefined, multiConvo: undefined, agents: undefined, temporaryChat: undefined, @@ -97,6 +103,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: undefined }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: undefined }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, @@ -110,6 +117,7 @@ describe('loadDefaultInterface', () => { interface: { prompts: true, bookmarks: false, + memories: true, multiConvo: undefined, agents: true, temporaryChat: undefined, @@ -124,6 +132,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: true }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: true }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined }, [PermissionTypes.AGENTS]: { [Permissions.USE]: true }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, @@ -138,6 +147,7 @@ describe('loadDefaultInterface', () => { interface: { prompts: true, bookmarks: true, + memories: true, multiConvo: true, agents: true, temporaryChat: true, @@ -151,6 +161,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: true }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: true }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true }, [PermissionTypes.AGENTS]: { [Permissions.USE]: true }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true }, @@ -168,6 +179,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: undefined }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: undefined }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, @@ -185,6 +197,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: undefined }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: undefined }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: false }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, @@ -202,6 +215,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: undefined }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: undefined }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, @@ -215,6 +229,7 @@ describe('loadDefaultInterface', () => { interface: { prompts: true, bookmarks: false, + memories: true, multiConvo: true, agents: false, temporaryChat: true, @@ -228,6 +243,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: true }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: true }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true }, [PermissionTypes.AGENTS]: { [Permissions.USE]: false }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true }, @@ -242,6 +258,7 @@ describe('loadDefaultInterface', () => { interface: { prompts: true, bookmarks: true, + memories: false, multiConvo: false, agents: undefined, temporaryChat: undefined, @@ -255,6 +272,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: true }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: false }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: false }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, @@ -268,6 +286,7 @@ describe('loadDefaultInterface', () => { interface: { prompts: true, bookmarks: false, + memories: true, multiConvo: true, agents: false, temporaryChat: true, @@ -281,6 +300,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: true }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: true }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true }, [PermissionTypes.AGENTS]: { [Permissions.USE]: false }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true }, diff --git a/api/server/services/twoFactorService.js b/api/server/services/twoFactorService.js index d000c8fcfc..4ac86a5549 100644 --- a/api/server/services/twoFactorService.js +++ b/api/server/services/twoFactorService.js @@ -1,6 +1,6 @@ const { webcrypto } = require('node:crypto'); -const { decryptV3, decryptV2 } = require('../utils/crypto'); -const { hashBackupCode } = require('~/server/utils/crypto'); +const { hashBackupCode, decryptV3, decryptV2 } = require('@librechat/api'); +const { updateUser } = require('~/models'); // Base32 alphabet for TOTP secret encoding. const BASE32_ALPHABET = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567'; @@ -172,7 +172,6 @@ const verifyBackupCode = async ({ user, backupCode }) => { : codeObj, ); // Update the user record with the marked backup code. - const { updateUser } = require('~/models'); await updateUser(user._id, { backupCodes: updatedBackupCodes }); return true; } diff --git a/api/server/utils/emails/passwordReset.handlebars b/api/server/utils/emails/passwordReset.handlebars index 9076b92edb..6b735f53fd 100644 --- a/api/server/utils/emails/passwordReset.handlebars +++ b/api/server/utils/emails/passwordReset.handlebars @@ -22,17 +22,71 @@ diff --git a/api/server/utils/emails/requestPasswordReset.handlebars b/api/server/utils/emails/requestPasswordReset.handlebars index 2600b5a9d3..b7005254ba 100644 --- a/api/server/utils/emails/requestPasswordReset.handlebars +++ b/api/server/utils/emails/requestPasswordReset.handlebars @@ -22,18 +22,78 @@ diff --git a/api/server/utils/emails/verifyEmail.handlebars b/api/server/utils/emails/verifyEmail.handlebars index 63b52e79be..fa77575053 100644 --- a/api/server/utils/emails/verifyEmail.handlebars +++ b/api/server/utils/emails/verifyEmail.handlebars @@ -22,18 +22,75 @@ diff --git a/api/server/utils/handleText.js b/api/server/utils/handleText.js index 86c17f1dda..36671c44ff 100644 --- a/api/server/utils/handleText.js +++ b/api/server/utils/handleText.js @@ -1,5 +1,3 @@ -const path = require('path'); -const crypto = require('crypto'); const { Capabilities, EModelEndpoint, @@ -9,9 +7,9 @@ const { defaultAssistantsVersion, defaultAgentCapabilities, } = require('librechat-data-provider'); +const { sendEvent } = require('@librechat/api'); const { Providers } = require('@librechat/agents'); const partialRight = require('lodash/partialRight'); -const { sendMessage } = require('./streamResponse'); /** Helper function to escape special characters in regex * @param {string} string - The string to escape. @@ -39,7 +37,7 @@ const createOnProgress = ( basePayload.text = basePayload.text + chunk; const payload = Object.assign({}, basePayload, rest); - sendMessage(res, payload); + sendEvent(res, payload); if (_onProgress) { _onProgress(payload); } @@ -52,7 +50,7 @@ const createOnProgress = ( const sendIntermediateMessage = (res, payload, extraTokens = '') => { basePayload.text = basePayload.text + extraTokens; const message = Object.assign({}, basePayload, payload); - sendMessage(res, message); + sendEvent(res, message); if (i === 0) { basePayload.initial = false; } @@ -218,38 +216,6 @@ function normalizeEndpointName(name = '') { return name.toLowerCase() === Providers.OLLAMA ? Providers.OLLAMA : name; } -/** - * Sanitize a filename by removing any directory components, replacing non-alphanumeric characters - * @param {string} inputName - * @returns {string} - */ -function sanitizeFilename(inputName) { - // Remove any directory components - let name = path.basename(inputName); - - // Replace any non-alphanumeric characters except for '.' and '-' - name = name.replace(/[^a-zA-Z0-9.-]/g, '_'); - - // Ensure the name doesn't start with a dot (hidden file in Unix-like systems) - if (name.startsWith('.') || name === '') { - name = '_' + name; - } - - // Limit the length of the filename - const MAX_LENGTH = 255; - if (name.length > MAX_LENGTH) { - const ext = path.extname(name); - const nameWithoutExt = path.basename(name, ext); - name = - nameWithoutExt.slice(0, MAX_LENGTH - ext.length - 7) + - '-' + - crypto.randomBytes(3).toString('hex') + - ext; - } - - return name; -} - module.exports = { isEnabled, handleText, @@ -260,6 +226,5 @@ module.exports = { generateConfig, addSpaceIfNeeded, createOnProgress, - sanitizeFilename, normalizeEndpointName, }; diff --git a/api/server/utils/handleText.spec.js b/api/server/utils/handleText.spec.js deleted file mode 100644 index 8b1b6eef8d..0000000000 --- a/api/server/utils/handleText.spec.js +++ /dev/null @@ -1,99 +0,0 @@ -const { isEnabled, sanitizeFilename } = require('./handleText'); - -describe('isEnabled', () => { - test('should return true when input is "true"', () => { - expect(isEnabled('true')).toBe(true); - }); - - test('should return true when input is "TRUE"', () => { - expect(isEnabled('TRUE')).toBe(true); - }); - - test('should return true when input is true', () => { - expect(isEnabled(true)).toBe(true); - }); - - test('should return false when input is "false"', () => { - expect(isEnabled('false')).toBe(false); - }); - - test('should return false when input is false', () => { - expect(isEnabled(false)).toBe(false); - }); - - test('should return false when input is null', () => { - expect(isEnabled(null)).toBe(false); - }); - - test('should return false when input is undefined', () => { - expect(isEnabled()).toBe(false); - }); - - test('should return false when input is an empty string', () => { - expect(isEnabled('')).toBe(false); - }); - - test('should return false when input is a whitespace string', () => { - expect(isEnabled(' ')).toBe(false); - }); - - test('should return false when input is a number', () => { - expect(isEnabled(123)).toBe(false); - }); - - test('should return false when input is an object', () => { - expect(isEnabled({})).toBe(false); - }); - - test('should return false when input is an array', () => { - expect(isEnabled([])).toBe(false); - }); -}); - -jest.mock('crypto', () => ({ - randomBytes: jest.fn().mockReturnValue(Buffer.from('abc123', 'hex')), -})); - -describe('sanitizeFilename', () => { - test('removes directory components (1/2)', () => { - expect(sanitizeFilename('/path/to/file.txt')).toBe('file.txt'); - }); - - test('removes directory components (2/2)', () => { - expect(sanitizeFilename('../../../../file.txt')).toBe('file.txt'); - }); - - test('replaces non-alphanumeric characters', () => { - expect(sanitizeFilename('file name@#$.txt')).toBe('file_name___.txt'); - }); - - test('preserves dots and hyphens', () => { - expect(sanitizeFilename('file-name.with.dots.txt')).toBe('file-name.with.dots.txt'); - }); - - test('prepends underscore to filenames starting with a dot', () => { - expect(sanitizeFilename('.hiddenfile')).toBe('_.hiddenfile'); - }); - - test('truncates long filenames', () => { - const longName = 'a'.repeat(300) + '.txt'; - const result = sanitizeFilename(longName); - expect(result.length).toBe(255); - expect(result).toMatch(/^a+-abc123\.txt$/); - }); - - test('handles filenames with no extension', () => { - const longName = 'a'.repeat(300); - const result = sanitizeFilename(longName); - expect(result.length).toBe(255); - expect(result).toMatch(/^a+-abc123$/); - }); - - test('handles empty input', () => { - expect(sanitizeFilename('')).toBe('_'); - }); - - test('handles input with only special characters', () => { - expect(sanitizeFilename('@#$%^&*')).toBe('_______'); - }); -}); diff --git a/api/server/utils/index.js b/api/server/utils/index.js index b79b42f00d..2672f4f2ea 100644 --- a/api/server/utils/index.js +++ b/api/server/utils/index.js @@ -1,35 +1,36 @@ -const streamResponse = require('./streamResponse'); const removePorts = require('./removePorts'); const countTokens = require('./countTokens'); const handleText = require('./handleText'); const sendEmail = require('./sendEmail'); -const cryptoUtils = require('./crypto'); const queue = require('./queue'); const files = require('./files'); -const math = require('./math'); /** * Check if email configuration is set * @returns {Boolean} */ function checkEmailConfig() { - return ( + // Check if Mailgun is configured + const hasMailgunConfig = + !!process.env.MAILGUN_API_KEY && !!process.env.MAILGUN_DOMAIN && !!process.env.EMAIL_FROM; + + // Check if SMTP is configured + const hasSMTPConfig = (!!process.env.EMAIL_SERVICE || !!process.env.EMAIL_HOST) && !!process.env.EMAIL_USERNAME && !!process.env.EMAIL_PASSWORD && - !!process.env.EMAIL_FROM - ); + !!process.env.EMAIL_FROM; + + // Return true if either Mailgun or SMTP is properly configured + return hasMailgunConfig || hasSMTPConfig; } module.exports = { - ...streamResponse, checkEmailConfig, - ...cryptoUtils, ...handleText, countTokens, removePorts, sendEmail, ...files, ...queue, - math, }; diff --git a/api/server/utils/sendEmail.js b/api/server/utils/sendEmail.js index 59d75830f4..c0afd0eebe 100644 --- a/api/server/utils/sendEmail.js +++ b/api/server/utils/sendEmail.js @@ -1,9 +1,69 @@ const fs = require('fs'); const path = require('path'); +const axios = require('axios'); +const FormData = require('form-data'); const nodemailer = require('nodemailer'); const handlebars = require('handlebars'); +const { logAxiosError } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { isEnabled } = require('~/server/utils/handleText'); -const logger = require('~/config/winston'); + +/** + * Sends an email using Mailgun API. + * + * @async + * @function sendEmailViaMailgun + * @param {Object} params - The parameters for sending the email. + * @param {string} params.to - The recipient's email address. + * @param {string} params.from - The sender's email address. + * @param {string} params.subject - The subject of the email. + * @param {string} params.html - The HTML content of the email. + * @returns {Promise} - A promise that resolves to the response from Mailgun API. + */ +const sendEmailViaMailgun = async ({ to, from, subject, html }) => { + const mailgunApiKey = process.env.MAILGUN_API_KEY; + const mailgunDomain = process.env.MAILGUN_DOMAIN; + const mailgunHost = process.env.MAILGUN_HOST || 'https://api.mailgun.net'; + + if (!mailgunApiKey || !mailgunDomain) { + throw new Error('Mailgun API key and domain are required'); + } + + const formData = new FormData(); + formData.append('from', from); + formData.append('to', to); + formData.append('subject', subject); + formData.append('html', html); + formData.append('o:tracking-clicks', 'no'); + + try { + const response = await axios.post(`${mailgunHost}/v3/${mailgunDomain}/messages`, formData, { + headers: { + ...formData.getHeaders(), + Authorization: `Basic ${Buffer.from(`api:${mailgunApiKey}`).toString('base64')}`, + }, + }); + + return response.data; + } catch (error) { + throw new Error(logAxiosError({ error, message: 'Failed to send email via Mailgun' })); + } +}; + +/** + * Sends an email using SMTP via Nodemailer. + * + * @async + * @function sendEmailViaSMTP + * @param {Object} params - The parameters for sending the email. + * @param {Object} params.transporterOptions - The transporter configuration options. + * @param {Object} params.mailOptions - The email options. + * @returns {Promise} - A promise that resolves to the info object of the sent email. + */ +const sendEmailViaSMTP = async ({ transporterOptions, mailOptions }) => { + const transporter = nodemailer.createTransport(transporterOptions); + return await transporter.sendMail(mailOptions); +}; /** * Sends an email using the specified template, subject, and payload. @@ -34,6 +94,30 @@ const logger = require('~/config/winston'); */ const sendEmail = async ({ email, subject, payload, template, throwError = true }) => { try { + // Read and compile the email template + const source = fs.readFileSync(path.join(__dirname, 'emails', template), 'utf8'); + const compiledTemplate = handlebars.compile(source); + const html = compiledTemplate(payload); + + // Prepare common email data + const fromName = process.env.EMAIL_FROM_NAME || process.env.APP_TITLE; + const fromEmail = process.env.EMAIL_FROM; + const fromAddress = `"${fromName}" <${fromEmail}>`; + const toAddress = `"${payload.name}" <${email}>`; + + // Check if Mailgun is configured + if (process.env.MAILGUN_API_KEY && process.env.MAILGUN_DOMAIN) { + logger.debug('[sendEmail] Using Mailgun provider'); + return await sendEmailViaMailgun({ + from: fromAddress, + to: toAddress, + subject: subject, + html: html, + }); + } + + // Default to SMTP + logger.debug('[sendEmail] Using SMTP provider'); const transporterOptions = { // Use STARTTLS by default instead of obligatory TLS secure: process.env.EMAIL_ENCRYPTION === 'tls', @@ -62,30 +146,21 @@ const sendEmail = async ({ email, subject, payload, template, throwError = true transporterOptions.port = process.env.EMAIL_PORT ?? 25; } - const transporter = nodemailer.createTransport(transporterOptions); - - const source = fs.readFileSync(path.join(__dirname, 'emails', template), 'utf8'); - const compiledTemplate = handlebars.compile(source); - const options = () => { - return { - // Header address should contain name-addr - from: - `"${process.env.EMAIL_FROM_NAME || process.env.APP_TITLE}"` + - `<${process.env.EMAIL_FROM}>`, - to: `"${payload.name}" <${email}>`, - envelope: { - // Envelope from should contain addr-spec - // Mistake in the Nodemailer documentation? - from: process.env.EMAIL_FROM, - to: email, - }, - subject: subject, - html: compiledTemplate(payload), - }; + const mailOptions = { + // Header address should contain name-addr + from: fromAddress, + to: toAddress, + envelope: { + // Envelope from should contain addr-spec + // Mistake in the Nodemailer documentation? + from: fromEmail, + to: email, + }, + subject: subject, + html: html, }; - // Send email - return await transporter.sendMail(options()); + return await sendEmailViaSMTP({ transporterOptions, mailOptions }); } catch (error) { if (throwError) { throw error; diff --git a/api/strategies/appleStrategy.js b/api/strategies/appleStrategy.js index a45f10fc62..4dbac2e364 100644 --- a/api/strategies/appleStrategy.js +++ b/api/strategies/appleStrategy.js @@ -18,17 +18,13 @@ const getProfileDetails = ({ idToken, profile }) => { const decoded = jwt.decode(idToken); - logger.debug( - `Decoded Apple JWT: ${JSON.stringify(decoded, null, 2)}`, - ); + logger.debug(`Decoded Apple JWT: ${JSON.stringify(decoded, null, 2)}`); return { email: decoded.email, id: decoded.sub, avatarUrl: null, // Apple does not provide an avatar URL - username: decoded.email - ? decoded.email.split('@')[0].toLowerCase() - : `user_${decoded.sub}`, + username: decoded.email ? decoded.email.split('@')[0].toLowerCase() : `user_${decoded.sub}`, name: decoded.name ? `${decoded.name.firstName} ${decoded.name.lastName}` : profile.displayName || null, diff --git a/api/strategies/appleStrategy.test.js b/api/strategies/appleStrategy.test.js index c457e15fdc..65a961bd4d 100644 --- a/api/strategies/appleStrategy.test.js +++ b/api/strategies/appleStrategy.test.js @@ -1,22 +1,25 @@ -const mongoose = require('mongoose'); -const { MongoMemoryServer } = require('mongodb-memory-server'); const jwt = require('jsonwebtoken'); +const mongoose = require('mongoose'); +const { logger } = require('@librechat/data-schemas'); const { Strategy: AppleStrategy } = require('passport-apple'); -const socialLogin = require('./socialLogin'); -const User = require('~/models/User'); -const { logger } = require('~/config'); +const { MongoMemoryServer } = require('mongodb-memory-server'); const { createSocialUser, handleExistingUser } = require('./process'); const { isEnabled } = require('~/server/utils'); +const socialLogin = require('./socialLogin'); const { findUser } = require('~/models'); +const { User } = require('~/db/models'); -// Mocking external dependencies jest.mock('jsonwebtoken'); -jest.mock('~/config', () => ({ - logger: { - error: jest.fn(), - debug: jest.fn(), - }, -})); +jest.mock('@librechat/data-schemas', () => { + const actualModule = jest.requireActual('@librechat/data-schemas'); + return { + ...actualModule, + logger: { + error: jest.fn(), + debug: jest.fn(), + }, + }; +}); jest.mock('./process', () => ({ createSocialUser: jest.fn(), handleExistingUser: jest.fn(), @@ -64,7 +67,6 @@ describe('Apple Login Strategy', () => { // Define getProfileDetails within the test scope getProfileDetails = ({ idToken, profile }) => { - console.log('getProfileDetails called with idToken:', idToken); if (!idToken) { logger.error('idToken is missing'); throw new Error('idToken is missing'); @@ -84,9 +86,7 @@ describe('Apple Login Strategy', () => { email: decoded.email, id: decoded.sub, avatarUrl: null, // Apple does not provide an avatar URL - username: decoded.email - ? decoded.email.split('@')[0].toLowerCase() - : `user_${decoded.sub}`, + username: decoded.email ? decoded.email.split('@')[0].toLowerCase() : `user_${decoded.sub}`, name: decoded.name ? `${decoded.name.firstName} ${decoded.name.lastName}` : profile.displayName || null, @@ -96,8 +96,12 @@ describe('Apple Login Strategy', () => { // Mock isEnabled based on environment variable isEnabled.mockImplementation((flag) => { - if (flag === 'true') { return true; } - if (flag === 'false') { return false; } + if (flag === 'true') { + return true; + } + if (flag === 'false') { + return false; + } return false; }); @@ -154,9 +158,7 @@ describe('Apple Login Strategy', () => { }); expect(jwt.decode).toHaveBeenCalledWith('fake_id_token'); - expect(logger.debug).toHaveBeenCalledWith( - expect.stringContaining('Decoded Apple JWT'), - ); + expect(logger.debug).toHaveBeenCalledWith(expect.stringContaining('Decoded Apple JWT')); expect(profileDetails).toEqual({ email: 'john.doe@example.com', id: 'apple-sub-1234', @@ -209,7 +211,7 @@ describe('Apple Login Strategy', () => { beforeEach(() => { jwt.decode.mockReturnValue(decodedToken); - findUser.mockImplementation(({ email }) => User.findOne({ email })); + findUser.mockResolvedValue(null); }); it('should create a new user if one does not exist and registration is allowed', async () => { @@ -248,7 +250,7 @@ describe('Apple Login Strategy', () => { }); it('should handle existing user and update avatarUrl', async () => { - // Create an existing user + // Create an existing user without saving to database const existingUser = new User({ email: 'jane.doe@example.com', username: 'jane.doe', @@ -257,15 +259,15 @@ describe('Apple Login Strategy', () => { providerId: 'apple-sub-9012', avatarUrl: 'old_avatar.png', }); - await existingUser.save(); // Mock findUser to return the existing user findUser.mockResolvedValue(existingUser); - // Mock handleExistingUser to update avatarUrl + // Mock handleExistingUser to update avatarUrl without saving to database handleExistingUser.mockImplementation(async (user, avatarUrl) => { user.avatarUrl = avatarUrl; - await user.save(); + // Don't call save() to avoid database operations + return user; }); const mockVerifyCallback = jest.fn(); @@ -297,7 +299,7 @@ describe('Apple Login Strategy', () => { appleStrategyInstance._verify( fakeAccessToken, fakeRefreshToken, - null, // idToken is missing + null, // idToken is missing mockProfile, (err, user) => { mockVerifyCallback(err, user); diff --git a/api/strategies/jwtStrategy.js b/api/strategies/jwtStrategy.js index eb4b34fd85..6793873ee8 100644 --- a/api/strategies/jwtStrategy.js +++ b/api/strategies/jwtStrategy.js @@ -1,7 +1,7 @@ +const { logger } = require('@librechat/data-schemas'); const { SystemRoles } = require('librechat-data-provider'); const { Strategy: JwtStrategy, ExtractJwt } = require('passport-jwt'); const { getUserById, updateUser } = require('~/models'); -const { logger } = require('~/config'); // JWT strategy const jwtLogin = () => diff --git a/api/strategies/ldapStrategy.js b/api/strategies/ldapStrategy.js index beb9b8c2fd..434534c264 100644 --- a/api/strategies/ldapStrategy.js +++ b/api/strategies/ldapStrategy.js @@ -1,10 +1,10 @@ const fs = require('fs'); const LdapStrategy = require('passport-ldapauth'); const { SystemRoles } = require('librechat-data-provider'); -const { findUser, createUser, updateUser } = require('~/models/userMethods'); -const { countUsers } = require('~/models/userMethods'); +const { logger } = require('@librechat/data-schemas'); +const { createUser, findUser, updateUser, countUsers } = require('~/models'); +const { getBalanceConfig } = require('~/server/services/Config'); const { isEnabled } = require('~/server/utils'); -const logger = require('~/utils/logger'); const { LDAP_URL, @@ -124,7 +124,8 @@ const ldapLogin = new LdapStrategy(ldapOptions, async (userinfo, done) => { name: fullName, role: isFirstRegisteredUser ? SystemRoles.ADMIN : SystemRoles.USER, }; - const userId = await createUser(user); + const balanceConfig = await getBalanceConfig(); + const userId = await createUser(user, balanceConfig); user._id = userId; } else { // Users registered in LDAP are assumed to have their user information managed in LDAP, diff --git a/api/strategies/localStrategy.js b/api/strategies/localStrategy.js index bffb4f845f..bc84e7c6b5 100644 --- a/api/strategies/localStrategy.js +++ b/api/strategies/localStrategy.js @@ -1,9 +1,9 @@ +const { logger } = require('@librechat/data-schemas'); const { errorsToString } = require('librechat-data-provider'); const { Strategy: PassportLocalStrategy } = require('passport-local'); -const { findUser, comparePassword, updateUser } = require('~/models'); const { isEnabled, checkEmailConfig } = require('~/server/utils'); +const { findUser, comparePassword, updateUser } = require('~/models'); const { loginSchema } = require('./validators'); -const logger = require('~/utils/logger'); // Unix timestamp for 2024-06-07 15:20:18 Eastern Time const verificationEnabledTimestamp = 1717788018; @@ -29,6 +29,12 @@ async function passportLogin(req, email, password, done) { return done(null, false, { message: 'Email does not exist.' }); } + if (!user.password) { + logError('Passport Local Strategy - User has no password', { email }); + logger.error(`[Login] [Login failed] [Username: ${email}] [Request-IP: ${req.ip}]`); + return done(null, false, { message: 'Email does not exist.' }); + } + const isMatch = await comparePassword(user, password); if (!isMatch) { logError('Passport Local Strategy - Password does not match', { isMatch }); diff --git a/api/strategies/openIdJwtStrategy.js b/api/strategies/openIdJwtStrategy.js index dae8d17bc6..cc90e20036 100644 --- a/api/strategies/openIdJwtStrategy.js +++ b/api/strategies/openIdJwtStrategy.js @@ -1,4 +1,5 @@ const { SystemRoles } = require('librechat-data-provider'); +const { HttpsProxyAgent } = require('https-proxy-agent'); const { Strategy: JwtStrategy, ExtractJwt } = require('passport-jwt'); const { updateUser, findUser } = require('~/models'); const { logger } = require('~/config'); @@ -13,17 +14,23 @@ const { isEnabled } = require('~/server/utils'); * The strategy extracts the JWT from the Authorization header as a Bearer token. * The JWT is then verified using the signing key, and the user is retrieved from the database. */ -const openIdJwtLogin = (openIdConfig) => - new JwtStrategy( +const openIdJwtLogin = (openIdConfig) => { + let jwksRsaOptions = { + cache: isEnabled(process.env.OPENID_JWKS_URL_CACHE_ENABLED) || true, + cacheMaxAge: process.env.OPENID_JWKS_URL_CACHE_TIME + ? eval(process.env.OPENID_JWKS_URL_CACHE_TIME) + : 60000, + jwksUri: openIdConfig.serverMetadata().jwks_uri, + }; + + if (process.env.PROXY) { + jwksRsaOptions.requestAgent = new HttpsProxyAgent(process.env.PROXY); + } + + return new JwtStrategy( { jwtFromRequest: ExtractJwt.fromAuthHeaderAsBearerToken(), - secretOrKeyProvider: jwksRsa.passportJwtSecret({ - cache: isEnabled(process.env.OPENID_JWKS_URL_CACHE_ENABLED) || true, - cacheMaxAge: process.env.OPENID_JWKS_URL_CACHE_TIME - ? eval(process.env.OPENID_JWKS_URL_CACHE_TIME) - : 60000, - jwksUri: openIdConfig.serverMetadata().jwks_uri, - }), + secretOrKeyProvider: jwksRsa.passportJwtSecret(jwksRsaOptions), }, async (payload, done) => { try { @@ -48,5 +55,6 @@ const openIdJwtLogin = (openIdConfig) => } }, ); +}; module.exports = openIdJwtLogin; diff --git a/api/strategies/openidStrategy.js b/api/strategies/openidStrategy.js index ea109358d7..563ac8257e 100644 --- a/api/strategies/openidStrategy.js +++ b/api/strategies/openidStrategy.js @@ -1,15 +1,16 @@ -const { CacheKeys } = require('librechat-data-provider'); +const undici = require('undici'); const fetch = require('node-fetch'); const passport = require('passport'); -const jwtDecode = require('jsonwebtoken/decode'); -const { HttpsProxyAgent } = require('https-proxy-agent'); const client = require('openid-client'); +const jwtDecode = require('jsonwebtoken/decode'); +const { CacheKeys } = require('librechat-data-provider'); +const { HttpsProxyAgent } = require('https-proxy-agent'); +const { hashToken, logger } = require('@librechat/data-schemas'); const { Strategy: OpenIDStrategy } = require('openid-client/passport'); +const { isEnabled, safeStringify, logHeaders } = require('@librechat/api'); const { getStrategyFunctions } = require('~/server/services/Files/strategies'); -const { findUser, createUser, updateUser } = require('~/models/userMethods'); -const { hashToken } = require('~/server/utils/crypto'); -const { isEnabled } = require('~/server/utils'); -const { logger } = require('~/config'); +const { findUser, createUser, updateUser } = require('~/models'); +const { getBalanceConfig } = require('~/server/services/Config'); const getLogStores = require('~/cache/getLogStores'); /** @@ -17,6 +18,76 @@ const getLogStores = require('~/cache/getLogStores'); * @typedef {import('openid-client').Configuration} Configuration **/ +/** + * @param {string} url + * @param {client.CustomFetchOptions} options + */ +async function customFetch(url, options) { + const urlStr = url.toString(); + logger.debug(`[openidStrategy] Request to: ${urlStr}`); + const debugOpenId = isEnabled(process.env.DEBUG_OPENID_REQUESTS); + if (debugOpenId) { + logger.debug(`[openidStrategy] Request method: ${options.method || 'GET'}`); + logger.debug(`[openidStrategy] Request headers: ${logHeaders(options.headers)}`); + if (options.body) { + let bodyForLogging = ''; + if (options.body instanceof URLSearchParams) { + bodyForLogging = options.body.toString(); + } else if (typeof options.body === 'string') { + bodyForLogging = options.body; + } else { + bodyForLogging = safeStringify(options.body); + } + logger.debug(`[openidStrategy] Request body: ${bodyForLogging}`); + } + } + + try { + /** @type {undici.RequestInit} */ + let fetchOptions = options; + if (process.env.PROXY) { + logger.info(`[openidStrategy] proxy agent configured: ${process.env.PROXY}`); + fetchOptions = { + ...options, + dispatcher: new undici.ProxyAgent(process.env.PROXY), + }; + } + + const response = await undici.fetch(url, fetchOptions); + + if (debugOpenId) { + logger.debug(`[openidStrategy] Response status: ${response.status} ${response.statusText}`); + logger.debug(`[openidStrategy] Response headers: ${logHeaders(response.headers)}`); + } + + if (response.status === 200 && response.headers.has('www-authenticate')) { + const wwwAuth = response.headers.get('www-authenticate'); + logger.warn(`[openidStrategy] Non-standard WWW-Authenticate header found in successful response (200 OK): ${wwwAuth}. +This violates RFC 7235 and may cause issues with strict OAuth clients. Removing header for compatibility.`); + + /** Cloned response without the WWW-Authenticate header */ + const responseBody = await response.arrayBuffer(); + const newHeaders = new Headers(); + for (const [key, value] of response.headers.entries()) { + if (key.toLowerCase() !== 'www-authenticate') { + newHeaders.append(key, value); + } + } + + return new Response(responseBody, { + status: response.status, + statusText: response.statusText, + headers: newHeaders, + }); + } + + return response; + } catch (error) { + logger.error(`[openidStrategy] Fetch error: ${error.message}`); + throw error; + } +} + /** @typedef {Configuration | null} */ let openidConfig = null; @@ -47,7 +118,7 @@ class CustomOpenIDStrategy extends OpenIDStrategy { */ const exchangeAccessTokenIfNeeded = async (config, accessToken, sub, fromCache = false) => { const tokensCache = getLogStores(CacheKeys.OPENID_EXCHANGED_TOKENS); - const onBehalfFlowRequired = isEnabled(process.env.OPENID_ON_BEHALF_FLOW_FOR_USERINFRO_REQUIRED); + const onBehalfFlowRequired = isEnabled(process.env.OPENID_ON_BEHALF_FLOW_FOR_USERINFO_REQUIRED); if (onBehalfFlowRequired) { if (fromCache) { const cachedToken = await tokensCache.get(sub); @@ -59,7 +130,7 @@ const exchangeAccessTokenIfNeeded = async (config, accessToken, sub, fromCache = config, 'urn:ietf:params:oauth:grant-type:jwt-bearer', { - scope: process.env.OPENID_ON_BEHALF_FLOW_USERINFRO_SCOPE || 'user.read', + scope: process.env.OPENID_ON_BEHALF_FLOW_USERINFO_SCOPE || 'user.read', assertion: accessToken, requested_token_use: 'on_behalf_of', }, @@ -208,14 +279,12 @@ async function setupOpenId() { new URL(process.env.OPENID_ISSUER), process.env.OPENID_CLIENT_ID, clientMetadata, + undefined, + { + [client.customFetch]: customFetch, + }, ); - if (process.env.PROXY) { - const proxyAgent = new HttpsProxyAgent(process.env.PROXY); - openidConfig[client.customFetch] = (...args) => { - return fetch(args[0], { ...args[1], agent: proxyAgent }); - }; - logger.info(`[openidStrategy] proxy agent added: ${process.env.PROXY}`); - } + const requiredRole = process.env.OPENID_REQUIRED_ROLE; const requiredRoleParameterPath = process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH; const requiredRoleTokenKind = process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND; @@ -297,7 +366,10 @@ async function setupOpenId() { emailVerified: userinfo.email_verified || false, name: fullName, }; - user = await createUser(user, true, true); + + const balanceConfig = await getBalanceConfig(); + + user = await createUser(user, balanceConfig, true, true); } else { user.provider = 'openid'; user.openidId = userinfo.sub; diff --git a/api/strategies/openidStrategy.spec.js b/api/strategies/openidStrategy.spec.js index e70dfa5529..1e6750384e 100644 --- a/api/strategies/openidStrategy.spec.js +++ b/api/strategies/openidStrategy.spec.js @@ -1,7 +1,7 @@ const fetch = require('node-fetch'); const jwtDecode = require('jsonwebtoken/decode'); -const { findUser, createUser, updateUser } = require('~/models/userMethods'); const { setupOpenId } = require('./openidStrategy'); +const { findUser, createUser, updateUser } = require('~/models'); // --- Mocks --- jest.mock('node-fetch'); @@ -11,24 +11,28 @@ jest.mock('~/server/services/Files/strategies', () => ({ saveBuffer: jest.fn().mockResolvedValue('/fake/path/to/avatar.png'), })), })); -jest.mock('~/models/userMethods', () => ({ +jest.mock('~/server/services/Config', () => ({ + getBalanceConfig: jest.fn(() => ({ + enabled: false, + })), +})); +jest.mock('~/models', () => ({ findUser: jest.fn(), createUser: jest.fn(), updateUser: jest.fn(), })); -jest.mock('~/server/utils/crypto', () => ({ - hashToken: jest.fn().mockResolvedValue('hashed-token'), -})); -jest.mock('~/server/utils', () => ({ +jest.mock('@librechat/api', () => ({ + ...jest.requireActual('@librechat/api'), isEnabled: jest.fn(() => false), })); -jest.mock('~/config', () => ({ +jest.mock('@librechat/data-schemas', () => ({ + ...jest.requireActual('@librechat/api'), logger: { info: jest.fn(), debug: jest.fn(), error: jest.fn(), - warn: jest.fn(), }, + hashToken: jest.fn().mockResolvedValue('hashed-token'), })); jest.mock('~/cache/getLogStores', () => jest.fn(() => ({ @@ -36,11 +40,6 @@ jest.mock('~/cache/getLogStores', () => set: jest.fn(), })), ); -jest.mock('librechat-data-provider', () => ({ - CacheKeys: { - OPENID_EXCHANGED_TOKENS: 'openid-exchanged-tokens', - }, -})); // Mock the openid-client module and all its dependencies jest.mock('openid-client', () => { @@ -174,6 +173,7 @@ describe('setupOpenId', () => { email: userinfo.email, name: `${userinfo.given_name} ${userinfo.family_name}`, }), + { enabled: false }, true, true, ); @@ -193,6 +193,7 @@ describe('setupOpenId', () => { expect(user.username).toBe(expectUsername); expect(createUser).toHaveBeenCalledWith( expect.objectContaining({ username: expectUsername }), + { enabled: false }, true, true, ); @@ -212,6 +213,7 @@ describe('setupOpenId', () => { expect(user.username).toBe(expectUsername); expect(createUser).toHaveBeenCalledWith( expect.objectContaining({ username: expectUsername }), + { enabled: false }, true, true, ); @@ -229,6 +231,7 @@ describe('setupOpenId', () => { expect(user.username).toBe(userinfo.sub); expect(createUser).toHaveBeenCalledWith( expect.objectContaining({ username: userinfo.sub }), + { enabled: false }, true, true, ); diff --git a/api/strategies/process.js b/api/strategies/process.js index e9a908ffd0..1f7e7c81d2 100644 --- a/api/strategies/process.js +++ b/api/strategies/process.js @@ -1,7 +1,8 @@ const { FileSources } = require('librechat-data-provider'); -const { createUser, updateUser, getUserById } = require('~/models/userMethods'); const { getStrategyFunctions } = require('~/server/services/Files/strategies'); const { resizeAvatar } = require('~/server/services/Files/images/avatar'); +const { updateUser, createUser, getUserById } = require('~/models'); +const { getBalanceConfig } = require('~/server/services/Config'); /** * Updates the avatar URL of an existing user. If the user's avatar URL does not include the query parameter @@ -30,7 +31,7 @@ const handleExistingUser = async (oldUser, avatarUrl) => { input: avatarUrl, }); const { processAvatar } = getStrategyFunctions(fileStrategy); - updatedAvatar = await processAvatar({ buffer: resizedBuffer, userId }); + updatedAvatar = await processAvatar({ buffer: resizedBuffer, userId, manual: 'false' }); } if (updatedAvatar) { @@ -78,7 +79,8 @@ const createSocialUser = async ({ emailVerified, }; - const newUserId = await createUser(update); + const balanceConfig = await getBalanceConfig(); + const newUserId = await createUser(update, balanceConfig); const fileStrategy = process.env.CDN_PROVIDER; const isLocal = fileStrategy === FileSources.local; @@ -88,7 +90,11 @@ const createSocialUser = async ({ input: avatarUrl, }); const { processAvatar } = getStrategyFunctions(fileStrategy); - const avatar = await processAvatar({ buffer: resizedBuffer, userId: newUserId }); + const avatar = await processAvatar({ + buffer: resizedBuffer, + userId: newUserId, + manual: 'false', + }); await updateUser(newUserId, { avatar }); } diff --git a/api/strategies/samlStrategy.js b/api/strategies/samlStrategy.js index a0793f1c83..376434f733 100644 --- a/api/strategies/samlStrategy.js +++ b/api/strategies/samlStrategy.js @@ -2,11 +2,11 @@ const fs = require('fs'); const path = require('path'); const fetch = require('node-fetch'); const passport = require('passport'); +const { hashToken, logger } = require('@librechat/data-schemas'); const { Strategy: SamlStrategy } = require('@node-saml/passport-saml'); -const { findUser, createUser, updateUser } = require('~/models/userMethods'); const { getStrategyFunctions } = require('~/server/services/Files/strategies'); -const { hashToken } = require('~/server/utils/crypto'); -const { logger } = require('~/config'); +const { findUser, createUser, updateUser } = require('~/models'); +const { getBalanceConfig } = require('~/server/services/Config'); const paths = require('~/config/paths'); let crypto; @@ -218,7 +218,8 @@ async function setupSaml() { emailVerified: true, name: fullName, }; - user = await createUser(user, true, true); + const balanceConfig = await getBalanceConfig(); + user = await createUser(user, balanceConfig, true, true); } else { user.provider = 'saml'; user.samlId = profile.nameID; diff --git a/api/strategies/samlStrategy.spec.js b/api/strategies/samlStrategy.spec.js index cb007c75e4..fc8329a31a 100644 --- a/api/strategies/samlStrategy.spec.js +++ b/api/strategies/samlStrategy.spec.js @@ -1,39 +1,56 @@ -const fs = require('fs'); -const path = require('path'); -const fetch = require('node-fetch'); -const { Strategy: SamlStrategy } = require('@node-saml/passport-saml'); -const { findUser, createUser, updateUser } = require('~/models/userMethods'); -const { setupSaml, getCertificateContent } = require('./samlStrategy'); - // --- Mocks --- +jest.mock('tiktoken'); jest.mock('fs'); jest.mock('path'); jest.mock('node-fetch'); jest.mock('@node-saml/passport-saml'); -jest.mock('~/models/userMethods', () => ({ +jest.mock('@librechat/data-schemas', () => ({ + logger: { + info: jest.fn(), + debug: jest.fn(), + error: jest.fn(), + }, + hashToken: jest.fn().mockResolvedValue('hashed-token'), +})); +jest.mock('~/models', () => ({ findUser: jest.fn(), createUser: jest.fn(), updateUser: jest.fn(), })); +jest.mock('~/server/services/Config', () => ({ + config: { + registration: { + socialLogins: ['saml'], + }, + }, + getBalanceConfig: jest.fn().mockResolvedValue({ + tokenCredits: 1000, + startingBalance: 1000, + }), +})); +jest.mock('~/server/services/Config/EndpointService', () => ({ + config: {}, +})); jest.mock('~/server/services/Files/strategies', () => ({ getStrategyFunctions: jest.fn(() => ({ saveBuffer: jest.fn().mockResolvedValue('/fake/path/to/avatar.png'), })), })); -jest.mock('~/server/utils/crypto', () => ({ - hashToken: jest.fn().mockResolvedValue('hashed-token'), -})); -jest.mock('~/server/utils', () => ({ - isEnabled: jest.fn(() => false), -})); -jest.mock('~/config', () => ({ - logger: { - info: jest.fn(), - debug: jest.fn(), - error: jest.fn(), - }, +jest.mock('~/config/paths', () => ({ + root: '/fake/root/path', })); +const fs = require('fs'); +const path = require('path'); +const fetch = require('node-fetch'); +const { Strategy: SamlStrategy } = require('@node-saml/passport-saml'); +const { setupSaml, getCertificateContent } = require('./samlStrategy'); + +// Configure fs mock +jest.mocked(fs).existsSync = jest.fn(); +jest.mocked(fs).statSync = jest.fn(); +jest.mocked(fs).readFileSync = jest.fn(); + // To capture the verify callback from the strategy, we grab it from the mock constructor let verifyCallback; SamlStrategy.mockImplementation((options, verify) => { @@ -196,6 +213,18 @@ describe('setupSaml', () => { beforeEach(async () => { jest.clearAllMocks(); + // Configure mocks + const { findUser, createUser, updateUser } = require('~/models'); + findUser.mockResolvedValue(null); + createUser.mockImplementation(async (userData) => ({ + _id: 'mock-user-id', + ...userData, + })); + updateUser.mockImplementation(async (id, userData) => ({ + _id: id, + ...userData, + })); + const cert = ` -----BEGIN CERTIFICATE----- MIIDazCCAlOgAwIBAgIUKhXaFJGJJPx466rlwYORIsqCq7MwDQYJKoZIhvcNAQEL @@ -232,16 +261,6 @@ u7wlOSk+oFzDIO/UILIA delete process.env.SAML_PICTURE_CLAIM; delete process.env.SAML_NAME_CLAIM; - findUser.mockResolvedValue(null); - createUser.mockImplementation(async (userData) => ({ - _id: 'newUserId', - ...userData, - })); - updateUser.mockImplementation(async (id, userData) => ({ - _id: id, - ...userData, - })); - // Simulate image download const fakeBuffer = Buffer.from('fake image'); fetch.mockResolvedValue({ @@ -257,17 +276,10 @@ u7wlOSk+oFzDIO/UILIA const { user } = await validate(profile); expect(user.username).toBe(profile.username); - expect(createUser).toHaveBeenCalledWith( - expect.objectContaining({ - provider: 'saml', - samlId: profile.nameID, - username: profile.username, - email: profile.email, - name: `${profile.given_name} ${profile.family_name}`, - }), - true, - true, - ); + expect(user.provider).toBe('saml'); + expect(user.samlId).toBe(profile.nameID); + expect(user.email).toBe(profile.email); + expect(user.name).toBe(`${profile.given_name} ${profile.family_name}`); }); it('should use given_name as username when username claim is missing', async () => { @@ -278,11 +290,7 @@ u7wlOSk+oFzDIO/UILIA const { user } = await validate(profile); expect(user.username).toBe(expectUsername); - expect(createUser).toHaveBeenCalledWith( - expect.objectContaining({ username: expectUsername }), - true, - true, - ); + expect(user.provider).toBe('saml'); }); it('should use email as username when username and given_name are missing', async () => { @@ -294,11 +302,7 @@ u7wlOSk+oFzDIO/UILIA const { user } = await validate(profile); expect(user.username).toBe(expectUsername); - expect(createUser).toHaveBeenCalledWith( - expect.objectContaining({ username: expectUsername }), - true, - true, - ); + expect(user.provider).toBe('saml'); }); it('should override username with SAML_USERNAME_CLAIM when set', async () => { @@ -308,11 +312,7 @@ u7wlOSk+oFzDIO/UILIA const { user } = await validate(profile); expect(user.username).toBe(profile.nameID); - expect(createUser).toHaveBeenCalledWith( - expect.objectContaining({ username: profile.nameID }), - true, - true, - ); + expect(user.provider).toBe('saml'); }); it('should set the full name correctly when given_name and family_name exist', async () => { @@ -378,34 +378,26 @@ u7wlOSk+oFzDIO/UILIA }); it('should update an existing user on login', async () => { + // Set up findUser to return an existing user + const { findUser } = require('~/models'); const existingUser = { - _id: 'existingUserId', + _id: 'existing-user-id', provider: 'local', email: baseProfile.email, samlId: '', - username: '', - name: '', + username: 'oldusername', + name: 'Old Name', }; - - findUser.mockImplementation(async (query) => { - if (query.samlId === baseProfile.nameID || query.email === baseProfile.email) { - return existingUser; - } - return null; - }); + findUser.mockResolvedValue(existingUser); const profile = { ...baseProfile }; - await validate(profile); + const { user } = await validate(profile); - expect(updateUser).toHaveBeenCalledWith( - existingUser._id, - expect.objectContaining({ - provider: 'saml', - samlId: baseProfile.nameID, - username: baseProfile.username, - name: `${baseProfile.given_name} ${baseProfile.family_name}`, - }), - ); + expect(user.provider).toBe('saml'); + expect(user.samlId).toBe(baseProfile.nameID); + expect(user.username).toBe(baseProfile.username); + expect(user.name).toBe(`${baseProfile.given_name} ${baseProfile.family_name}`); + expect(user.email).toBe(baseProfile.email); }); it('should attempt to download and save the avatar if picture is provided', async () => { diff --git a/api/strategies/socialLogin.js b/api/strategies/socialLogin.js index 925c2de34d..4f9462316a 100644 --- a/api/strategies/socialLogin.js +++ b/api/strategies/socialLogin.js @@ -1,7 +1,7 @@ +const { logger } = require('@librechat/data-schemas'); const { createSocialUser, handleExistingUser } = require('./process'); const { isEnabled } = require('~/server/utils'); const { findUser } = require('~/models'); -const { logger } = require('~/config'); const socialLogin = (provider, getProfileDetails) => async (accessToken, refreshToken, idToken, profile, cb) => { diff --git a/api/test/__mocks__/logger.js b/api/test/__mocks__/logger.js index f9f6d78c87..56fb28cbab 100644 --- a/api/test/__mocks__/logger.js +++ b/api/test/__mocks__/logger.js @@ -41,10 +41,7 @@ jest.mock('winston-daily-rotate-file', () => { }); jest.mock('~/config', () => { - const actualModule = jest.requireActual('~/config'); return { - sendEvent: actualModule.sendEvent, - createAxiosInstance: actualModule.createAxiosInstance, logger: { info: jest.fn(), warn: jest.fn(), diff --git a/api/typedefs.js b/api/typedefs.js index 8da5b34809..c0e0dd5f46 100644 --- a/api/typedefs.js +++ b/api/typedefs.js @@ -476,11 +476,18 @@ * @memberof typedefs */ +/** + * @exports ToolCallChunk + * @typedef {import('librechat-data-provider').Agents.ToolCallChunk} ToolCallChunk + * @memberof typedefs + */ + /** * @exports MessageContentImageUrl * @typedef {import('librechat-data-provider').Agents.MessageContentImageUrl} MessageContentImageUrl * @memberof typedefs */ + /** Web Search */ /** @@ -1073,7 +1080,7 @@ /** * @exports MCPServers - * @typedef {import('librechat-mcp').MCPServers} MCPServers + * @typedef {import('@librechat/api').MCPServers} MCPServers * @memberof typedefs */ @@ -1085,31 +1092,31 @@ /** * @exports MCPManager - * @typedef {import('librechat-mcp').MCPManager} MCPManager + * @typedef {import('@librechat/api').MCPManager} MCPManager * @memberof typedefs */ /** * @exports FlowStateManager - * @typedef {import('librechat-mcp').FlowStateManager} FlowStateManager + * @typedef {import('@librechat/api').FlowStateManager} FlowStateManager * @memberof typedefs */ /** * @exports LCAvailableTools - * @typedef {import('librechat-mcp').LCAvailableTools} LCAvailableTools + * @typedef {import('@librechat/api').LCAvailableTools} LCAvailableTools * @memberof typedefs */ /** * @exports LCTool - * @typedef {import('librechat-mcp').LCTool} LCTool + * @typedef {import('@librechat/api').LCTool} LCTool * @memberof typedefs */ /** * @exports FormattedContent - * @typedef {import('librechat-mcp').FormattedContent} FormattedContent + * @typedef {import('@librechat/api').FormattedContent} FormattedContent * @memberof typedefs */ @@ -1232,7 +1239,7 @@ * @typedef {Object} AgentClientOptions * @property {Agent} agent - The agent configuration object * @property {string} endpoint - The endpoint identifier for the agent - * @property {Object} req - The request object + * @property {ServerRequest} req - The request object * @property {string} [name] - The username * @property {string} [modelLabel] - The label for the model being used * @property {number} [maxContextTokens] - Maximum number of tokens allowed in context @@ -1496,7 +1503,6 @@ * @property {boolean|{userProvide: boolean}} [anthropic] - Flag to indicate if Anthropic endpoint is user provided, or its configuration. * @property {boolean|{userProvide: boolean}} [google] - Flag to indicate if Google endpoint is user provided, or its configuration. * @property {boolean|{userProvide: boolean, userProvideURL: boolean, name: string}} [custom] - Custom Endpoint configuration. - * @property {boolean|GptPlugins} [gptPlugins] - Configuration for GPT plugins. * @memberof typedefs */ diff --git a/api/utils/axios.js b/api/utils/axios.js deleted file mode 100644 index 91c1fbb223..0000000000 --- a/api/utils/axios.js +++ /dev/null @@ -1,46 +0,0 @@ -const { logger } = require('~/config'); - -/** - * Logs Axios errors based on the error object and a custom message. - * - * @param {Object} options - The options object. - * @param {string} options.message - The custom message to be logged. - * @param {import('axios').AxiosError} options.error - The Axios error object. - * @returns {string} The log message. - */ -const logAxiosError = ({ message, error }) => { - let logMessage = message; - try { - const stack = error.stack || 'No stack trace available'; - - if (error.response?.status) { - const { status, headers, data } = error.response; - logMessage = `${message} The server responded with status ${status}: ${error.message}`; - logger.error(logMessage, { - status, - headers, - data, - stack, - }); - } else if (error.request) { - const { method, url } = error.config || {}; - logMessage = `${message} No response received for ${method ? method.toUpperCase() : ''} ${url || ''}: ${error.message}`; - logger.error(logMessage, { - requestInfo: { method, url }, - stack, - }); - } else if (error?.message?.includes("Cannot read properties of undefined (reading 'status')")) { - logMessage = `${message} It appears the request timed out or was unsuccessful: ${error.message}`; - logger.error(logMessage, { stack }); - } else { - logMessage = `${message} An error occurred while setting up the request: ${error.message}`; - logger.error(logMessage, { stack }); - } - } catch (err) { - logMessage = `Error in logAxiosError: ${err.message}`; - logger.error(logMessage, { stack: err.stack || 'No stack trace available' }); - } - return logMessage; -}; - -module.exports = { logAxiosError }; diff --git a/api/utils/index.js b/api/utils/index.js index 62d61586bf..b80c9b0c31 100644 --- a/api/utils/index.js +++ b/api/utils/index.js @@ -1,17 +1,11 @@ -const loadYaml = require('./loadYaml'); -const axiosHelpers = require('./axios'); const tokenHelpers = require('./tokens'); -const azureUtils = require('./azureUtils'); const deriveBaseURL = require('./deriveBaseURL'); const extractBaseURL = require('./extractBaseURL'); const findMessageContent = require('./findMessageContent'); module.exports = { - loadYaml, deriveBaseURL, extractBaseURL, - ...azureUtils, - ...axiosHelpers, ...tokenHelpers, findMessageContent, }; diff --git a/api/utils/loadYaml.js b/api/utils/loadYaml.js deleted file mode 100644 index 50e5d23ec3..0000000000 --- a/api/utils/loadYaml.js +++ /dev/null @@ -1,13 +0,0 @@ -const fs = require('fs'); -const yaml = require('js-yaml'); - -function loadYaml(filepath) { - try { - let fileContents = fs.readFileSync(filepath, 'utf8'); - return yaml.load(fileContents); - } catch (e) { - return e; - } -} - -module.exports = loadYaml; diff --git a/client/package.json b/client/package.json index 8e4be78764..9c86cd5d4d 100644 --- a/client/package.json +++ b/client/package.json @@ -1,6 +1,6 @@ { "name": "@librechat/frontend", - "version": "v0.7.8", + "version": "v0.7.9-rc1", "description": "", "type": "module", "scripts": { @@ -65,6 +65,7 @@ "export-from-json": "^1.7.2", "filenamify": "^6.0.0", "framer-motion": "^11.5.4", + "heic-to": "^1.1.14", "html-to-image": "^1.11.11", "i18next": "^24.2.2", "i18next-browser-languagedetector": "^8.0.3", @@ -74,6 +75,7 @@ "lodash": "^4.17.21", "lucide-react": "^0.394.0", "match-sorter": "^6.3.4", + "micromark-extension-llm-math": "^3.1.0", "qrcode.react": "^4.2.0", "rc-input-number": "^7.4.2", "react": "^18.2.0", @@ -139,7 +141,6 @@ "postcss": "^8.4.31", "postcss-loader": "^7.1.0", "postcss-preset-env": "^8.2.0", - "rollup-plugin-visualizer": "^6.0.0", "tailwindcss": "^3.4.1", "ts-jest": "^29.2.5", "typescript": "^5.3.3", diff --git a/client/src/Providers/ActivePanelContext.tsx b/client/src/Providers/ActivePanelContext.tsx new file mode 100644 index 0000000000..4a8d6ccfc4 --- /dev/null +++ b/client/src/Providers/ActivePanelContext.tsx @@ -0,0 +1,37 @@ +import { createContext, useContext, useState, ReactNode } from 'react'; + +interface ActivePanelContextType { + active: string | undefined; + setActive: (id: string) => void; +} + +const ActivePanelContext = createContext(undefined); + +export function ActivePanelProvider({ + children, + defaultActive, +}: { + children: ReactNode; + defaultActive?: string; +}) { + const [active, _setActive] = useState(defaultActive); + + const setActive = (id: string) => { + localStorage.setItem('side:active-panel', id); + _setActive(id); + }; + + return ( + + {children} + + ); +} + +export function useActivePanel() { + const context = useContext(ActivePanelContext); + if (context === undefined) { + throw new Error('useActivePanel must be used within an ActivePanelProvider'); + } + return context; +} diff --git a/client/src/Providers/AgentPanelContext.tsx b/client/src/Providers/AgentPanelContext.tsx new file mode 100644 index 0000000000..b15d334078 --- /dev/null +++ b/client/src/Providers/AgentPanelContext.tsx @@ -0,0 +1,96 @@ +import React, { createContext, useContext, useState } from 'react'; +import { Constants, EModelEndpoint } from 'librechat-data-provider'; +import type { TPlugin, AgentToolType, Action, MCP } from 'librechat-data-provider'; +import type { AgentPanelContextType } from '~/common'; +import { useAvailableToolsQuery, useGetActionsQuery } from '~/data-provider'; +import { useLocalize } from '~/hooks'; +import { Panel } from '~/common'; + +const AgentPanelContext = createContext(undefined); + +export function useAgentPanelContext() { + const context = useContext(AgentPanelContext); + if (context === undefined) { + throw new Error('useAgentPanelContext must be used within an AgentPanelProvider'); + } + return context; +} + +/** Houses relevant state for the Agent Form Panels (formerly 'commonProps') */ +export function AgentPanelProvider({ children }: { children: React.ReactNode }) { + const localize = useLocalize(); + const [mcp, setMcp] = useState(undefined); + const [mcps, setMcps] = useState(undefined); + const [action, setAction] = useState(undefined); + const [activePanel, setActivePanel] = useState(Panel.builder); + const [agent_id, setCurrentAgentId] = useState(undefined); + + const { data: actions } = useGetActionsQuery(EModelEndpoint.agents, { + enabled: !!agent_id, + }); + + const { data: pluginTools } = useAvailableToolsQuery(EModelEndpoint.agents, { + enabled: !!agent_id, + }); + + const tools = + pluginTools?.map((tool) => ({ + tool_id: tool.pluginKey, + metadata: tool as TPlugin, + agent_id: agent_id || '', + })) || []; + + const groupedTools = tools?.reduce( + (acc, tool) => { + if (tool.tool_id.includes(Constants.mcp_delimiter)) { + const [_toolName, serverName] = tool.tool_id.split(Constants.mcp_delimiter); + const groupKey = `${serverName.toLowerCase()}`; + if (!acc[groupKey]) { + acc[groupKey] = { + tool_id: groupKey, + metadata: { + name: `${serverName}`, + pluginKey: groupKey, + description: `${localize('com_ui_tool_collection_prefix')} ${serverName}`, + icon: tool.metadata.icon || '', + } as TPlugin, + agent_id: agent_id || '', + tools: [], + }; + } + acc[groupKey].tools?.push({ + tool_id: tool.tool_id, + metadata: tool.metadata, + agent_id: agent_id || '', + }); + } else { + acc[tool.tool_id] = { + tool_id: tool.tool_id, + metadata: tool.metadata, + agent_id: agent_id || '', + }; + } + return acc; + }, + {} as Record, + ); + + const value = { + action, + setAction, + mcp, + setMcp, + mcps, + setMcps, + activePanel, + setActivePanel, + setCurrentAgentId, + agent_id, + groupedTools, + /** Query data for actions and tools */ + actions, + tools, + }; + + return {children}; +} diff --git a/client/src/Providers/AgentsContext.tsx b/client/src/Providers/AgentsContext.tsx index e793a3f087..a90a53ecb5 100644 --- a/client/src/Providers/AgentsContext.tsx +++ b/client/src/Providers/AgentsContext.tsx @@ -1,8 +1,8 @@ import { useForm, FormProvider } from 'react-hook-form'; import { createContext, useContext } from 'react'; -import { defaultAgentFormValues } from 'librechat-data-provider'; import type { UseFormReturn } from 'react-hook-form'; import type { AgentForm } from '~/common'; +import { getDefaultAgentFormValues } from '~/utils'; type AgentsContextType = UseFormReturn; @@ -20,7 +20,7 @@ export function useAgentsContext() { export default function AgentsProvider({ children }) { const methods = useForm({ - defaultValues: defaultAgentFormValues, + defaultValues: getDefaultAgentFormValues(), }); return {children}; diff --git a/client/src/Providers/BadgeRowContext.tsx b/client/src/Providers/BadgeRowContext.tsx new file mode 100644 index 0000000000..01590b1948 --- /dev/null +++ b/client/src/Providers/BadgeRowContext.tsx @@ -0,0 +1,89 @@ +import React, { createContext, useContext } from 'react'; +import { Tools, LocalStorageKeys } from 'librechat-data-provider'; +import { useMCPSelect, useToolToggle, useCodeApiKeyForm, useSearchApiKeyForm } from '~/hooks'; +import { useGetStartupConfig } from '~/data-provider'; + +interface BadgeRowContextType { + conversationId?: string | null; + mcpSelect: ReturnType; + webSearch: ReturnType; + codeInterpreter: ReturnType; + fileSearch: ReturnType; + codeApiKeyForm: ReturnType; + searchApiKeyForm: ReturnType; + startupConfig: ReturnType['data']; +} + +const BadgeRowContext = createContext(undefined); + +export function useBadgeRowContext() { + const context = useContext(BadgeRowContext); + if (context === undefined) { + throw new Error('useBadgeRowContext must be used within a BadgeRowProvider'); + } + return context; +} + +interface BadgeRowProviderProps { + children: React.ReactNode; + conversationId?: string | null; +} + +export default function BadgeRowProvider({ children, conversationId }: BadgeRowProviderProps) { + /** Startup config */ + const { data: startupConfig } = useGetStartupConfig(); + + /** MCPSelect hook */ + const mcpSelect = useMCPSelect({ conversationId }); + + /** CodeInterpreter hooks */ + const codeApiKeyForm = useCodeApiKeyForm({}); + const { setIsDialogOpen: setCodeDialogOpen } = codeApiKeyForm; + + const codeInterpreter = useToolToggle({ + conversationId, + setIsDialogOpen: setCodeDialogOpen, + toolKey: Tools.execute_code, + localStorageKey: LocalStorageKeys.LAST_CODE_TOGGLE_, + authConfig: { + toolId: Tools.execute_code, + queryOptions: { retry: 1 }, + }, + }); + + /** WebSearch hooks */ + const searchApiKeyForm = useSearchApiKeyForm({}); + const { setIsDialogOpen: setWebSearchDialogOpen } = searchApiKeyForm; + + const webSearch = useToolToggle({ + conversationId, + toolKey: Tools.web_search, + localStorageKey: LocalStorageKeys.LAST_WEB_SEARCH_TOGGLE_, + setIsDialogOpen: setWebSearchDialogOpen, + authConfig: { + toolId: Tools.web_search, + queryOptions: { retry: 1 }, + }, + }); + + /** FileSearch hook */ + const fileSearch = useToolToggle({ + conversationId, + toolKey: Tools.file_search, + localStorageKey: LocalStorageKeys.LAST_FILE_SEARCH_TOGGLE_, + isAuthenticated: true, + }); + + const value: BadgeRowContextType = { + mcpSelect, + webSearch, + fileSearch, + startupConfig, + conversationId, + codeApiKeyForm, + codeInterpreter, + searchApiKeyForm, + }; + + return {children}; +} diff --git a/client/src/Providers/index.ts b/client/src/Providers/index.ts index 00191318e0..b455cb3f1e 100644 --- a/client/src/Providers/index.ts +++ b/client/src/Providers/index.ts @@ -1,6 +1,8 @@ -export { default as ToastProvider } from './ToastContext'; export { default as AssistantsProvider } from './AssistantsContext'; export { default as AgentsProvider } from './AgentsContext'; +export { default as ToastProvider } from './ToastContext'; +export * from './ActivePanelContext'; +export * from './AgentPanelContext'; export * from './ChatContext'; export * from './ShareContext'; export * from './ToastContext'; @@ -21,3 +23,5 @@ export * from './CodeBlockContext'; export * from './ToolCallsMapContext'; export * from './SetConvoContext'; export * from './SearchContext'; +export * from './BadgeRowContext'; +export { default as BadgeRowProvider } from './BadgeRowContext'; diff --git a/client/src/common/mcp.ts b/client/src/common/mcp.ts new file mode 100644 index 0000000000..b4f44a1f94 --- /dev/null +++ b/client/src/common/mcp.ts @@ -0,0 +1,26 @@ +import { + AuthorizationTypeEnum, + AuthTypeEnum, + TokenExchangeMethodEnum, +} from 'librechat-data-provider'; +import { MCPForm } from '~/common/types'; + +export const defaultMCPFormValues: MCPForm = { + type: AuthTypeEnum.None, + saved_auth_fields: false, + api_key: '', + authorization_type: AuthorizationTypeEnum.Basic, + custom_auth_header: '', + oauth_client_id: '', + oauth_client_secret: '', + authorization_url: '', + client_url: '', + scope: '', + token_exchange_method: TokenExchangeMethodEnum.DefaultPost, + name: '', + description: '', + url: '', + tools: [], + icon: '', + trust: false, +}; diff --git a/client/src/common/types.ts b/client/src/common/types.ts index 6837869e8e..9349b7695e 100644 --- a/client/src/common/types.ts +++ b/client/src/common/types.ts @@ -143,6 +143,7 @@ export enum Panel { actions = 'actions', model = 'model', version = 'version', + mcp = 'mcp', } export type FileSetter = @@ -166,6 +167,15 @@ export type ActionAuthForm = { token_exchange_method: t.TokenExchangeMethodEnum; }; +export type MCPForm = ActionAuthForm & { + name?: string; + description?: string; + url?: string; + tools?: string[]; + icon?: string; + trust?: boolean; +}; + export type ActionWithNullableMetadata = Omit & { metadata: t.ActionMetadata | null; }; @@ -188,16 +198,35 @@ export type AgentPanelProps = { index?: number; agent_id?: string; activePanel?: string; + mcp?: t.MCP; + mcps?: t.MCP[]; action?: t.Action; actions?: t.Action[]; createMutation: UseMutationResult; setActivePanel: React.Dispatch>; + setMcp: React.Dispatch>; setAction: React.Dispatch>; endpointsConfig?: t.TEndpointsConfig; setCurrentAgentId: React.Dispatch>; agentsConfig?: t.TAgentsEndpoint | null; }; +export type AgentPanelContextType = { + action?: t.Action; + actions?: t.Action[]; + setAction: React.Dispatch>; + mcp?: t.MCP; + mcps?: t.MCP[]; + setMcp: React.Dispatch>; + setMcps: React.Dispatch>; + tools: t.AgentToolType[]; + activePanel?: string; + setActivePanel: React.Dispatch>; + setCurrentAgentId: React.Dispatch>; + groupedTools?: Record; + agent_id?: string; +}; + export type AgentModelPanelProps = { agent_id?: string; providers: Option[]; @@ -307,6 +336,11 @@ export type TAskProps = { export type TOptions = { editedMessageId?: string | null; editedText?: string | null; + editedContent?: { + index: number; + text: string; + type: 'text' | 'think'; + }; isRegenerate?: boolean; isContinued?: boolean; isEdited?: boolean; @@ -457,11 +491,20 @@ export type VoiceOption = { }; export type TMessageAudio = { - messageId?: string; - content?: t.TMessageContentParts[] | string; - className?: string; - isLast: boolean; + isLast?: boolean; index: number; + messageId: string; + content: string; + className?: string; + renderButton?: (props: { + onClick: (e?: React.MouseEvent) => void; + title: string; + icon: React.ReactNode; + isActive?: boolean; + isVisible?: boolean; + isDisabled?: boolean; + className?: string; + }) => React.ReactNode; }; export type OptionWithIcon = Option & { icon?: React.ReactNode }; diff --git a/client/src/components/Artifacts/Artifact.tsx b/client/src/components/Artifacts/Artifact.tsx index 2b06a2ccc0..902ac9191a 100644 --- a/client/src/components/Artifacts/Artifact.tsx +++ b/client/src/components/Artifacts/Artifact.tsx @@ -40,7 +40,7 @@ const defaultType = 'unknown'; const defaultIdentifier = 'lc-no-identifier'; export function Artifact({ - node, + node: _node, ...props }: Artifact & { children: React.ReactNode | { props: { children: React.ReactNode } }; @@ -95,7 +95,7 @@ export function Artifact({ setArtifacts((prevArtifacts) => { if ( prevArtifacts?.[artifactKey] != null && - prevArtifacts[artifactKey].content === content + prevArtifacts[artifactKey]?.content === content ) { return prevArtifacts; } diff --git a/client/src/components/Artifacts/useDebounceCodeBlock.ts b/client/src/components/Artifacts/useDebounceCodeBlock.ts deleted file mode 100644 index 27aaf5bc83..0000000000 --- a/client/src/components/Artifacts/useDebounceCodeBlock.ts +++ /dev/null @@ -1,37 +0,0 @@ -// client/src/hooks/useDebounceCodeBlock.ts -import { useCallback, useEffect } from 'react'; -import debounce from 'lodash/debounce'; -import { useSetRecoilState } from 'recoil'; -import { codeBlocksState, codeBlockIdsState } from '~/store/artifacts'; -import type { CodeBlock } from '~/common'; - -export function useDebounceCodeBlock() { - const setCodeBlocks = useSetRecoilState(codeBlocksState); - const setCodeBlockIds = useSetRecoilState(codeBlockIdsState); - - const updateCodeBlock = useCallback((codeBlock: CodeBlock) => { - console.log('Updating code block:', codeBlock); - setCodeBlocks((prev) => ({ - ...prev, - [codeBlock.id]: codeBlock, - })); - setCodeBlockIds((prev) => - prev.includes(codeBlock.id) ? prev : [...prev, codeBlock.id], - ); - }, [setCodeBlocks, setCodeBlockIds]); - - const debouncedUpdateCodeBlock = useCallback( - debounce((codeBlock: CodeBlock) => { - updateCodeBlock(codeBlock); - }, 25), - [updateCodeBlock], - ); - - useEffect(() => { - return () => { - debouncedUpdateCodeBlock.cancel(); - }; - }, [debouncedUpdateCodeBlock]); - - return debouncedUpdateCodeBlock; -} diff --git a/client/src/components/Audio/TTS.tsx b/client/src/components/Audio/TTS.tsx index 3ceacb7f8d..9343b483d2 100644 --- a/client/src/components/Audio/TTS.tsx +++ b/client/src/components/Audio/TTS.tsx @@ -1,5 +1,5 @@ /* eslint-disable jsx-a11y/media-has-caption */ -import { useEffect, useMemo } from 'react'; +import { useEffect } from 'react'; import { useRecoilValue } from 'recoil'; import type { TMessageAudio } from '~/common'; import { useLocalize, useTTSBrowser, useTTSExternal } from '~/hooks'; @@ -7,7 +7,14 @@ import { VolumeIcon, VolumeMuteIcon, Spinner } from '~/components'; import { logger } from '~/utils'; import store from '~/store'; -export function BrowserTTS({ isLast, index, messageId, content, className }: TMessageAudio) { +export function BrowserTTS({ + isLast, + index, + messageId, + content, + className, + renderButton, +}: TMessageAudio) { const localize = useLocalize(); const playbackRate = useRecoilValue(store.playbackRate); @@ -18,16 +25,16 @@ export function BrowserTTS({ isLast, index, messageId, content, className }: TMe content, }); - const renderIcon = (size: string) => { + const renderIcon = () => { if (isLoading === true) { - return ; + return ; } if (isSpeaking === true) { - return ; + return ; } - return ; + return ; }; useEffect(() => { @@ -46,21 +53,30 @@ export function BrowserTTS({ isLast, index, messageId, content, className }: TMe audioRef.current, ); + const handleClick = () => { + if (audioRef.current) { + audioRef.current.muted = false; + } + toggleSpeech(); + }; + + const title = isSpeaking === true ? localize('com_ui_stop') : localize('com_ui_read_aloud'); + return ( <> - + {renderButton ? ( + renderButton({ + onClick: handleClick, + title: title, + icon: renderIcon(), + isActive: isSpeaking, + className, + }) + ) : ( + + )}