diff --git a/.env.example b/.env.example index f79b89a155..876535b345 100644 --- a/.env.example +++ b/.env.example @@ -515,6 +515,18 @@ EMAIL_PASSWORD= EMAIL_FROM_NAME= EMAIL_FROM=noreply@librechat.ai +#========================# +# Mailgun API # +#========================# + +# MAILGUN_API_KEY=your-mailgun-api-key +# MAILGUN_DOMAIN=mg.yourdomain.com +# EMAIL_FROM=noreply@yourdomain.com +# EMAIL_FROM_NAME="LibreChat" + +# # Optional: For EU region +# MAILGUN_HOST=https://api.eu.mailgun.net + #========================# # Firebase CDN # #========================# diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 09444a1b44..207aa17e66 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -30,8 +30,8 @@ Project maintainers have the right and responsibility to remove, edit, or reject 2. Install typescript globally: `npm i -g typescript`. 3. Run `npm ci` to install dependencies. 4. Build the data provider: `npm run build:data-provider`. -5. Build MCP: `npm run build:mcp`. -6. Build data schemas: `npm run build:data-schemas`. +5. Build data schemas: `npm run build:data-schemas`. +6. Build API methods: `npm run build:api`. 7. Setup and run unit tests: - Copy `.env.test`: `cp api/test/.env.test.example api/test/.env.test`. - Run backend unit tests: `npm run test:api`. diff --git a/.github/workflows/backend-review.yml b/.github/workflows/backend-review.yml index b7bccecae8..7637b8cdc0 100644 --- a/.github/workflows/backend-review.yml +++ b/.github/workflows/backend-review.yml @@ -7,6 +7,7 @@ on: - release/* paths: - 'api/**' + - 'packages/api/**' jobs: tests_Backend: name: Run Backend unit tests @@ -36,12 +37,12 @@ jobs: - name: Install Data Provider Package run: npm run build:data-provider - - name: Install MCP Package - run: npm run build:mcp - - name: Install Data Schemas Package run: npm run build:data-schemas + - name: Install API Package + run: npm run build:api + - name: Create empty auth.json file run: | mkdir -p api/data @@ -66,5 +67,8 @@ jobs: - name: Run librechat-data-provider unit tests run: cd packages/data-provider && npm run test:ci - - name: Run librechat-mcp unit tests - run: cd packages/mcp && npm run test:ci \ No newline at end of file + - name: Run @librechat/data-schemas unit tests + run: cd packages/data-schemas && npm run test:ci + + - name: Run @librechat/api unit tests + run: cd packages/api && npm run test:ci \ No newline at end of file diff --git a/.github/workflows/unused-packages.yml b/.github/workflows/unused-packages.yml index 442e70e52c..dc6ce3ba56 100644 --- a/.github/workflows/unused-packages.yml +++ b/.github/workflows/unused-packages.yml @@ -98,6 +98,8 @@ jobs: cd client UNUSED=$(depcheck --json | jq -r '.dependencies | join("\n")' || echo "") UNUSED=$(comm -23 <(echo "$UNUSED" | sort) <(cat ../client_used_deps.txt ../client_used_code.txt | sort) || echo "") + # Filter out false positives + UNUSED=$(echo "$UNUSED" | grep -v "^micromark-extension-llm-math$" || echo "") echo "CLIENT_UNUSED<> $GITHUB_ENV echo "$UNUSED" >> $GITHUB_ENV echo "EOF" >> $GITHUB_ENV diff --git a/.gitignore b/.gitignore index f49594afdf..c9658f17e6 100644 --- a/.gitignore +++ b/.gitignore @@ -55,6 +55,7 @@ bower_components/ # AI .clineignore .cursor +.aider* # Floobits .floo diff --git a/Dockerfile.multi b/Dockerfile.multi index 991f805bec..17a9847323 100644 --- a/Dockerfile.multi +++ b/Dockerfile.multi @@ -14,7 +14,7 @@ RUN npm config set fetch-retry-maxtimeout 600000 && \ npm config set fetch-retry-mintimeout 15000 COPY package*.json ./ COPY packages/data-provider/package*.json ./packages/data-provider/ -COPY packages/mcp/package*.json ./packages/mcp/ +COPY packages/api/package*.json ./packages/api/ COPY packages/data-schemas/package*.json ./packages/data-schemas/ COPY client/package*.json ./client/ COPY api/package*.json ./api/ @@ -24,26 +24,27 @@ FROM base-min AS base WORKDIR /app RUN npm ci -# Build data-provider +# Build `data-provider` package FROM base AS data-provider-build WORKDIR /app/packages/data-provider COPY packages/data-provider ./ RUN npm run build -# Build mcp package -FROM base AS mcp-build -WORKDIR /app/packages/mcp -COPY packages/mcp ./ -COPY --from=data-provider-build /app/packages/data-provider/dist /app/packages/data-provider/dist -RUN npm run build - -# Build data-schemas +# Build `data-schemas` package FROM base AS data-schemas-build WORKDIR /app/packages/data-schemas COPY packages/data-schemas ./ COPY --from=data-provider-build /app/packages/data-provider/dist /app/packages/data-provider/dist RUN npm run build +# Build `api` package +FROM base AS api-package-build +WORKDIR /app/packages/api +COPY packages/api ./ +COPY --from=data-provider-build /app/packages/data-provider/dist /app/packages/data-provider/dist +COPY --from=data-schemas-build /app/packages/data-schemas/dist /app/packages/data-schemas/dist +RUN npm run build + # Client build FROM base AS client-build WORKDIR /app/client @@ -63,8 +64,8 @@ RUN npm ci --omit=dev COPY api ./api COPY config ./config COPY --from=data-provider-build /app/packages/data-provider/dist ./packages/data-provider/dist -COPY --from=mcp-build /app/packages/mcp/dist ./packages/mcp/dist COPY --from=data-schemas-build /app/packages/data-schemas/dist ./packages/data-schemas/dist +COPY --from=api-package-build /app/packages/api/dist ./packages/api/dist COPY --from=client-build /app/client/dist ./client/dist WORKDIR /app/api EXPOSE 3080 diff --git a/README.md b/README.md index cc9533b2d2..d6bd19ab43 100644 --- a/README.md +++ b/README.md @@ -150,8 +150,8 @@ Click on the thumbnail to open the video☝️ **Other:** - **Website:** [librechat.ai](https://librechat.ai) - - **Documentation:** [docs.librechat.ai](https://docs.librechat.ai) - - **Blog:** [blog.librechat.ai](https://blog.librechat.ai) + - **Documentation:** [librechat.ai/docs](https://librechat.ai/docs) + - **Blog:** [librechat.ai/blog](https://librechat.ai/blog) --- diff --git a/api/app/clients/AnthropicClient.js b/api/app/clients/AnthropicClient.js index 0da331ced5..a3fba29d5c 100644 --- a/api/app/clients/AnthropicClient.js +++ b/api/app/clients/AnthropicClient.js @@ -10,6 +10,7 @@ const { validateVisionModel, } = require('librechat-data-provider'); const { SplitStreamHandler: _Handler } = require('@librechat/agents'); +const { Tokenizer, createFetch, createStreamEventHandlers } = require('@librechat/api'); const { truncateText, formatMessage, @@ -26,8 +27,6 @@ const { const { getModelMaxTokens, getModelMaxOutputTokens, matchModelName } = require('~/utils'); const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); const { encodeAndFormat } = require('~/server/services/Files/images/encode'); -const { createFetch, createStreamEventHandlers } = require('./generators'); -const Tokenizer = require('~/server/services/Tokenizer'); const { sleep } = require('~/server/utils'); const BaseClient = require('./BaseClient'); const { logger } = require('~/config'); @@ -191,10 +190,11 @@ class AnthropicClient extends BaseClient { reverseProxyUrl: this.options.reverseProxyUrl, }), apiKey: this.apiKey, + fetchOptions: {}, }; if (this.options.proxy) { - options.httpAgent = new HttpsProxyAgent(this.options.proxy); + options.fetchOptions.agent = new HttpsProxyAgent(this.options.proxy); } if (this.options.reverseProxyUrl) { diff --git a/api/app/clients/ChatGPTClient.js b/api/app/clients/ChatGPTClient.js index 36a3f4936a..555028dc3f 100644 --- a/api/app/clients/ChatGPTClient.js +++ b/api/app/clients/ChatGPTClient.js @@ -2,6 +2,7 @@ const { Keyv } = require('keyv'); const crypto = require('crypto'); const { CohereClient } = require('cohere-ai'); const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source'); +const { constructAzureURL, genAzureChatCompletion } = require('@librechat/api'); const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken'); const { ImageDetail, @@ -10,9 +11,9 @@ const { CohereConstants, mapModelToAzureConfig, } = require('librechat-data-provider'); -const { extractBaseURL, constructAzureURL, genAzureChatCompletion } = require('~/utils'); const { createContextHandlers } = require('./prompts'); const { createCoherePayload } = require('./llm'); +const { extractBaseURL } = require('~/utils'); const BaseClient = require('./BaseClient'); const { logger } = require('~/config'); diff --git a/api/app/clients/GoogleClient.js b/api/app/clients/GoogleClient.js index 4151e6663a..817239d14f 100644 --- a/api/app/clients/GoogleClient.js +++ b/api/app/clients/GoogleClient.js @@ -1,4 +1,5 @@ const { google } = require('googleapis'); +const { Tokenizer } = require('@librechat/api'); const { concat } = require('@langchain/core/utils/stream'); const { ChatVertexAI } = require('@langchain/google-vertexai'); const { ChatGoogleGenerativeAI } = require('@langchain/google-genai'); @@ -19,7 +20,6 @@ const { } = require('librechat-data-provider'); const { getSafetySettings } = require('~/server/services/Endpoints/google/llm'); const { encodeAndFormat } = require('~/server/services/Files/images'); -const Tokenizer = require('~/server/services/Tokenizer'); const { spendTokens } = require('~/models/spendTokens'); const { getModelMaxTokens } = require('~/utils'); const { sleep } = require('~/server/utils'); @@ -34,7 +34,8 @@ const BaseClient = require('./BaseClient'); const loc = process.env.GOOGLE_LOC || 'us-central1'; const publisher = 'google'; -const endpointPrefix = `${loc}-aiplatform.googleapis.com`; +const endpointPrefix = + loc === 'global' ? 'aiplatform.googleapis.com' : `${loc}-aiplatform.googleapis.com`; const settings = endpointSettings[EModelEndpoint.google]; const EXCLUDED_GENAI_MODELS = /gemini-(?:1\.0|1-0|pro)/; diff --git a/api/app/clients/OllamaClient.js b/api/app/clients/OllamaClient.js index 77d007580c..032781f1f1 100644 --- a/api/app/clients/OllamaClient.js +++ b/api/app/clients/OllamaClient.js @@ -1,10 +1,11 @@ const { z } = require('zod'); const axios = require('axios'); const { Ollama } = require('ollama'); +const { sleep } = require('@librechat/agents'); +const { logAxiosError } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { Constants } = require('librechat-data-provider'); -const { deriveBaseURL, logAxiosError } = require('~/utils'); -const { sleep } = require('~/server/utils'); -const { logger } = require('~/config'); +const { deriveBaseURL } = require('~/utils'); const ollamaPayloadSchema = z.object({ mirostat: z.number().optional(), @@ -67,7 +68,7 @@ class OllamaClient { return models; } catch (error) { const logMessage = - 'Failed to fetch models from Ollama API. If you are not using Ollama directly, and instead, through some aggregator or reverse proxy that handles fetching via OpenAI spec, ensure the name of the endpoint doesn\'t start with `ollama` (case-insensitive).'; + "Failed to fetch models from Ollama API. If you are not using Ollama directly, and instead, through some aggregator or reverse proxy that handles fetching via OpenAI spec, ensure the name of the endpoint doesn't start with `ollama` (case-insensitive)."; logAxiosError({ message: logMessage, error }); return []; } diff --git a/api/app/clients/OpenAIClient.js b/api/app/clients/OpenAIClient.js index 280db89284..2d4146bd9c 100644 --- a/api/app/clients/OpenAIClient.js +++ b/api/app/clients/OpenAIClient.js @@ -1,6 +1,14 @@ const { OllamaClient } = require('./OllamaClient'); const { HttpsProxyAgent } = require('https-proxy-agent'); const { SplitStreamHandler, CustomOpenAIClient: OpenAI } = require('@librechat/agents'); +const { + isEnabled, + Tokenizer, + createFetch, + constructAzureURL, + genAzureChatCompletion, + createStreamEventHandlers, +} = require('@librechat/api'); const { Constants, ImageDetail, @@ -16,13 +24,6 @@ const { validateVisionModel, mapModelToAzureConfig, } = require('librechat-data-provider'); -const { - extractBaseURL, - constructAzureURL, - getModelMaxTokens, - genAzureChatCompletion, - getModelMaxOutputTokens, -} = require('~/utils'); const { truncateText, formatMessage, @@ -30,10 +31,9 @@ const { titleInstruction, createContextHandlers, } = require('./prompts'); +const { extractBaseURL, getModelMaxTokens, getModelMaxOutputTokens } = require('~/utils'); const { encodeAndFormat } = require('~/server/services/Files/images/encode'); -const { createFetch, createStreamEventHandlers } = require('./generators'); -const { addSpaceIfNeeded, isEnabled, sleep } = require('~/server/utils'); -const Tokenizer = require('~/server/services/Tokenizer'); +const { addSpaceIfNeeded, sleep } = require('~/server/utils'); const { spendTokens } = require('~/models/spendTokens'); const { handleOpenAIErrors } = require('./tools/util'); const { createLLM, RunManager } = require('./llm'); @@ -1159,6 +1159,7 @@ ${convo} logger.debug('[OpenAIClient] chatCompletion', { baseURL, modelOptions }); const opts = { baseURL, + fetchOptions: {}, }; if (this.useOpenRouter) { @@ -1177,7 +1178,7 @@ ${convo} } if (this.options.proxy) { - opts.httpAgent = new HttpsProxyAgent(this.options.proxy); + opts.fetchOptions.agent = new HttpsProxyAgent(this.options.proxy); } /** @type {TAzureConfig | undefined} */ @@ -1395,7 +1396,7 @@ ${convo} ...modelOptions, stream: true, }; - const stream = await openai.beta.chat.completions + const stream = await openai.chat.completions .stream(params) .on('abort', () => { /* Do nothing here */ diff --git a/api/app/clients/generators.js b/api/app/clients/generators.js deleted file mode 100644 index 9814cac7a5..0000000000 --- a/api/app/clients/generators.js +++ /dev/null @@ -1,71 +0,0 @@ -const fetch = require('node-fetch'); -const { GraphEvents } = require('@librechat/agents'); -const { logger, sendEvent } = require('~/config'); -const { sleep } = require('~/server/utils'); - -/** - * Makes a function to make HTTP request and logs the process. - * @param {Object} params - * @param {boolean} [params.directEndpoint] - Whether to use a direct endpoint. - * @param {string} [params.reverseProxyUrl] - The reverse proxy URL to use for the request. - * @returns {Promise} - A promise that resolves to the response of the fetch request. - */ -function createFetch({ directEndpoint = false, reverseProxyUrl = '' }) { - /** - * Makes an HTTP request and logs the process. - * @param {RequestInfo} url - The URL to make the request to. Can be a string or a Request object. - * @param {RequestInit} [init] - Optional init options for the request. - * @returns {Promise} - A promise that resolves to the response of the fetch request. - */ - return async (_url, init) => { - let url = _url; - if (directEndpoint) { - url = reverseProxyUrl; - } - logger.debug(`Making request to ${url}`); - if (typeof Bun !== 'undefined') { - return await fetch(url, init); - } - return await fetch(url, init); - }; -} - -// Add this at the module level outside the class -/** - * Creates event handlers for stream events that don't capture client references - * @param {Object} res - The response object to send events to - * @returns {Object} Object containing handler functions - */ -function createStreamEventHandlers(res) { - return { - [GraphEvents.ON_RUN_STEP]: (event) => { - if (res) { - sendEvent(res, event); - } - }, - [GraphEvents.ON_MESSAGE_DELTA]: (event) => { - if (res) { - sendEvent(res, event); - } - }, - [GraphEvents.ON_REASONING_DELTA]: (event) => { - if (res) { - sendEvent(res, event); - } - }, - }; -} - -function createHandleLLMNewToken(streamRate) { - return async () => { - if (streamRate) { - await sleep(streamRate); - } - }; -} - -module.exports = { - createFetch, - createHandleLLMNewToken, - createStreamEventHandlers, -}; diff --git a/api/app/clients/llm/createLLM.js b/api/app/clients/llm/createLLM.js index c8d6666bce..846c4d8e9c 100644 --- a/api/app/clients/llm/createLLM.js +++ b/api/app/clients/llm/createLLM.js @@ -1,6 +1,5 @@ const { ChatOpenAI } = require('@langchain/openai'); -const { sanitizeModelName, constructAzureURL } = require('~/utils'); -const { isEnabled } = require('~/server/utils'); +const { isEnabled, sanitizeModelName, constructAzureURL } = require('@librechat/api'); /** * Creates a new instance of a language model (LLM) for chat interactions. diff --git a/api/app/clients/specs/AnthropicClient.test.js b/api/app/clients/specs/AnthropicClient.test.js index 9867859087..fbcd2b75e4 100644 --- a/api/app/clients/specs/AnthropicClient.test.js +++ b/api/app/clients/specs/AnthropicClient.test.js @@ -309,7 +309,7 @@ describe('AnthropicClient', () => { }; client.setOptions({ modelOptions, promptCache: true }); const anthropicClient = client.getClient(modelOptions); - expect(anthropicClient.defaultHeaders).not.toHaveProperty('anthropic-beta'); + expect(anthropicClient._options.defaultHeaders).toBeUndefined(); }); it('should not add beta header for other models', () => { @@ -320,7 +320,7 @@ describe('AnthropicClient', () => { }, }); const anthropicClient = client.getClient(); - expect(anthropicClient.defaultHeaders).not.toHaveProperty('anthropic-beta'); + expect(anthropicClient._options.defaultHeaders).toBeUndefined(); }); }); diff --git a/api/app/clients/specs/BaseClient.test.js b/api/app/clients/specs/BaseClient.test.js index 0ba77db6fa..6d44915804 100644 --- a/api/app/clients/specs/BaseClient.test.js +++ b/api/app/clients/specs/BaseClient.test.js @@ -33,7 +33,9 @@ jest.mock('~/models', () => ({ const { getConvo, saveConvo } = require('~/models'); jest.mock('@librechat/agents', () => { + const { Providers } = jest.requireActual('@librechat/agents'); return { + Providers, ChatOpenAI: jest.fn().mockImplementation(() => { return {}; }), diff --git a/api/app/clients/tools/dynamic/OpenAPIPlugin.js b/api/app/clients/tools/dynamic/OpenAPIPlugin.js deleted file mode 100644 index acc3a64d32..0000000000 --- a/api/app/clients/tools/dynamic/OpenAPIPlugin.js +++ /dev/null @@ -1,184 +0,0 @@ -require('dotenv').config(); -const fs = require('fs'); -const { z } = require('zod'); -const path = require('path'); -const yaml = require('js-yaml'); -const { createOpenAPIChain } = require('langchain/chains'); -const { DynamicStructuredTool } = require('@langchain/core/tools'); -const { ChatPromptTemplate, HumanMessagePromptTemplate } = require('@langchain/core/prompts'); -const { logger } = require('~/config'); - -function addLinePrefix(text, prefix = '// ') { - return text - .split('\n') - .map((line) => prefix + line) - .join('\n'); -} - -function createPrompt(name, functions) { - const prefix = `// The ${name} tool has the following functions. Determine the desired or most optimal function for the user's query:`; - const functionDescriptions = functions - .map((func) => `// - ${func.name}: ${func.description}`) - .join('\n'); - return `${prefix}\n${functionDescriptions} -// You are an expert manager and scrum master. You must provide a detailed intent to better execute the function. -// Always format as such: {{"func": "function_name", "intent": "intent and expected result"}}`; -} - -const AuthBearer = z - .object({ - type: z.string().includes('service_http'), - authorization_type: z.string().includes('bearer'), - verification_tokens: z.object({ - openai: z.string(), - }), - }) - .catch(() => false); - -const AuthDefinition = z - .object({ - type: z.string(), - authorization_type: z.string(), - verification_tokens: z.object({ - openai: z.string(), - }), - }) - .catch(() => false); - -async function readSpecFile(filePath) { - try { - const fileContents = await fs.promises.readFile(filePath, 'utf8'); - if (path.extname(filePath) === '.json') { - return JSON.parse(fileContents); - } - return yaml.load(fileContents); - } catch (e) { - logger.error('[readSpecFile] error', e); - return false; - } -} - -async function getSpec(url) { - const RegularUrl = z - .string() - .url() - .catch(() => false); - - if (RegularUrl.parse(url) && path.extname(url) === '.json') { - const response = await fetch(url); - return await response.json(); - } - - const ValidSpecPath = z - .string() - .url() - .catch(async () => { - const spec = path.join(__dirname, '..', '.well-known', 'openapi', url); - if (!fs.existsSync(spec)) { - return false; - } - - return await readSpecFile(spec); - }); - - return ValidSpecPath.parse(url); -} - -async function createOpenAPIPlugin({ data, llm, user, message, memory, signal }) { - let spec; - try { - spec = await getSpec(data.api.url); - } catch (error) { - logger.error('[createOpenAPIPlugin] getSpec error', error); - return null; - } - - if (!spec) { - logger.warn('[createOpenAPIPlugin] No spec found'); - return null; - } - - const headers = {}; - const { auth, name_for_model, description_for_model, description_for_human } = data; - if (auth && AuthDefinition.parse(auth)) { - logger.debug('[createOpenAPIPlugin] auth detected', auth); - const { openai } = auth.verification_tokens; - if (AuthBearer.parse(auth)) { - headers.authorization = `Bearer ${openai}`; - logger.debug('[createOpenAPIPlugin] added auth bearer', headers); - } - } - - const chainOptions = { llm }; - - if (data.headers && data.headers['librechat_user_id']) { - logger.debug('[createOpenAPIPlugin] id detected', headers); - headers[data.headers['librechat_user_id']] = user; - } - - if (Object.keys(headers).length > 0) { - logger.debug('[createOpenAPIPlugin] headers detected', headers); - chainOptions.headers = headers; - } - - if (data.params) { - logger.debug('[createOpenAPIPlugin] params detected', data.params); - chainOptions.params = data.params; - } - - let history = ''; - if (memory) { - logger.debug('[createOpenAPIPlugin] openAPI chain: memory detected', memory); - const { history: chat_history } = await memory.loadMemoryVariables({}); - history = chat_history?.length > 0 ? `\n\n## Chat History:\n${chat_history}\n` : ''; - } - - chainOptions.prompt = ChatPromptTemplate.fromMessages([ - HumanMessagePromptTemplate.fromTemplate( - `# Use the provided API's to respond to this query:\n\n{query}\n\n## Instructions:\n${addLinePrefix( - description_for_model, - )}${history}`, - ), - ]); - - const chain = await createOpenAPIChain(spec, chainOptions); - - const { functions } = chain.chains[0].lc_kwargs.llmKwargs; - - return new DynamicStructuredTool({ - name: name_for_model, - description_for_model: `${addLinePrefix(description_for_human)}${createPrompt( - name_for_model, - functions, - )}`, - description: `${description_for_human}`, - schema: z.object({ - func: z - .string() - .describe( - `The function to invoke. The functions available are: ${functions - .map((func) => func.name) - .join(', ')}`, - ), - intent: z - .string() - .describe('Describe your intent with the function and your expected result'), - }), - func: async ({ func = '', intent = '' }) => { - const filteredFunctions = functions.filter((f) => f.name === func); - chain.chains[0].lc_kwargs.llmKwargs.functions = filteredFunctions; - const query = `${message}${func?.length > 0 ? `\n// Intent: ${intent}` : ''}`; - const result = await chain.call({ - query, - signal, - }); - return result.response; - }, - }); -} - -module.exports = { - getSpec, - readSpecFile, - createOpenAPIPlugin, -}; diff --git a/api/app/clients/tools/dynamic/OpenAPIPlugin.spec.js b/api/app/clients/tools/dynamic/OpenAPIPlugin.spec.js deleted file mode 100644 index 83bc5e9397..0000000000 --- a/api/app/clients/tools/dynamic/OpenAPIPlugin.spec.js +++ /dev/null @@ -1,72 +0,0 @@ -const fs = require('fs'); -const { createOpenAPIPlugin, getSpec, readSpecFile } = require('./OpenAPIPlugin'); - -global.fetch = jest.fn().mockImplementationOnce(() => { - return new Promise((resolve) => { - resolve({ - ok: true, - json: () => Promise.resolve({ key: 'value' }), - }); - }); -}); -jest.mock('fs', () => ({ - promises: { - readFile: jest.fn(), - }, - existsSync: jest.fn(), -})); - -describe('readSpecFile', () => { - it('reads JSON file correctly', async () => { - fs.promises.readFile.mockResolvedValue(JSON.stringify({ test: 'value' })); - const result = await readSpecFile('test.json'); - expect(result).toEqual({ test: 'value' }); - }); - - it('reads YAML file correctly', async () => { - fs.promises.readFile.mockResolvedValue('test: value'); - const result = await readSpecFile('test.yaml'); - expect(result).toEqual({ test: 'value' }); - }); - - it('handles error correctly', async () => { - fs.promises.readFile.mockRejectedValue(new Error('test error')); - const result = await readSpecFile('test.json'); - expect(result).toBe(false); - }); -}); - -describe('getSpec', () => { - it('fetches spec from url correctly', async () => { - const parsedJson = await getSpec('https://www.instacart.com/.well-known/ai-plugin.json'); - const isObject = typeof parsedJson === 'object'; - expect(isObject).toEqual(true); - }); - - it('reads spec from file correctly', async () => { - fs.existsSync.mockReturnValue(true); - fs.promises.readFile.mockResolvedValue(JSON.stringify({ test: 'value' })); - const result = await getSpec('test.json'); - expect(result).toEqual({ test: 'value' }); - }); - - it('returns false when file does not exist', async () => { - fs.existsSync.mockReturnValue(false); - const result = await getSpec('test.json'); - expect(result).toBe(false); - }); -}); - -describe('createOpenAPIPlugin', () => { - it('returns null when getSpec throws an error', async () => { - const result = await createOpenAPIPlugin({ data: { api: { url: 'invalid' } } }); - expect(result).toBe(null); - }); - - it('returns null when no spec is found', async () => { - const result = await createOpenAPIPlugin({}); - expect(result).toBe(null); - }); - - // Add more tests here for different scenarios -}); diff --git a/api/app/clients/tools/structured/DALLE3.js b/api/app/clients/tools/structured/DALLE3.js index fc0f1851f6..7c2a56fe71 100644 --- a/api/app/clients/tools/structured/DALLE3.js +++ b/api/app/clients/tools/structured/DALLE3.js @@ -8,10 +8,10 @@ const { HttpsProxyAgent } = require('https-proxy-agent'); const { FileContext, ContentTypes } = require('librechat-data-provider'); const { getImageBasename } = require('~/server/services/Files/images'); const extractBaseURL = require('~/utils/extractBaseURL'); -const { logger } = require('~/config'); +const logger = require('~/config/winston'); const displayMessage = - 'DALL-E displayed an image. All generated images are already plainly visible, so don\'t repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user.'; + "DALL-E displayed an image. All generated images are already plainly visible, so don't repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user."; class DALLE3 extends Tool { constructor(fields = {}) { super(); diff --git a/api/app/clients/tools/structured/OpenAIImageTools.js b/api/app/clients/tools/structured/OpenAIImageTools.js index 499cda3ea7..08e15a7fad 100644 --- a/api/app/clients/tools/structured/OpenAIImageTools.js +++ b/api/app/clients/tools/structured/OpenAIImageTools.js @@ -4,12 +4,13 @@ const { v4 } = require('uuid'); const OpenAI = require('openai'); const FormData = require('form-data'); const { tool } = require('@langchain/core/tools'); +const { logAxiosError } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { HttpsProxyAgent } = require('https-proxy-agent'); const { ContentTypes, EImageOutputType } = require('librechat-data-provider'); const { getStrategyFunctions } = require('~/server/services/Files/strategies'); -const { logAxiosError, extractBaseURL } = require('~/utils'); +const { extractBaseURL } = require('~/utils'); const { getFiles } = require('~/models/File'); -const { logger } = require('~/config'); /** Default descriptions for image generation tool */ const DEFAULT_IMAGE_GEN_DESCRIPTION = ` diff --git a/api/app/clients/tools/structured/specs/DALLE3.spec.js b/api/app/clients/tools/structured/specs/DALLE3.spec.js index 1b28de2faf..2def575fb3 100644 --- a/api/app/clients/tools/structured/specs/DALLE3.spec.js +++ b/api/app/clients/tools/structured/specs/DALLE3.spec.js @@ -1,10 +1,29 @@ const OpenAI = require('openai'); const DALLE3 = require('../DALLE3'); - -const { logger } = require('~/config'); +const logger = require('~/config/winston'); jest.mock('openai'); +jest.mock('@librechat/data-schemas', () => { + return { + logger: { + info: jest.fn(), + warn: jest.fn(), + debug: jest.fn(), + error: jest.fn(), + }, + }; +}); + +jest.mock('tiktoken', () => { + return { + encoding_for_model: jest.fn().mockReturnValue({ + encode: jest.fn(), + decode: jest.fn(), + }), + }; +}); + const processFileURL = jest.fn(); jest.mock('~/server/services/Files/images', () => ({ @@ -37,6 +56,11 @@ jest.mock('fs', () => { return { existsSync: jest.fn(), mkdirSync: jest.fn(), + promises: { + writeFile: jest.fn(), + readFile: jest.fn(), + unlink: jest.fn(), + }, }; }); diff --git a/api/app/clients/tools/util/fileSearch.js b/api/app/clients/tools/util/fileSearch.js index 54da483362..19d3a79edb 100644 --- a/api/app/clients/tools/util/fileSearch.js +++ b/api/app/clients/tools/util/fileSearch.js @@ -135,7 +135,7 @@ const createFileSearchTool = async ({ req, files, entity_id }) => { query: z .string() .describe( - 'A natural language query to search for relevant information in the files. Be specific and use keywords related to the information you\'re looking for. The query will be used for semantic similarity matching against the file contents.', + "A natural language query to search for relevant information in the files. Be specific and use keywords related to the information you're looking for. The query will be used for semantic similarity matching against the file contents.", ), }), }, diff --git a/api/app/clients/tools/util/handleTools.js b/api/app/clients/tools/util/handleTools.js index 51f0c87ef9..c233c0f762 100644 --- a/api/app/clients/tools/util/handleTools.js +++ b/api/app/clients/tools/util/handleTools.js @@ -1,14 +1,14 @@ +const { mcpToolPattern } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { SerpAPI } = require('@langchain/community/tools/serpapi'); const { Calculator } = require('@langchain/community/tools/calculator'); const { EnvVar, createCodeExecutionTool, createSearchTool } = require('@librechat/agents'); const { Tools, - Constants, EToolResources, loadWebSearchAuth, replaceSpecialVars, } = require('librechat-data-provider'); -const { getUserPluginAuthValue } = require('~/server/services/PluginService'); const { availableTools, manifestToolMap, @@ -28,11 +28,10 @@ const { } = require('../'); const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process'); const { createFileSearchTool, primeFiles: primeSearchFiles } = require('./fileSearch'); +const { getUserPluginAuthValue } = require('~/server/services/PluginService'); const { loadAuthValues } = require('~/server/services/Tools/credentials'); +const { getCachedTools } = require('~/server/services/Config'); const { createMCPTool } = require('~/server/services/MCP'); -const { logger } = require('~/config'); - -const mcpToolPattern = new RegExp(`^.+${Constants.mcp_delimiter}.+$`); /** * Validates the availability and authentication of tools for a user based on environment variables or user-specific plugin authentication values. @@ -93,7 +92,7 @@ const validateTools = async (user, tools = []) => { return Array.from(validToolsSet.values()); } catch (err) { logger.error('[validateTools] There was a problem validating tools', err); - throw new Error('There was a problem validating tools'); + throw new Error(err); } }; @@ -236,7 +235,7 @@ const loadTools = async ({ /** @type {Record} */ const toolContextMap = {}; - const appTools = options.req?.app?.locals?.availableTools ?? {}; + const appTools = (await getCachedTools({ includeGlobal: true })) ?? {}; for (const tool of tools) { if (tool === Tools.execute_code) { @@ -299,6 +298,7 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })} requestedTools[tool] = async () => createMCPTool({ req: options.req, + res: options.res, toolKey: tool, model: agent?.model ?? model, provider: agent?.provider ?? endpoint, diff --git a/api/cache/getLogStores.js b/api/cache/getLogStores.js index 2478bf40d9..06cadf9f64 100644 --- a/api/cache/getLogStores.js +++ b/api/cache/getLogStores.js @@ -29,6 +29,10 @@ const roles = isRedisEnabled ? new Keyv({ store: keyvRedis }) : new Keyv({ namespace: CacheKeys.ROLES }); +const mcpTools = isRedisEnabled + ? new Keyv({ store: keyvRedis }) + : new Keyv({ namespace: CacheKeys.MCP_TOOLS }); + const audioRuns = isRedisEnabled ? new Keyv({ store: keyvRedis, ttl: Time.TEN_MINUTES }) : new Keyv({ namespace: CacheKeys.AUDIO_RUNS, ttl: Time.TEN_MINUTES }); @@ -67,6 +71,7 @@ const openIdExchangedTokensCache = isRedisEnabled const namespaces = { [CacheKeys.ROLES]: roles, + [CacheKeys.MCP_TOOLS]: mcpTools, [CacheKeys.CONFIG_STORE]: config, [CacheKeys.PENDING_REQ]: pending_req, [ViolationTypes.BAN]: new Keyv({ store: keyvMongo, namespace: CacheKeys.BANS, ttl: duration }), diff --git a/api/config/index.js b/api/config/index.js index e238f700be..2e69e87118 100644 --- a/api/config/index.js +++ b/api/config/index.js @@ -1,7 +1,6 @@ -const axios = require('axios'); const { EventSource } = require('eventsource'); -const { Time, CacheKeys } = require('librechat-data-provider'); -const { MCPManager, FlowStateManager } = require('librechat-mcp'); +const { Time } = require('librechat-data-provider'); +const { MCPManager, FlowStateManager } = require('@librechat/api'); const logger = require('./winston'); global.EventSource = EventSource; @@ -16,7 +15,7 @@ let flowManager = null; */ function getMCPManager(userId) { if (!mcpManager) { - mcpManager = MCPManager.getInstance(logger); + mcpManager = MCPManager.getInstance(); } else { mcpManager.checkIdleConnections(userId); } @@ -31,66 +30,13 @@ function getFlowStateManager(flowsCache) { if (!flowManager) { flowManager = new FlowStateManager(flowsCache, { ttl: Time.ONE_MINUTE * 3, - logger, }); } return flowManager; } -/** - * Sends message data in Server Sent Events format. - * @param {ServerResponse} res - The server response. - * @param {{ data: string | Record, event?: string }} event - The message event. - * @param {string} event.event - The type of event. - * @param {string} event.data - The message to be sent. - */ -const sendEvent = (res, event) => { - if (typeof event.data === 'string' && event.data.length === 0) { - return; - } - res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`); -}; - -/** - * Creates and configures an Axios instance with optional proxy settings. - * - * @typedef {import('axios').AxiosInstance} AxiosInstance - * @typedef {import('axios').AxiosProxyConfig} AxiosProxyConfig - * - * @returns {AxiosInstance} A configured Axios instance - * @throws {Error} If there's an issue creating the Axios instance or parsing the proxy URL - */ -function createAxiosInstance() { - const instance = axios.create(); - - if (process.env.proxy) { - try { - const url = new URL(process.env.proxy); - - /** @type {AxiosProxyConfig} */ - const proxyConfig = { - host: url.hostname.replace(/^\[|\]$/g, ''), - protocol: url.protocol.replace(':', ''), - }; - - if (url.port) { - proxyConfig.port = parseInt(url.port, 10); - } - - instance.defaults.proxy = proxyConfig; - } catch (error) { - console.error('Error parsing proxy URL:', error); - throw new Error(`Invalid proxy URL: ${process.env.proxy}`); - } - } - - return instance; -} - module.exports = { logger, - sendEvent, getMCPManager, - createAxiosInstance, getFlowStateManager, }; diff --git a/api/db/indexSync.js b/api/db/indexSync.js index e8bcd55e37..945346a906 100644 --- a/api/db/indexSync.js +++ b/api/db/indexSync.js @@ -1,8 +1,11 @@ const mongoose = require('mongoose'); const { MeiliSearch } = require('meilisearch'); const { logger } = require('@librechat/data-schemas'); +const { FlowStateManager } = require('@librechat/api'); +const { CacheKeys } = require('librechat-data-provider'); const { isEnabled } = require('~/server/utils'); +const { getLogStores } = require('~/cache'); const Conversation = mongoose.models.Conversation; const Message = mongoose.models.Message; @@ -28,43 +31,123 @@ class MeiliSearchClient { } } +/** + * Performs the actual sync operations for messages and conversations + */ +async function performSync() { + const client = MeiliSearchClient.getInstance(); + + const { status } = await client.health(); + if (status !== 'available') { + throw new Error('Meilisearch not available'); + } + + if (indexingDisabled === true) { + logger.info('[indexSync] Indexing is disabled, skipping...'); + return { messagesSync: false, convosSync: false }; + } + + let messagesSync = false; + let convosSync = false; + + // Check if we need to sync messages + const messageProgress = await Message.getSyncProgress(); + if (!messageProgress.isComplete) { + logger.info( + `[indexSync] Messages need syncing: ${messageProgress.totalProcessed}/${messageProgress.totalDocuments} indexed`, + ); + + // Check if we should do a full sync or incremental + const messageCount = await Message.countDocuments(); + const messagesIndexed = messageProgress.totalProcessed; + const syncThreshold = parseInt(process.env.MEILI_SYNC_THRESHOLD || '1000', 10); + + if (messageCount - messagesIndexed > syncThreshold) { + logger.info('[indexSync] Starting full message sync due to large difference'); + await Message.syncWithMeili(); + messagesSync = true; + } else if (messageCount !== messagesIndexed) { + logger.warn('[indexSync] Messages out of sync, performing incremental sync'); + await Message.syncWithMeili(); + messagesSync = true; + } + } else { + logger.info( + `[indexSync] Messages are fully synced: ${messageProgress.totalProcessed}/${messageProgress.totalDocuments}`, + ); + } + + // Check if we need to sync conversations + const convoProgress = await Conversation.getSyncProgress(); + if (!convoProgress.isComplete) { + logger.info( + `[indexSync] Conversations need syncing: ${convoProgress.totalProcessed}/${convoProgress.totalDocuments} indexed`, + ); + + const convoCount = await Conversation.countDocuments(); + const convosIndexed = convoProgress.totalProcessed; + const syncThreshold = parseInt(process.env.MEILI_SYNC_THRESHOLD || '1000', 10); + + if (convoCount - convosIndexed > syncThreshold) { + logger.info('[indexSync] Starting full conversation sync due to large difference'); + await Conversation.syncWithMeili(); + convosSync = true; + } else if (convoCount !== convosIndexed) { + logger.warn('[indexSync] Convos out of sync, performing incremental sync'); + await Conversation.syncWithMeili(); + convosSync = true; + } + } else { + logger.info( + `[indexSync] Conversations are fully synced: ${convoProgress.totalProcessed}/${convoProgress.totalDocuments}`, + ); + } + + return { messagesSync, convosSync }; +} + +/** + * Main index sync function that uses FlowStateManager to prevent concurrent execution + */ async function indexSync() { if (!searchEnabled) { return; } - try { - const client = MeiliSearchClient.getInstance(); - const { status } = await client.health(); - if (status !== 'available') { - throw new Error('Meilisearch not available'); + logger.info('[indexSync] Starting index synchronization check...'); + + try { + // Get or create FlowStateManager instance + const flowsCache = getLogStores(CacheKeys.FLOWS); + if (!flowsCache) { + logger.warn('[indexSync] Flows cache not available, falling back to direct sync'); + return await performSync(); } - if (indexingDisabled === true) { - logger.info('[indexSync] Indexing is disabled, skipping...'); + const flowManager = new FlowStateManager(flowsCache, { + ttl: 60000 * 10, // 10 minutes TTL for sync operations + }); + + // Use a unique flow ID for the sync operation + const flowId = 'meili-index-sync'; + const flowType = 'MEILI_SYNC'; + + // This will only execute the handler if no other instance is running the sync + const result = await flowManager.createFlowWithHandler(flowId, flowType, performSync); + + if (result.messagesSync || result.convosSync) { + logger.info('[indexSync] Sync completed successfully'); + } else { + logger.debug('[indexSync] No sync was needed'); + } + + return result; + } catch (err) { + if (err.message.includes('flow already exists')) { + logger.info('[indexSync] Sync already running on another instance'); return; } - const messageCount = await Message.countDocuments(); - const convoCount = await Conversation.countDocuments(); - const messages = await client.index('messages').getStats(); - const convos = await client.index('convos').getStats(); - const messagesIndexed = messages.numberOfDocuments; - const convosIndexed = convos.numberOfDocuments; - - logger.debug(`[indexSync] There are ${messageCount} messages and ${messagesIndexed} indexed`); - logger.debug(`[indexSync] There are ${convoCount} convos and ${convosIndexed} indexed`); - - if (messageCount !== messagesIndexed) { - logger.debug('[indexSync] Messages out of sync, indexing'); - Message.syncWithMeili(); - } - - if (convoCount !== convosIndexed) { - logger.debug('[indexSync] Convos out of sync, indexing'); - Conversation.syncWithMeili(); - } - } catch (err) { if (err.message.includes('not found')) { logger.debug('[indexSync] Creating indices...'); currentTimeout = setTimeout(async () => { diff --git a/api/models/Agent.js b/api/models/Agent.js index 808dbf09e0..d33ca8a8bf 100644 --- a/api/models/Agent.js +++ b/api/models/Agent.js @@ -11,6 +11,7 @@ const { removeAgentIdsFromProject, removeAgentFromAllProjects, } = require('./Project'); +const { getCachedTools } = require('~/server/services/Config'); const getLogStores = require('~/cache/getLogStores'); const { getActions } = require('./Action'); const { Agent } = require('~/db/models'); @@ -55,12 +56,12 @@ const getAgent = async (searchParameter) => await Agent.findOne(searchParameter) * @param {string} params.agent_id * @param {string} params.endpoint * @param {import('@librechat/agents').ClientOptions} [params.model_parameters] - * @returns {Agent|null} The agent document as a plain object, or null if not found. + * @returns {Promise} The agent document as a plain object, or null if not found. */ -const loadEphemeralAgent = ({ req, agent_id, endpoint, model_parameters: _m }) => { +const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _m }) => { const { model, ...model_parameters } = _m; /** @type {Record} */ - const availableTools = req.app.locals.availableTools; + const availableTools = await getCachedTools({ includeGlobal: true }); /** @type {TEphemeralAgent | null} */ const ephemeralAgent = req.body.ephemeralAgent; const mcpServers = new Set(ephemeralAgent?.mcp); @@ -111,7 +112,7 @@ const loadAgent = async ({ req, agent_id, endpoint, model_parameters }) => { return null; } if (agent_id === EPHEMERAL_AGENT_ID) { - return loadEphemeralAgent({ req, agent_id, endpoint, model_parameters }); + return await loadEphemeralAgent({ req, agent_id, endpoint, model_parameters }); } const agent = await getAgent({ id: agent_id, @@ -170,7 +171,6 @@ const isDuplicateVersion = (updateData, currentData, versions, actionsHash = nul 'created_at', 'updated_at', '__v', - 'agent_ids', 'versions', 'actionsHash', // Exclude actionsHash from direct comparison ]; @@ -260,11 +260,12 @@ const isDuplicateVersion = (updateData, currentData, versions, actionsHash = nul * @param {Object} [options] - Optional configuration object. * @param {string} [options.updatingUserId] - The ID of the user performing the update (used for tracking non-author updates). * @param {boolean} [options.forceVersion] - Force creation of a new version even if no fields changed. + * @param {boolean} [options.skipVersioning] - Skip version creation entirely (useful for isolated operations like sharing). * @returns {Promise} The updated or newly created agent document as a plain object. * @throws {Error} If the update would create a duplicate version */ const updateAgent = async (searchParameter, updateData, options = {}) => { - const { updatingUserId = null, forceVersion = false } = options; + const { updatingUserId = null, forceVersion = false, skipVersioning = false } = options; const mongoOptions = { new: true, upsert: false }; const currentAgent = await Agent.findOne(searchParameter); @@ -301,10 +302,8 @@ const updateAgent = async (searchParameter, updateData, options = {}) => { } const shouldCreateVersion = - forceVersion || - (versions && - versions.length > 0 && - (Object.keys(directUpdates).length > 0 || $push || $pull || $addToSet)); + !skipVersioning && + (forceVersion || Object.keys(directUpdates).length > 0 || $push || $pull || $addToSet); if (shouldCreateVersion) { const duplicateVersion = isDuplicateVersion(updateData, versionData, versions, actionsHash); @@ -339,7 +338,7 @@ const updateAgent = async (searchParameter, updateData, options = {}) => { versionEntry.updatedBy = new mongoose.Types.ObjectId(updatingUserId); } - if (shouldCreateVersion || forceVersion) { + if (shouldCreateVersion) { updateData.$push = { ...($push || {}), versions: versionEntry, @@ -550,7 +549,10 @@ const updateAgentProjects = async ({ user, agentId, projectIds, removeProjectIds delete updateQuery.author; } - const updatedAgent = await updateAgent(updateQuery, updateOps, { updatingUserId: user.id }); + const updatedAgent = await updateAgent(updateQuery, updateOps, { + updatingUserId: user.id, + skipVersioning: true, + }); if (updatedAgent) { return updatedAgent; } diff --git a/api/models/Agent.spec.js b/api/models/Agent.spec.js index bfa5a18259..0b0646f524 100644 --- a/api/models/Agent.spec.js +++ b/api/models/Agent.spec.js @@ -6,12 +6,17 @@ const originalEnv = { process.env.CREDS_KEY = '0123456789abcdef0123456789abcdef'; process.env.CREDS_IV = '0123456789abcdef'; +jest.mock('~/server/services/Config', () => ({ + getCachedTools: jest.fn(), +})); + const mongoose = require('mongoose'); const { v4: uuidv4 } = require('uuid'); const { agentSchema } = require('@librechat/data-schemas'); const { MongoMemoryServer } = require('mongodb-memory-server'); const { getAgent, + loadAgent, createAgent, updateAgent, deleteAgent, @@ -19,1063 +24,2637 @@ const { updateAgentProjects, addAgentResourceFile, removeAgentResourceFiles, + generateActionMetadataHash, + revertAgentVersion, } = require('./Agent'); +const { getCachedTools } = require('~/server/services/Config'); + /** * @type {import('mongoose').Model} */ let Agent; -describe('Agent Resource File Operations', () => { - let mongoServer; +describe('models/Agent', () => { + describe('Agent Resource File Operations', () => { + let mongoServer; - beforeAll(async () => { - mongoServer = await MongoMemoryServer.create(); - const mongoUri = mongoServer.getUri(); - Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema); - await mongoose.connect(mongoUri); - }); - - afterAll(async () => { - await mongoose.disconnect(); - await mongoServer.stop(); - process.env.CREDS_KEY = originalEnv.CREDS_KEY; - process.env.CREDS_IV = originalEnv.CREDS_IV; - }); - - beforeEach(async () => { - await Agent.deleteMany({}); - }); - - const createBasicAgent = async () => { - const agentId = `agent_${uuidv4()}`; - const agent = await Agent.create({ - id: agentId, - name: 'Test Agent', - provider: 'test', - model: 'test-model', - author: new mongoose.Types.ObjectId(), - }); - return agent; - }; - - test('should add tool_resource to tools if missing', async () => { - const agent = await createBasicAgent(); - const fileId = uuidv4(); - const toolResource = 'file_search'; - - const updatedAgent = await addAgentResourceFile({ - agent_id: agent.id, - tool_resource: toolResource, - file_id: fileId, + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema); + await mongoose.connect(mongoUri); }); - expect(updatedAgent.tools).toContain(toolResource); - expect(Array.isArray(updatedAgent.tools)).toBe(true); - // Should not duplicate - const count = updatedAgent.tools.filter((t) => t === toolResource).length; - expect(count).toBe(1); - }); - - test('should not duplicate tool_resource in tools if already present', async () => { - const agent = await createBasicAgent(); - const fileId1 = uuidv4(); - const fileId2 = uuidv4(); - const toolResource = 'file_search'; - - // First add - await addAgentResourceFile({ - agent_id: agent.id, - tool_resource: toolResource, - file_id: fileId1, + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); + process.env.CREDS_KEY = originalEnv.CREDS_KEY; + process.env.CREDS_IV = originalEnv.CREDS_IV; }); - // Second add (should not duplicate) - const updatedAgent = await addAgentResourceFile({ - agent_id: agent.id, - tool_resource: toolResource, - file_id: fileId2, + beforeEach(async () => { + await Agent.deleteMany({}); }); - expect(updatedAgent.tools).toContain(toolResource); - expect(Array.isArray(updatedAgent.tools)).toBe(true); - const count = updatedAgent.tools.filter((t) => t === toolResource).length; - expect(count).toBe(1); - }); + test('should add tool_resource to tools if missing', async () => { + const agent = await createBasicAgent(); + const fileId = uuidv4(); + const toolResource = 'file_search'; - test('should handle concurrent file additions', async () => { - const agent = await createBasicAgent(); - const fileIds = Array.from({ length: 10 }, () => uuidv4()); + const updatedAgent = await addAgentResourceFile({ + agent_id: agent.id, + tool_resource: toolResource, + file_id: fileId, + }); - // Concurrent additions - const additionPromises = fileIds.map((fileId) => - addAgentResourceFile({ + expect(updatedAgent.tools).toContain(toolResource); + expect(Array.isArray(updatedAgent.tools)).toBe(true); + // Should not duplicate + const count = updatedAgent.tools.filter((t) => t === toolResource).length; + expect(count).toBe(1); + }); + + test('should not duplicate tool_resource in tools if already present', async () => { + const agent = await createBasicAgent(); + const fileId1 = uuidv4(); + const fileId2 = uuidv4(); + const toolResource = 'file_search'; + + // First add + await addAgentResourceFile({ + agent_id: agent.id, + tool_resource: toolResource, + file_id: fileId1, + }); + + // Second add (should not duplicate) + const updatedAgent = await addAgentResourceFile({ + agent_id: agent.id, + tool_resource: toolResource, + file_id: fileId2, + }); + + expect(updatedAgent.tools).toContain(toolResource); + expect(Array.isArray(updatedAgent.tools)).toBe(true); + const count = updatedAgent.tools.filter((t) => t === toolResource).length; + expect(count).toBe(1); + }); + + test('should handle concurrent file additions', async () => { + const agent = await createBasicAgent(); + const fileIds = Array.from({ length: 10 }, () => uuidv4()); + + // Concurrent additions + const additionPromises = createFileOperations(agent.id, fileIds, 'add'); + + await Promise.all(additionPromises); + + const updatedAgent = await Agent.findOne({ id: agent.id }); + expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined(); + expect(updatedAgent.tool_resources.test_tool.file_ids).toHaveLength(10); + expect(new Set(updatedAgent.tool_resources.test_tool.file_ids).size).toBe(10); + }); + + test('should handle concurrent additions and removals', async () => { + const agent = await createBasicAgent(); + const initialFileIds = Array.from({ length: 5 }, () => uuidv4()); + + await Promise.all(createFileOperations(agent.id, initialFileIds, 'add')); + + const newFileIds = Array.from({ length: 5 }, () => uuidv4()); + const operations = [ + ...newFileIds.map((fileId) => + addAgentResourceFile({ + agent_id: agent.id, + tool_resource: 'test_tool', + file_id: fileId, + }), + ), + ...initialFileIds.map((fileId) => + removeAgentResourceFiles({ + agent_id: agent.id, + files: [{ tool_resource: 'test_tool', file_id: fileId }], + }), + ), + ]; + + await Promise.all(operations); + + const updatedAgent = await Agent.findOne({ id: agent.id }); + expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined(); + expect(updatedAgent.tool_resources.test_tool.file_ids).toHaveLength(5); + }); + + test('should initialize array when adding to non-existent tool resource', async () => { + const agent = await createBasicAgent(); + const fileId = uuidv4(); + + const updatedAgent = await addAgentResourceFile({ + agent_id: agent.id, + tool_resource: 'new_tool', + file_id: fileId, + }); + + expect(updatedAgent.tool_resources.new_tool.file_ids).toBeDefined(); + expect(updatedAgent.tool_resources.new_tool.file_ids).toHaveLength(1); + expect(updatedAgent.tool_resources.new_tool.file_ids[0]).toBe(fileId); + }); + + test('should handle rapid sequential modifications to same tool resource', async () => { + const agent = await createBasicAgent(); + const fileId = uuidv4(); + + for (let i = 0; i < 10; i++) { + await addAgentResourceFile({ + agent_id: agent.id, + tool_resource: 'test_tool', + file_id: `${fileId}_${i}`, + }); + + if (i % 2 === 0) { + await removeAgentResourceFiles({ + agent_id: agent.id, + files: [{ tool_resource: 'test_tool', file_id: `${fileId}_${i}` }], + }); + } + } + + const updatedAgent = await Agent.findOne({ id: agent.id }); + expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined(); + expect(Array.isArray(updatedAgent.tool_resources.test_tool.file_ids)).toBe(true); + }); + + test('should handle multiple tool resources concurrently', async () => { + const agent = await createBasicAgent(); + const toolResources = ['tool1', 'tool2', 'tool3']; + const operations = []; + + toolResources.forEach((tool) => { + const fileIds = Array.from({ length: 5 }, () => uuidv4()); + fileIds.forEach((fileId) => { + operations.push( + addAgentResourceFile({ + agent_id: agent.id, + tool_resource: tool, + file_id: fileId, + }), + ); + }); + }); + + await Promise.all(operations); + + const updatedAgent = await Agent.findOne({ id: agent.id }); + toolResources.forEach((tool) => { + expect(updatedAgent.tool_resources[tool].file_ids).toBeDefined(); + expect(updatedAgent.tool_resources[tool].file_ids).toHaveLength(5); + }); + }); + + test.each([ + { + name: 'duplicate additions', + operation: 'add', + duplicateCount: 5, + expectedLength: 1, + expectedContains: true, + }, + { + name: 'duplicate removals', + operation: 'remove', + duplicateCount: 5, + expectedLength: 0, + expectedContains: false, + setupFile: true, + }, + ])( + 'should handle concurrent $name', + async ({ operation, duplicateCount, expectedLength, expectedContains, setupFile }) => { + const agent = await createBasicAgent(); + const fileId = uuidv4(); + + if (setupFile) { + await addAgentResourceFile({ + agent_id: agent.id, + tool_resource: 'test_tool', + file_id: fileId, + }); + } + + const promises = Array.from({ length: duplicateCount }).map(() => + operation === 'add' + ? addAgentResourceFile({ + agent_id: agent.id, + tool_resource: 'test_tool', + file_id: fileId, + }) + : removeAgentResourceFiles({ + agent_id: agent.id, + files: [{ tool_resource: 'test_tool', file_id: fileId }], + }), + ); + + await Promise.all(promises); + + const updatedAgent = await Agent.findOne({ id: agent.id }); + const fileIds = updatedAgent.tool_resources?.test_tool?.file_ids ?? []; + + expect(fileIds).toHaveLength(expectedLength); + if (expectedContains) { + expect(fileIds[0]).toBe(fileId); + } else { + expect(fileIds).not.toContain(fileId); + } + }, + ); + + test('should handle concurrent add and remove of the same file', async () => { + const agent = await createBasicAgent(); + const fileId = uuidv4(); + + await addAgentResourceFile({ agent_id: agent.id, tool_resource: 'test_tool', file_id: fileId, - }), - ); + }); - await Promise.all(additionPromises); - - const updatedAgent = await Agent.findOne({ id: agent.id }); - expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined(); - expect(updatedAgent.tool_resources.test_tool.file_ids).toHaveLength(10); - expect(new Set(updatedAgent.tool_resources.test_tool.file_ids).size).toBe(10); - }); - - test('should handle concurrent additions and removals', async () => { - const agent = await createBasicAgent(); - const initialFileIds = Array.from({ length: 5 }, () => uuidv4()); - - await Promise.all( - initialFileIds.map((fileId) => + const operations = [ addAgentResourceFile({ agent_id: agent.id, tool_resource: 'test_tool', file_id: fileId, }), - ), - ); - - const newFileIds = Array.from({ length: 5 }, () => uuidv4()); - const operations = [ - ...newFileIds.map((fileId) => - addAgentResourceFile({ - agent_id: agent.id, - tool_resource: 'test_tool', - file_id: fileId, - }), - ), - ...initialFileIds.map((fileId) => removeAgentResourceFiles({ agent_id: agent.id, files: [{ tool_resource: 'test_tool', file_id: fileId }], }), - ), - ]; + ]; - await Promise.all(operations); + await Promise.all(operations); - const updatedAgent = await Agent.findOne({ id: agent.id }); - expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined(); - expect(updatedAgent.tool_resources.test_tool.file_ids).toHaveLength(5); - }); + const updatedAgent = await Agent.findOne({ id: agent.id }); + const finalFileIds = updatedAgent.tool_resources.test_tool.file_ids; + const count = finalFileIds.filter((id) => id === fileId).length; - test('should initialize array when adding to non-existent tool resource', async () => { - const agent = await createBasicAgent(); - const fileId = uuidv4(); - - const updatedAgent = await addAgentResourceFile({ - agent_id: agent.id, - tool_resource: 'new_tool', - file_id: fileId, + expect(count).toBeLessThanOrEqual(1); + if (count === 0) { + expect(finalFileIds).toHaveLength(0); + } else { + expect(finalFileIds).toHaveLength(1); + expect(finalFileIds[0]).toBe(fileId); + } }); - expect(updatedAgent.tool_resources.new_tool.file_ids).toBeDefined(); - expect(updatedAgent.tool_resources.new_tool.file_ids).toHaveLength(1); - expect(updatedAgent.tool_resources.new_tool.file_ids[0]).toBe(fileId); - }); + test('should handle concurrent removals of different files', async () => { + const agent = await createBasicAgent(); + const fileIds = Array.from({ length: 10 }, () => uuidv4()); - test('should handle rapid sequential modifications to same tool resource', async () => { - const agent = await createBasicAgent(); - const fileId = uuidv4(); - - for (let i = 0; i < 10; i++) { - await addAgentResourceFile({ - agent_id: agent.id, - tool_resource: 'test_tool', - file_id: `${fileId}_${i}`, - }); - - if (i % 2 === 0) { - await removeAgentResourceFiles({ - agent_id: agent.id, - files: [{ tool_resource: 'test_tool', file_id: `${fileId}_${i}` }], - }); - } - } - - const updatedAgent = await Agent.findOne({ id: agent.id }); - expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined(); - expect(Array.isArray(updatedAgent.tool_resources.test_tool.file_ids)).toBe(true); - }); - - test('should handle multiple tool resources concurrently', async () => { - const agent = await createBasicAgent(); - const toolResources = ['tool1', 'tool2', 'tool3']; - const operations = []; - - toolResources.forEach((tool) => { - const fileIds = Array.from({ length: 5 }, () => uuidv4()); - fileIds.forEach((fileId) => { - operations.push( + // Add all files first + await Promise.all( + fileIds.map((fileId) => addAgentResourceFile({ agent_id: agent.id, - tool_resource: tool, + tool_resource: 'test_tool', file_id: fileId, }), - ); + ), + ); + + // Concurrently remove all files + const removalPromises = fileIds.map((fileId) => + removeAgentResourceFiles({ + agent_id: agent.id, + files: [{ tool_resource: 'test_tool', file_id: fileId }], + }), + ); + + await Promise.all(removalPromises); + + const updatedAgent = await Agent.findOne({ id: agent.id }); + // Check if the array is empty or the tool resource itself is removed + const finalFileIds = updatedAgent.tool_resources?.test_tool?.file_ids ?? []; + expect(finalFileIds).toHaveLength(0); + }); + + describe('Edge Cases', () => { + describe.each([ + { + operation: 'add', + name: 'empty file_id', + needsAgent: true, + params: { tool_resource: 'file_search', file_id: '' }, + shouldResolve: true, + }, + { + operation: 'add', + name: 'non-existent agent', + needsAgent: false, + params: { tool_resource: 'file_search', file_id: 'file123' }, + shouldResolve: false, + error: 'Agent not found for adding resource file', + }, + ])('addAgentResourceFile with $name', ({ needsAgent, params, shouldResolve, error }) => { + test(`should ${shouldResolve ? 'resolve' : 'reject'}`, async () => { + const agent = needsAgent ? await createBasicAgent() : null; + const agent_id = needsAgent ? agent.id : `agent_${uuidv4()}`; + + if (shouldResolve) { + await expect(addAgentResourceFile({ agent_id, ...params })).resolves.toBeDefined(); + } else { + await expect(addAgentResourceFile({ agent_id, ...params })).rejects.toThrow(error); + } + }); + }); + + describe.each([ + { + name: 'empty files array', + files: [], + needsAgent: true, + shouldResolve: true, + }, + { + name: 'non-existent tool_resource', + files: [{ tool_resource: 'non_existent_tool', file_id: 'file123' }], + needsAgent: true, + shouldResolve: true, + }, + { + name: 'non-existent agent', + files: [{ tool_resource: 'file_search', file_id: 'file123' }], + needsAgent: false, + shouldResolve: false, + error: 'Agent not found for removing resource files', + }, + ])('removeAgentResourceFiles with $name', ({ files, needsAgent, shouldResolve, error }) => { + test(`should ${shouldResolve ? 'resolve' : 'reject'}`, async () => { + const agent = needsAgent ? await createBasicAgent() : null; + const agent_id = needsAgent ? agent.id : `agent_${uuidv4()}`; + + if (shouldResolve) { + const result = await removeAgentResourceFiles({ agent_id, files }); + expect(result).toBeDefined(); + if (agent) { + expect(result.id).toBe(agent.id); + } + } else { + await expect(removeAgentResourceFiles({ agent_id, files })).rejects.toThrow(error); + } + }); }); }); - - await Promise.all(operations); - - const updatedAgent = await Agent.findOne({ id: agent.id }); - toolResources.forEach((tool) => { - expect(updatedAgent.tool_resources[tool].file_ids).toBeDefined(); - expect(updatedAgent.tool_resources[tool].file_ids).toHaveLength(5); - }); }); - test('should handle concurrent duplicate additions', async () => { - const agent = await createBasicAgent(); - const fileId = uuidv4(); + describe('Agent CRUD Operations', () => { + let mongoServer; - // Concurrent additions of the same file - const additionPromises = Array.from({ length: 5 }).map(() => - addAgentResourceFile({ - agent_id: agent.id, - tool_resource: 'test_tool', - file_id: fileId, - }), - ); - - await Promise.all(additionPromises); - - const updatedAgent = await Agent.findOne({ id: agent.id }); - expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined(); - // Should only contain one instance of the fileId - expect(updatedAgent.tool_resources.test_tool.file_ids).toHaveLength(1); - expect(updatedAgent.tool_resources.test_tool.file_ids[0]).toBe(fileId); - }); - - test('should handle concurrent add and remove of the same file', async () => { - const agent = await createBasicAgent(); - const fileId = uuidv4(); - - // First, ensure the file exists (or test might be trivial if remove runs first) - await addAgentResourceFile({ - agent_id: agent.id, - tool_resource: 'test_tool', - file_id: fileId, + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema); + await mongoose.connect(mongoUri); }); - // Concurrent add (which should be ignored) and remove - const operations = [ - addAgentResourceFile({ - agent_id: agent.id, - tool_resource: 'test_tool', - file_id: fileId, - }), - removeAgentResourceFiles({ - agent_id: agent.id, - files: [{ tool_resource: 'test_tool', file_id: fileId }], - }), - ]; - - await Promise.all(operations); - - const updatedAgent = await Agent.findOne({ id: agent.id }); - // The final state should ideally be that the file is removed, - // but the key point is consistency (not duplicated or error state). - // Depending on execution order, the file might remain if the add operation's - // findOneAndUpdate runs after the remove operation completes. - // A more robust check might be that the length is <= 1. - // Given the remove uses an update pipeline, it might be more likely to win. - // The final state depends on race condition timing (add or remove might "win"). - // The critical part is that the state is consistent (no duplicates, no errors). - // Assert that the fileId is either present exactly once or not present at all. - expect(updatedAgent.tool_resources.test_tool.file_ids).toBeDefined(); - const finalFileIds = updatedAgent.tool_resources.test_tool.file_ids; - const count = finalFileIds.filter((id) => id === fileId).length; - expect(count).toBeLessThanOrEqual(1); // Should be 0 or 1, never more - // Optional: Check overall length is consistent with the count - if (count === 0) { - expect(finalFileIds).toHaveLength(0); - } else { - expect(finalFileIds).toHaveLength(1); - expect(finalFileIds[0]).toBe(fileId); - } - }); - - test('should handle concurrent duplicate removals', async () => { - const agent = await createBasicAgent(); - const fileId = uuidv4(); - - // Add the file first - await addAgentResourceFile({ - agent_id: agent.id, - tool_resource: 'test_tool', - file_id: fileId, + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); }); - // Concurrent removals of the same file - const removalPromises = Array.from({ length: 5 }).map(() => - removeAgentResourceFiles({ - agent_id: agent.id, - files: [{ tool_resource: 'test_tool', file_id: fileId }], - }), - ); - - await Promise.all(removalPromises); - - const updatedAgent = await Agent.findOne({ id: agent.id }); - // Check if the array is empty or the tool resource itself is removed - const fileIds = updatedAgent.tool_resources?.test_tool?.file_ids ?? []; - expect(fileIds).toHaveLength(0); - expect(fileIds).not.toContain(fileId); - }); - - test('should handle concurrent removals of different files', async () => { - const agent = await createBasicAgent(); - const fileIds = Array.from({ length: 10 }, () => uuidv4()); - - // Add all files first - await Promise.all( - fileIds.map((fileId) => - addAgentResourceFile({ - agent_id: agent.id, - tool_resource: 'test_tool', - file_id: fileId, - }), - ), - ); - - // Concurrently remove all files - const removalPromises = fileIds.map((fileId) => - removeAgentResourceFiles({ - agent_id: agent.id, - files: [{ tool_resource: 'test_tool', file_id: fileId }], - }), - ); - - await Promise.all(removalPromises); - - const updatedAgent = await Agent.findOne({ id: agent.id }); - // Check if the array is empty or the tool resource itself is removed - const finalFileIds = updatedAgent.tool_resources?.test_tool?.file_ids ?? []; - expect(finalFileIds).toHaveLength(0); - }); -}); - -describe('Agent CRUD Operations', () => { - let mongoServer; - - beforeAll(async () => { - mongoServer = await MongoMemoryServer.create(); - const mongoUri = mongoServer.getUri(); - await mongoose.connect(mongoUri); - }); - - afterAll(async () => { - await mongoose.disconnect(); - await mongoServer.stop(); - }); - - beforeEach(async () => { - await Agent.deleteMany({}); - }); - - test('should create and get an agent', async () => { - const agentId = `agent_${uuidv4()}`; - const authorId = new mongoose.Types.ObjectId(); - - const newAgent = await createAgent({ - id: agentId, - name: 'Test Agent', - provider: 'test', - model: 'test-model', - author: authorId, - description: 'Test description', + beforeEach(async () => { + await Agent.deleteMany({}); }); - expect(newAgent).toBeDefined(); - expect(newAgent.id).toBe(agentId); - expect(newAgent.name).toBe('Test Agent'); + test('should create and get an agent', async () => { + const { agentId, authorId } = createTestIds(); - const retrievedAgent = await getAgent({ id: agentId }); - expect(retrievedAgent).toBeDefined(); - expect(retrievedAgent.id).toBe(agentId); - expect(retrievedAgent.name).toBe('Test Agent'); - expect(retrievedAgent.description).toBe('Test description'); - }); + const newAgent = await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + description: 'Test description', + }); - test('should delete an agent', async () => { - const agentId = `agent_${uuidv4()}`; - const authorId = new mongoose.Types.ObjectId(); + expect(newAgent).toBeDefined(); + expect(newAgent.id).toBe(agentId); + expect(newAgent.name).toBe('Test Agent'); - await createAgent({ - id: agentId, - name: 'Agent To Delete', - provider: 'test', - model: 'test-model', - author: authorId, + const retrievedAgent = await getAgent({ id: agentId }); + expect(retrievedAgent).toBeDefined(); + expect(retrievedAgent.id).toBe(agentId); + expect(retrievedAgent.name).toBe('Test Agent'); + expect(retrievedAgent.description).toBe('Test description'); }); - const agentBeforeDelete = await getAgent({ id: agentId }); - expect(agentBeforeDelete).toBeDefined(); + test('should delete an agent', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); - await deleteAgent({ id: agentId }); - - const agentAfterDelete = await getAgent({ id: agentId }); - expect(agentAfterDelete).toBeNull(); - }); - - test('should list agents by author', async () => { - const authorId = new mongoose.Types.ObjectId(); - const otherAuthorId = new mongoose.Types.ObjectId(); - - const agentIds = []; - for (let i = 0; i < 5; i++) { - const id = `agent_${uuidv4()}`; - agentIds.push(id); await createAgent({ - id, - name: `Agent ${i}`, + id: agentId, + name: 'Agent To Delete', provider: 'test', model: 'test-model', author: authorId, }); - } - for (let i = 0; i < 3; i++) { + const agentBeforeDelete = await getAgent({ id: agentId }); + expect(agentBeforeDelete).toBeDefined(); + + await deleteAgent({ id: agentId }); + + const agentAfterDelete = await getAgent({ id: agentId }); + expect(agentAfterDelete).toBeNull(); + }); + + test('should list agents by author', async () => { + const authorId = new mongoose.Types.ObjectId(); + const otherAuthorId = new mongoose.Types.ObjectId(); + + const agentIds = []; + for (let i = 0; i < 5; i++) { + const id = `agent_${uuidv4()}`; + agentIds.push(id); + await createAgent({ + id, + name: `Agent ${i}`, + provider: 'test', + model: 'test-model', + author: authorId, + }); + } + + for (let i = 0; i < 3; i++) { + await createAgent({ + id: `other_agent_${uuidv4()}`, + name: `Other Agent ${i}`, + provider: 'test', + model: 'test-model', + author: otherAuthorId, + }); + } + + const result = await getListAgents({ author: authorId.toString() }); + + expect(result).toBeDefined(); + expect(result.data).toBeDefined(); + expect(result.data).toHaveLength(5); + expect(result.has_more).toBe(true); + + for (const agent of result.data) { + expect(agent.author).toBe(authorId.toString()); + } + }); + + test('should update agent projects', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + const projectId1 = new mongoose.Types.ObjectId(); + const projectId2 = new mongoose.Types.ObjectId(); + const projectId3 = new mongoose.Types.ObjectId(); + await createAgent({ - id: `other_agent_${uuidv4()}`, - name: `Other Agent ${i}`, + id: agentId, + name: 'Project Test Agent', provider: 'test', model: 'test-model', - author: otherAuthorId, + author: authorId, + projectIds: [projectId1], }); - } - const result = await getListAgents({ author: authorId.toString() }); + await updateAgent( + { id: agentId }, + { $addToSet: { projectIds: { $each: [projectId2, projectId3] } } }, + ); - expect(result).toBeDefined(); - expect(result.data).toBeDefined(); - expect(result.data).toHaveLength(5); - expect(result.has_more).toBe(true); + await updateAgent({ id: agentId }, { $pull: { projectIds: projectId1 } }); - for (const agent of result.data) { - expect(agent.author).toBe(authorId.toString()); - } - }); + await updateAgent({ id: agentId }, { projectIds: [projectId2, projectId3] }); - test('should update agent projects', async () => { - const agentId = `agent_${uuidv4()}`; - const authorId = new mongoose.Types.ObjectId(); - const projectId1 = new mongoose.Types.ObjectId(); - const projectId2 = new mongoose.Types.ObjectId(); - const projectId3 = new mongoose.Types.ObjectId(); + const updatedAgent = await getAgent({ id: agentId }); + expect(updatedAgent.projectIds).toHaveLength(2); + expect(updatedAgent.projectIds.map((id) => id.toString())).toContain(projectId2.toString()); + expect(updatedAgent.projectIds.map((id) => id.toString())).toContain(projectId3.toString()); + expect(updatedAgent.projectIds.map((id) => id.toString())).not.toContain( + projectId1.toString(), + ); - await createAgent({ - id: agentId, - name: 'Project Test Agent', - provider: 'test', - model: 'test-model', - author: authorId, - projectIds: [projectId1], + await updateAgent({ id: agentId }, { projectIds: [] }); + + const emptyProjectsAgent = await getAgent({ id: agentId }); + expect(emptyProjectsAgent.projectIds).toHaveLength(0); + + const nonExistentId = `agent_${uuidv4()}`; + await expect( + updateAgentProjects({ + id: nonExistentId, + projectIds: [projectId1], + }), + ).rejects.toThrow(); }); - await updateAgent( - { id: agentId }, - { $addToSet: { projectIds: { $each: [projectId2, projectId3] } } }, - ); + test('should handle ephemeral agent loading', async () => { + const agentId = 'ephemeral_test'; + const endpoint = 'openai'; - await updateAgent({ id: agentId }, { $pull: { projectIds: projectId1 } }); + const originalModule = jest.requireActual('librechat-data-provider'); - await updateAgent({ id: agentId }, { projectIds: [projectId2, projectId3] }); + const mockDataProvider = { + ...originalModule, + Constants: { + ...originalModule.Constants, + EPHEMERAL_AGENT_ID: 'ephemeral_test', + }, + }; - const updatedAgent = await getAgent({ id: agentId }); - expect(updatedAgent.projectIds).toHaveLength(2); - expect(updatedAgent.projectIds.map((id) => id.toString())).toContain(projectId2.toString()); - expect(updatedAgent.projectIds.map((id) => id.toString())).toContain(projectId3.toString()); - expect(updatedAgent.projectIds.map((id) => id.toString())).not.toContain(projectId1.toString()); + jest.doMock('librechat-data-provider', () => mockDataProvider); - await updateAgent({ id: agentId }, { projectIds: [] }); + expect(agentId).toBeDefined(); + expect(endpoint).toBeDefined(); - const emptyProjectsAgent = await getAgent({ id: agentId }); - expect(emptyProjectsAgent.projectIds).toHaveLength(0); + jest.dontMock('librechat-data-provider'); + }); - const nonExistentId = `agent_${uuidv4()}`; - await expect( - updateAgentProjects({ - id: nonExistentId, - projectIds: [projectId1], - }), - ).rejects.toThrow(); + test('should handle loadAgent functionality and errors', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Load Agent', + provider: 'test', + model: 'test-model', + author: authorId, + tools: ['tool1', 'tool2'], + }); + + const agent = await getAgent({ id: agentId }); + + expect(agent).toBeDefined(); + expect(agent.id).toBe(agentId); + expect(agent.name).toBe('Test Load Agent'); + expect(agent.tools).toEqual(expect.arrayContaining(['tool1', 'tool2'])); + + const mockLoadAgent = jest.fn().mockResolvedValue(agent); + const loadedAgent = await mockLoadAgent(); + expect(loadedAgent).toBeDefined(); + expect(loadedAgent.id).toBe(agentId); + + const nonExistentId = `agent_${uuidv4()}`; + const nonExistentAgent = await getAgent({ id: nonExistentId }); + expect(nonExistentAgent).toBeNull(); + + const mockLoadAgentError = jest.fn().mockRejectedValue(new Error('No agent found with ID')); + await expect(mockLoadAgentError()).rejects.toThrow('No agent found with ID'); + }); + + describe('Edge Cases', () => { + test.each([ + { + name: 'getAgent with undefined search parameters', + fn: () => getAgent(undefined), + expected: null, + }, + { + name: 'deleteAgent with non-existent agent', + fn: () => deleteAgent({ id: 'non-existent' }), + expected: null, + }, + ])('$name should return null', async ({ fn, expected }) => { + const result = await fn(); + expect(result).toBe(expected); + }); + + test('should handle getListAgents with invalid author format', async () => { + try { + const result = await getListAgents({ author: 'invalid-object-id' }); + expect(result.data).toEqual([]); + } catch (error) { + expect(error).toBeDefined(); + } + }); + + test('should handle getListAgents with no agents', async () => { + const authorId = new mongoose.Types.ObjectId(); + const result = await getListAgents({ author: authorId.toString() }); + + expect(result).toBeDefined(); + expect(result.data).toEqual([]); + expect(result.has_more).toBe(false); + expect(result.first_id).toBeNull(); + expect(result.last_id).toBeNull(); + }); + + test('should handle updateAgentProjects with non-existent agent', async () => { + const nonExistentId = `agent_${uuidv4()}`; + const userId = new mongoose.Types.ObjectId(); + const projectId = new mongoose.Types.ObjectId(); + + const result = await updateAgentProjects({ + user: { id: userId.toString() }, + agentId: nonExistentId, + projectIds: [projectId.toString()], + }); + + expect(result).toBeNull(); + }); + }); }); - test('should handle ephemeral agent loading', async () => { - const agentId = 'ephemeral_test'; - const endpoint = 'openai'; + describe('Agent Version History', () => { + let mongoServer; - const originalModule = jest.requireActual('librechat-data-provider'); + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema); + await mongoose.connect(mongoUri); + }); - const mockDataProvider = { - ...originalModule, - Constants: { - ...originalModule.Constants, - EPHEMERAL_AGENT_ID: 'ephemeral_test', - }, - }; + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); + }); - jest.doMock('librechat-data-provider', () => mockDataProvider); + beforeEach(async () => { + await Agent.deleteMany({}); + }); - const mockReq = { - user: { id: 'user123' }, - body: { - promptPrefix: 'This is a test instruction', - ephemeralAgent: { - execute_code: true, - mcp: ['server1', 'server2'], + test('should create an agent with a single entry in versions array', async () => { + const agent = await createBasicAgent(); + + expect(agent.versions).toBeDefined(); + expect(Array.isArray(agent.versions)).toBe(true); + expect(agent.versions).toHaveLength(1); + expect(agent.versions[0].name).toBe('Test Agent'); + expect(agent.versions[0].provider).toBe('test'); + expect(agent.versions[0].model).toBe('test-model'); + }); + + test('should accumulate version history across multiple updates', async () => { + const agentId = `agent_${uuidv4()}`; + const author = new mongoose.Types.ObjectId(); + await createAgent({ + id: agentId, + name: 'First Name', + provider: 'test', + model: 'test-model', + author, + description: 'First description', + }); + + await updateAgent( + { id: agentId }, + { name: 'Second Name', description: 'Second description' }, + ); + await updateAgent({ id: agentId }, { name: 'Third Name', model: 'new-model' }); + const finalAgent = await updateAgent({ id: agentId }, { description: 'Final description' }); + + expect(finalAgent.versions).toBeDefined(); + expect(Array.isArray(finalAgent.versions)).toBe(true); + expect(finalAgent.versions).toHaveLength(4); + + expect(finalAgent.versions[0].name).toBe('First Name'); + expect(finalAgent.versions[0].description).toBe('First description'); + expect(finalAgent.versions[0].model).toBe('test-model'); + + expect(finalAgent.versions[1].name).toBe('Second Name'); + expect(finalAgent.versions[1].description).toBe('Second description'); + expect(finalAgent.versions[1].model).toBe('test-model'); + + expect(finalAgent.versions[2].name).toBe('Third Name'); + expect(finalAgent.versions[2].description).toBe('Second description'); + expect(finalAgent.versions[2].model).toBe('new-model'); + + expect(finalAgent.versions[3].name).toBe('Third Name'); + expect(finalAgent.versions[3].description).toBe('Final description'); + expect(finalAgent.versions[3].model).toBe('new-model'); + + expect(finalAgent.name).toBe('Third Name'); + expect(finalAgent.description).toBe('Final description'); + expect(finalAgent.model).toBe('new-model'); + }); + + test('should not include metadata fields in version history', async () => { + const agentId = `agent_${uuidv4()}`; + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: new mongoose.Types.ObjectId(), + }); + + const updatedAgent = await updateAgent({ id: agentId }, { description: 'New description' }); + + expect(updatedAgent.versions).toHaveLength(2); + expect(updatedAgent.versions[0]._id).toBeUndefined(); + expect(updatedAgent.versions[0].__v).toBeUndefined(); + expect(updatedAgent.versions[0].name).toBe('Test Agent'); + expect(updatedAgent.versions[0].author).toBeUndefined(); + + expect(updatedAgent.versions[1]._id).toBeUndefined(); + expect(updatedAgent.versions[1].__v).toBeUndefined(); + }); + + test('should not recursively include previous versions', async () => { + const agentId = `agent_${uuidv4()}`; + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: new mongoose.Types.ObjectId(), + }); + + await updateAgent({ id: agentId }, { name: 'Updated Name 1' }); + await updateAgent({ id: agentId }, { name: 'Updated Name 2' }); + const finalAgent = await updateAgent({ id: agentId }, { name: 'Updated Name 3' }); + + expect(finalAgent.versions).toHaveLength(4); + + finalAgent.versions.forEach((version) => { + expect(version.versions).toBeUndefined(); + }); + }); + + test('should handle MongoDB operators and field updates correctly', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + const projectId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'MongoDB Operator Test', + provider: 'test', + model: 'test-model', + author: authorId, + tools: ['tool1'], + }); + + await updateAgent( + { id: agentId }, + { + description: 'Updated description', + $push: { tools: 'tool2' }, + $addToSet: { projectIds: projectId }, }, - }, - app: { - locals: { - availableTools: { - tool__server1: {}, - tool__server2: {}, - another_tool: {}, + ); + + const firstUpdate = await getAgent({ id: agentId }); + expect(firstUpdate.description).toBe('Updated description'); + expect(firstUpdate.tools).toContain('tool1'); + expect(firstUpdate.tools).toContain('tool2'); + expect(firstUpdate.projectIds.map((id) => id.toString())).toContain(projectId.toString()); + expect(firstUpdate.versions).toHaveLength(2); + + await updateAgent( + { id: agentId }, + { + tools: ['tool2', 'tool3'], + }, + ); + + const secondUpdate = await getAgent({ id: agentId }); + expect(secondUpdate.tools).toHaveLength(2); + expect(secondUpdate.tools).toContain('tool2'); + expect(secondUpdate.tools).toContain('tool3'); + expect(secondUpdate.tools).not.toContain('tool1'); + expect(secondUpdate.versions).toHaveLength(3); + + await updateAgent( + { id: agentId }, + { + $push: { tools: 'tool3' }, + }, + ); + + const thirdUpdate = await getAgent({ id: agentId }); + const toolCount = thirdUpdate.tools.filter((t) => t === 'tool3').length; + expect(toolCount).toBe(2); + expect(thirdUpdate.versions).toHaveLength(4); + }); + + test('should handle parameter objects correctly', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Parameters Test', + provider: 'test', + model: 'test-model', + author: authorId, + model_parameters: { temperature: 0.7 }, + }); + + const updatedAgent = await updateAgent( + { id: agentId }, + { model_parameters: { temperature: 0.8 } }, + ); + + expect(updatedAgent.versions).toHaveLength(2); + expect(updatedAgent.model_parameters.temperature).toBe(0.8); + + await updateAgent( + { id: agentId }, + { + model_parameters: { + temperature: 0.8, + max_tokens: 1000, }, }, - }, - }; + ); - const params = { - req: mockReq, - agent_id: agentId, - endpoint, - model_parameters: { - model: 'gpt-4', - temperature: 0.7, - }, - }; + const complexAgent = await getAgent({ id: agentId }); + expect(complexAgent.versions).toHaveLength(3); + expect(complexAgent.model_parameters.temperature).toBe(0.8); + expect(complexAgent.model_parameters.max_tokens).toBe(1000); - expect(agentId).toBeDefined(); - expect(endpoint).toBeDefined(); + await updateAgent({ id: agentId }, { model_parameters: {} }); - jest.dontMock('librechat-data-provider'); - }); - - test('should handle loadAgent functionality and errors', async () => { - const agentId = `agent_${uuidv4()}`; - const authorId = new mongoose.Types.ObjectId(); - - await createAgent({ - id: agentId, - name: 'Test Load Agent', - provider: 'test', - model: 'test-model', - author: authorId, - tools: ['tool1', 'tool2'], + const emptyParamsAgent = await getAgent({ id: agentId }); + expect(emptyParamsAgent.versions).toHaveLength(4); + expect(emptyParamsAgent.model_parameters).toEqual({}); }); - const agent = await getAgent({ id: agentId }); + test('should detect duplicate versions and reject updates', async () => { + const originalConsoleError = console.error; + console.error = jest.fn(); - expect(agent).toBeDefined(); - expect(agent.id).toBe(agentId); - expect(agent.name).toBe('Test Load Agent'); - expect(agent.tools).toEqual(expect.arrayContaining(['tool1', 'tool2'])); + try { + const authorId = new mongoose.Types.ObjectId(); + const testCases = generateVersionTestCases(); - const mockLoadAgent = jest.fn().mockResolvedValue(agent); - const loadedAgent = await mockLoadAgent(); - expect(loadedAgent).toBeDefined(); - expect(loadedAgent.id).toBe(agentId); + for (const testCase of testCases) { + const testAgentId = `agent_${uuidv4()}`; - const nonExistentId = `agent_${uuidv4()}`; - const nonExistentAgent = await getAgent({ id: nonExistentId }); - expect(nonExistentAgent).toBeNull(); + await createAgent({ + id: testAgentId, + provider: 'test', + model: 'test-model', + author: authorId, + ...testCase.initial, + }); - const mockLoadAgentError = jest.fn().mockRejectedValue(new Error('No agent found with ID')); - await expect(mockLoadAgentError()).rejects.toThrow('No agent found with ID'); - }); -}); + await updateAgent({ id: testAgentId }, testCase.update); -describe('Agent Version History', () => { - let mongoServer; + let error; + try { + await updateAgent({ id: testAgentId }, testCase.duplicate); + } catch (e) { + error = e; + } - beforeAll(async () => { - mongoServer = await MongoMemoryServer.create(); - const mongoUri = mongoServer.getUri(); - await mongoose.connect(mongoUri); - }); + expect(error).toBeDefined(); + expect(error.message).toContain('Duplicate version'); + expect(error.statusCode).toBe(409); + expect(error.details).toBeDefined(); + expect(error.details.duplicateVersion).toBeDefined(); - afterAll(async () => { - await mongoose.disconnect(); - await mongoServer.stop(); - }); - - beforeEach(async () => { - await Agent.deleteMany({}); - }); - - test('should create an agent with a single entry in versions array', async () => { - const agentId = `agent_${uuidv4()}`; - const agent = await createAgent({ - id: agentId, - name: 'Test Agent', - provider: 'test', - model: 'test-model', - author: new mongoose.Types.ObjectId(), + const agent = await getAgent({ id: testAgentId }); + expect(agent.versions).toHaveLength(2); + } + } finally { + console.error = originalConsoleError; + } }); - expect(agent.versions).toBeDefined(); - expect(Array.isArray(agent.versions)).toBe(true); - expect(agent.versions).toHaveLength(1); - expect(agent.versions[0].name).toBe('Test Agent'); - expect(agent.versions[0].provider).toBe('test'); - expect(agent.versions[0].model).toBe('test-model'); - }); + test('should track updatedBy when a different user updates an agent', async () => { + const agentId = `agent_${uuidv4()}`; + const originalAuthor = new mongoose.Types.ObjectId(); + const updatingUser = new mongoose.Types.ObjectId(); - test('should accumulate version history across multiple updates', async () => { - const agentId = `agent_${uuidv4()}`; - const author = new mongoose.Types.ObjectId(); - await createAgent({ - id: agentId, - name: 'First Name', - provider: 'test', - model: 'test-model', - author, - description: 'First description', + await createAgent({ + id: agentId, + name: 'Original Agent', + provider: 'test', + model: 'test-model', + author: originalAuthor, + description: 'Original description', + }); + + const updatedAgent = await updateAgent( + { id: agentId }, + { name: 'Updated Agent', description: 'Updated description' }, + { updatingUserId: updatingUser.toString() }, + ); + + expect(updatedAgent.versions).toHaveLength(2); + expect(updatedAgent.versions[1].updatedBy.toString()).toBe(updatingUser.toString()); + expect(updatedAgent.author.toString()).toBe(originalAuthor.toString()); }); - await updateAgent({ id: agentId }, { name: 'Second Name', description: 'Second description' }); - await updateAgent({ id: agentId }, { name: 'Third Name', model: 'new-model' }); - const finalAgent = await updateAgent({ id: agentId }, { description: 'Final description' }); + test('should include updatedBy even when the original author updates the agent', async () => { + const agentId = `agent_${uuidv4()}`; + const originalAuthor = new mongoose.Types.ObjectId(); - expect(finalAgent.versions).toBeDefined(); - expect(Array.isArray(finalAgent.versions)).toBe(true); - expect(finalAgent.versions).toHaveLength(4); + await createAgent({ + id: agentId, + name: 'Original Agent', + provider: 'test', + model: 'test-model', + author: originalAuthor, + description: 'Original description', + }); - expect(finalAgent.versions[0].name).toBe('First Name'); - expect(finalAgent.versions[0].description).toBe('First description'); - expect(finalAgent.versions[0].model).toBe('test-model'); + const updatedAgent = await updateAgent( + { id: agentId }, + { name: 'Updated Agent', description: 'Updated description' }, + { updatingUserId: originalAuthor.toString() }, + ); - expect(finalAgent.versions[1].name).toBe('Second Name'); - expect(finalAgent.versions[1].description).toBe('Second description'); - expect(finalAgent.versions[1].model).toBe('test-model'); - - expect(finalAgent.versions[2].name).toBe('Third Name'); - expect(finalAgent.versions[2].description).toBe('Second description'); - expect(finalAgent.versions[2].model).toBe('new-model'); - - expect(finalAgent.versions[3].name).toBe('Third Name'); - expect(finalAgent.versions[3].description).toBe('Final description'); - expect(finalAgent.versions[3].model).toBe('new-model'); - - expect(finalAgent.name).toBe('Third Name'); - expect(finalAgent.description).toBe('Final description'); - expect(finalAgent.model).toBe('new-model'); - }); - - test('should not include metadata fields in version history', async () => { - const agentId = `agent_${uuidv4()}`; - await createAgent({ - id: agentId, - name: 'Test Agent', - provider: 'test', - model: 'test-model', - author: new mongoose.Types.ObjectId(), + expect(updatedAgent.versions).toHaveLength(2); + expect(updatedAgent.versions[1].updatedBy.toString()).toBe(originalAuthor.toString()); + expect(updatedAgent.author.toString()).toBe(originalAuthor.toString()); }); - const updatedAgent = await updateAgent({ id: agentId }, { description: 'New description' }); + test('should track multiple different users updating the same agent', async () => { + const agentId = `agent_${uuidv4()}`; + const originalAuthor = new mongoose.Types.ObjectId(); + const user1 = new mongoose.Types.ObjectId(); + const user2 = new mongoose.Types.ObjectId(); + const user3 = new mongoose.Types.ObjectId(); - expect(updatedAgent.versions).toHaveLength(2); - expect(updatedAgent.versions[0]._id).toBeUndefined(); - expect(updatedAgent.versions[0].__v).toBeUndefined(); - expect(updatedAgent.versions[0].name).toBe('Test Agent'); - expect(updatedAgent.versions[0].author).toBeUndefined(); + await createAgent({ + id: agentId, + name: 'Original Agent', + provider: 'test', + model: 'test-model', + author: originalAuthor, + description: 'Original description', + }); - expect(updatedAgent.versions[1]._id).toBeUndefined(); - expect(updatedAgent.versions[1].__v).toBeUndefined(); - }); + // User 1 makes an update + await updateAgent( + { id: agentId }, + { name: 'Updated by User 1', description: 'First update' }, + { updatingUserId: user1.toString() }, + ); - test('should not recursively include previous versions', async () => { - const agentId = `agent_${uuidv4()}`; - await createAgent({ - id: agentId, - name: 'Test Agent', - provider: 'test', - model: 'test-model', - author: new mongoose.Types.ObjectId(), + // Original author makes an update + await updateAgent( + { id: agentId }, + { description: 'Updated by original author' }, + { updatingUserId: originalAuthor.toString() }, + ); + + // User 2 makes an update + await updateAgent( + { id: agentId }, + { name: 'Updated by User 2', model: 'new-model' }, + { updatingUserId: user2.toString() }, + ); + + // User 3 makes an update + const finalAgent = await updateAgent( + { id: agentId }, + { description: 'Final update by User 3' }, + { updatingUserId: user3.toString() }, + ); + + expect(finalAgent.versions).toHaveLength(5); + expect(finalAgent.author.toString()).toBe(originalAuthor.toString()); + + // Check that each version has the correct updatedBy + expect(finalAgent.versions[0].updatedBy).toBeUndefined(); // Initial creation has no updatedBy + expect(finalAgent.versions[1].updatedBy.toString()).toBe(user1.toString()); + expect(finalAgent.versions[2].updatedBy.toString()).toBe(originalAuthor.toString()); + expect(finalAgent.versions[3].updatedBy.toString()).toBe(user2.toString()); + expect(finalAgent.versions[4].updatedBy.toString()).toBe(user3.toString()); + + // Verify the final state + expect(finalAgent.name).toBe('Updated by User 2'); + expect(finalAgent.description).toBe('Final update by User 3'); + expect(finalAgent.model).toBe('new-model'); }); - await updateAgent({ id: agentId }, { name: 'Updated Name 1' }); - await updateAgent({ id: agentId }, { name: 'Updated Name 2' }); - const finalAgent = await updateAgent({ id: agentId }, { name: 'Updated Name 3' }); + test('should preserve original author during agent restoration', async () => { + const agentId = `agent_${uuidv4()}`; + const originalAuthor = new mongoose.Types.ObjectId(); + const updatingUser = new mongoose.Types.ObjectId(); - expect(finalAgent.versions).toHaveLength(4); + await createAgent({ + id: agentId, + name: 'Original Agent', + provider: 'test', + model: 'test-model', + author: originalAuthor, + description: 'Original description', + }); - finalAgent.versions.forEach((version) => { - expect(version.versions).toBeUndefined(); - }); - }); + await updateAgent( + { id: agentId }, + { name: 'Updated Agent', description: 'Updated description' }, + { updatingUserId: updatingUser.toString() }, + ); - test('should handle MongoDB operators and field updates correctly', async () => { - const agentId = `agent_${uuidv4()}`; - const authorId = new mongoose.Types.ObjectId(); - const projectId = new mongoose.Types.ObjectId(); + const { revertAgentVersion } = require('./Agent'); + const revertedAgent = await revertAgentVersion({ id: agentId }, 0); - await createAgent({ - id: agentId, - name: 'MongoDB Operator Test', - provider: 'test', - model: 'test-model', - author: authorId, - tools: ['tool1'], + expect(revertedAgent.author.toString()).toBe(originalAuthor.toString()); + expect(revertedAgent.name).toBe('Original Agent'); + expect(revertedAgent.description).toBe('Original description'); }); - await updateAgent( - { id: agentId }, - { - description: 'Updated description', - $push: { tools: 'tool2' }, - $addToSet: { projectIds: projectId }, - }, - ); + test('should detect action metadata changes and force version update', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + const actionId = 'testActionId123'; - const firstUpdate = await getAgent({ id: agentId }); - expect(firstUpdate.description).toBe('Updated description'); - expect(firstUpdate.tools).toContain('tool1'); - expect(firstUpdate.tools).toContain('tool2'); - expect(firstUpdate.projectIds.map((id) => id.toString())).toContain(projectId.toString()); - expect(firstUpdate.versions).toHaveLength(2); + // Create agent with actions + await createAgent({ + id: agentId, + name: 'Agent with Actions', + provider: 'test', + model: 'test-model', + author: authorId, + actions: [`test.com_action_${actionId}`], + tools: ['listEvents_action_test.com', 'createEvent_action_test.com'], + }); - await updateAgent( - { id: agentId }, - { - tools: ['tool2', 'tool3'], - }, - ); + // First update with forceVersion should create a version + const firstUpdate = await updateAgent( + { id: agentId }, + { tools: ['listEvents_action_test.com', 'createEvent_action_test.com'] }, + { updatingUserId: authorId.toString(), forceVersion: true }, + ); - const secondUpdate = await getAgent({ id: agentId }); - expect(secondUpdate.tools).toHaveLength(2); - expect(secondUpdate.tools).toContain('tool2'); - expect(secondUpdate.tools).toContain('tool3'); - expect(secondUpdate.tools).not.toContain('tool1'); - expect(secondUpdate.versions).toHaveLength(3); + expect(firstUpdate.versions).toHaveLength(2); - await updateAgent( - { id: agentId }, - { - $push: { tools: 'tool3' }, - }, - ); + // Second update with same data but forceVersion should still create a version + const secondUpdate = await updateAgent( + { id: agentId }, + { tools: ['listEvents_action_test.com', 'createEvent_action_test.com'] }, + { updatingUserId: authorId.toString(), forceVersion: true }, + ); - const thirdUpdate = await getAgent({ id: agentId }); - const toolCount = thirdUpdate.tools.filter((t) => t === 'tool3').length; - expect(toolCount).toBe(2); - expect(thirdUpdate.versions).toHaveLength(4); - }); + expect(secondUpdate.versions).toHaveLength(3); - test('should handle parameter objects correctly', async () => { - const agentId = `agent_${uuidv4()}`; - const authorId = new mongoose.Types.ObjectId(); + // Update without forceVersion and no changes should not create a version + let error; + try { + await updateAgent( + { id: agentId }, + { tools: ['listEvents_action_test.com', 'createEvent_action_test.com'] }, + { updatingUserId: authorId.toString(), forceVersion: false }, + ); + } catch (e) { + error = e; + } - await createAgent({ - id: agentId, - name: 'Parameters Test', - provider: 'test', - model: 'test-model', - author: authorId, - model_parameters: { temperature: 0.7 }, + expect(error).toBeDefined(); + expect(error.message).toContain('Duplicate version'); + expect(error.statusCode).toBe(409); }); - const updatedAgent = await updateAgent( - { id: agentId }, - { model_parameters: { temperature: 0.8 } }, - ); + test('should handle isDuplicateVersion with arrays containing null/undefined values', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); - expect(updatedAgent.versions).toHaveLength(2); - expect(updatedAgent.model_parameters.temperature).toBe(0.8); + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + tools: ['tool1', null, 'tool2', undefined], + }); - await updateAgent( - { id: agentId }, - { - model_parameters: { - temperature: 0.8, - max_tokens: 1000, + // Update with same array but different null/undefined arrangement + const updatedAgent = await updateAgent({ id: agentId }, { tools: ['tool1', 'tool2'] }); + + expect(updatedAgent.versions).toHaveLength(2); + expect(updatedAgent.tools).toEqual(['tool1', 'tool2']); + }); + + test('should handle isDuplicateVersion with empty objects in tool_kwargs', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + tool_kwargs: [ + { tool: 'tool1', config: { setting: 'value' } }, + {}, + { tool: 'tool2', config: {} }, + ], + }); + + // Try to update with reordered but equivalent tool_kwargs + const updatedAgent = await updateAgent( + { id: agentId }, + { + tool_kwargs: [ + { tool: 'tool2', config: {} }, + { tool: 'tool1', config: { setting: 'value' } }, + {}, + ], }, - }, - ); + ); - const complexAgent = await getAgent({ id: agentId }); - expect(complexAgent.versions).toHaveLength(3); - expect(complexAgent.model_parameters.temperature).toBe(0.8); - expect(complexAgent.model_parameters.max_tokens).toBe(1000); + // Should create new version as order matters for arrays + expect(updatedAgent.versions).toHaveLength(2); + }); - await updateAgent({ id: agentId }, { model_parameters: {} }); + test('should handle isDuplicateVersion with mixed primitive and object arrays', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); - const emptyParamsAgent = await getAgent({ id: agentId }); - expect(emptyParamsAgent.versions).toHaveLength(4); - expect(emptyParamsAgent.model_parameters).toEqual({}); + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + mixed_array: [1, 'string', { key: 'value' }, true, null], + }); + + // Update with same values but different types + const updatedAgent = await updateAgent( + { id: agentId }, + { mixed_array: ['1', 'string', { key: 'value' }, 'true', null] }, + ); + + // Should create new version as types differ + expect(updatedAgent.versions).toHaveLength(2); + }); + + test('should handle isDuplicateVersion with deeply nested objects', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + const deepObject = { + level1: { + level2: { + level3: { + level4: { + value: 'deep', + array: [1, 2, { nested: true }], + }, + }, + }, + }, + }; + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + model_parameters: deepObject, + }); + + // First create a version with changes + await updateAgent({ id: agentId }, { description: 'Updated' }); + + // Then try to create duplicate of the original version + await updateAgent( + { id: agentId }, + { + model_parameters: deepObject, + description: undefined, + }, + ); + + // Since we're updating back to the same model_parameters but with a different description, + // it should create a new version + const agent = await getAgent({ id: agentId }); + expect(agent.versions).toHaveLength(3); + }); + + test('should handle version comparison with special field types', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + const projectId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + projectIds: [projectId], + model_parameters: { temperature: 0.7 }, + }); + + // Update with a real field change first + const firstUpdate = await updateAgent({ id: agentId }, { description: 'New description' }); + + expect(firstUpdate.versions).toHaveLength(2); + + // Update with model parameters change + const secondUpdate = await updateAgent( + { id: agentId }, + { model_parameters: { temperature: 0.8 } }, + ); + + expect(secondUpdate.versions).toHaveLength(3); + }); + + describe('Edge Cases', () => { + test('should handle extremely large version history', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Version Test', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + for (let i = 0; i < 20; i++) { + await updateAgent({ id: agentId }, { description: `Version ${i}` }); + } + + const agent = await getAgent({ id: agentId }); + expect(agent.versions).toHaveLength(21); + expect(agent.description).toBe('Version 19'); + }); + + test('should handle revertAgentVersion with invalid version index', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + await expect(revertAgentVersion({ id: agentId }, 5)).rejects.toThrow('Version 5 not found'); + }); + + test('should handle revertAgentVersion with non-existent agent', async () => { + const nonExistentId = `agent_${uuidv4()}`; + + await expect(revertAgentVersion({ id: nonExistentId }, 0)).rejects.toThrow( + 'Agent not found', + ); + }); + + test('should handle updateAgent with empty update object', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + const updatedAgent = await updateAgent({ id: agentId }, {}); + + expect(updatedAgent).toBeDefined(); + expect(updatedAgent.name).toBe('Test Agent'); + expect(updatedAgent.versions).toHaveLength(1); + }); + }); }); - test('should detect duplicate versions and reject updates', async () => { - const originalConsoleError = console.error; - console.error = jest.fn(); + describe('Action Metadata and Hash Generation', () => { + let mongoServer; - try { + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema); + await mongoose.connect(mongoUri); + }); + + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); + }); + + beforeEach(async () => { + await Agent.deleteMany({}); + }); + + test('should generate consistent hash for same action metadata', async () => { + const actionIds = ['test.com_action_123', 'example.com_action_456']; + const actions = [ + { + action_id: '123', + metadata: { version: '1.0', endpoints: ['GET /api/test'], schema: { type: 'object' } }, + }, + { + action_id: '456', + metadata: { + version: '2.0', + endpoints: ['POST /api/example'], + schema: { type: 'string' }, + }, + }, + ]; + + const hash1 = await generateActionMetadataHash(actionIds, actions); + const hash2 = await generateActionMetadataHash(actionIds, actions); + + expect(hash1).toBe(hash2); + expect(typeof hash1).toBe('string'); + expect(hash1.length).toBe(64); // SHA-256 produces 64 character hex string + }); + + test('should generate different hashes for different action metadata', async () => { + const actionIds = ['test.com_action_123']; + const actions1 = [ + { action_id: '123', metadata: { version: '1.0', endpoints: ['GET /api/test'] } }, + ]; + const actions2 = [ + { action_id: '123', metadata: { version: '2.0', endpoints: ['GET /api/test'] } }, + ]; + + const hash1 = await generateActionMetadataHash(actionIds, actions1); + const hash2 = await generateActionMetadataHash(actionIds, actions2); + + expect(hash1).not.toBe(hash2); + }); + + test('should handle empty action arrays', async () => { + const hash = await generateActionMetadataHash([], []); + expect(hash).toBe(''); + }); + + test('should handle null or undefined action arrays', async () => { + const hash1 = await generateActionMetadataHash(null, []); + const hash2 = await generateActionMetadataHash(undefined, []); + + expect(hash1).toBe(''); + expect(hash2).toBe(''); + }); + + test('should handle missing action metadata gracefully', async () => { + const actionIds = ['test.com_action_123', 'missing.com_action_999']; + const actions = [ + { action_id: '123', metadata: { version: '1.0' } }, + // missing action with id '999' + ]; + + const hash = await generateActionMetadataHash(actionIds, actions); + expect(typeof hash).toBe('string'); + expect(hash.length).toBe(64); + }); + + test('should sort action IDs for consistent hashing', async () => { + const actionIds1 = ['b.com_action_2', 'a.com_action_1']; + const actionIds2 = ['a.com_action_1', 'b.com_action_2']; + const actions = [ + { action_id: '1', metadata: { version: '1.0' } }, + { action_id: '2', metadata: { version: '2.0' } }, + ]; + + const hash1 = await generateActionMetadataHash(actionIds1, actions); + const hash2 = await generateActionMetadataHash(actionIds2, actions); + + expect(hash1).toBe(hash2); + }); + + test('should handle complex nested metadata objects', async () => { + const actionIds = ['complex.com_action_1']; + const actions = [ + { + action_id: '1', + metadata: { + version: '1.0', + schema: { + type: 'object', + properties: { + name: { type: 'string' }, + nested: { + type: 'object', + properties: { + id: { type: 'number' }, + tags: { type: 'array', items: { type: 'string' } }, + }, + }, + }, + }, + endpoints: [ + { path: '/api/test', method: 'GET', params: ['id'] }, + { path: '/api/create', method: 'POST', body: true }, + ], + }, + }, + ]; + + const hash = await generateActionMetadataHash(actionIds, actions); + expect(typeof hash).toBe('string'); + expect(hash.length).toBe(64); + }); + + describe('Edge Cases', () => { + test('should handle generateActionMetadataHash with null metadata', async () => { + const hash = await generateActionMetadataHash( + ['test.com_action_1'], + [{ action_id: '1', metadata: null }], + ); + expect(typeof hash).toBe('string'); + }); + + test('should handle generateActionMetadataHash with deeply nested metadata', async () => { + const deepMetadata = { + level1: { + level2: { + level3: { + level4: { + level5: 'deep value', + array: [1, 2, { nested: true }], + }, + }, + }, + }, + }; + + const hash = await generateActionMetadataHash( + ['test.com_action_1'], + [{ action_id: '1', metadata: deepMetadata }], + ); + + expect(typeof hash).toBe('string'); + expect(hash.length).toBe(64); + }); + + test('should handle generateActionMetadataHash with special characters', async () => { + const specialMetadata = { + unicode: '🚀🎉👍', + symbols: '!@#$%^&*()_+-=[]{}|;:,.<>?', + quotes: 'single\'s and "doubles"', + newlines: 'line1\nline2\r\nline3', + }; + + const hash = await generateActionMetadataHash( + ['test.com_action_1'], + [{ action_id: '1', metadata: specialMetadata }], + ); + + expect(typeof hash).toBe('string'); + expect(hash.length).toBe(64); + }); + }); + }); + + describe('Load Agent Functionality', () => { + let mongoServer; + + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema); + await mongoose.connect(mongoUri); + }); + + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); + }); + + beforeEach(async () => { + await Agent.deleteMany({}); + }); + + test('should return null when agent_id is not provided', async () => { + const mockReq = { user: { id: 'user123' } }; + const result = await loadAgent({ + req: mockReq, + agent_id: null, + endpoint: 'openai', + model_parameters: { model: 'gpt-4' }, + }); + + expect(result).toBeNull(); + }); + + test('should return null when agent_id is empty string', async () => { + const mockReq = { user: { id: 'user123' } }; + const result = await loadAgent({ + req: mockReq, + agent_id: '', + endpoint: 'openai', + model_parameters: { model: 'gpt-4' }, + }); + + expect(result).toBeNull(); + }); + + test('should test ephemeral agent loading logic', async () => { + const { EPHEMERAL_AGENT_ID } = require('librechat-data-provider').Constants; + + getCachedTools.mockResolvedValue({ + tool1_mcp_server1: {}, + tool2_mcp_server2: {}, + another_tool: {}, + }); + + const mockReq = { + user: { id: 'user123' }, + body: { + promptPrefix: 'Test instructions', + ephemeralAgent: { + execute_code: true, + web_search: true, + mcp: ['server1', 'server2'], + }, + }, + }; + + const result = await loadAgent({ + req: mockReq, + agent_id: EPHEMERAL_AGENT_ID, + endpoint: 'openai', + model_parameters: { model: 'gpt-4', temperature: 0.7 }, + }); + + if (result) { + expect(result.id).toBe(EPHEMERAL_AGENT_ID); + expect(result.instructions).toBe('Test instructions'); + expect(result.provider).toBe('openai'); + expect(result.model).toBe('gpt-4'); + expect(result.model_parameters.temperature).toBe(0.7); + expect(result.tools).toContain('execute_code'); + expect(result.tools).toContain('web_search'); + expect(result.tools).toContain('tool1_mcp_server1'); + expect(result.tools).toContain('tool2_mcp_server2'); + } else { + expect(result).toBeNull(); + } + }); + + test('should return null for non-existent agent', async () => { + const mockReq = { user: { id: 'user123' } }; + const result = await loadAgent({ + req: mockReq, + agent_id: 'non_existent_agent', + endpoint: 'openai', + model_parameters: { model: 'gpt-4' }, + }); + + expect(result).toBeNull(); + }); + + test('should load agent when user is the author', async () => { + const userId = new mongoose.Types.ObjectId(); + const agentId = `agent_${uuidv4()}`; + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'openai', + model: 'gpt-4', + author: userId, + description: 'Test description', + tools: ['web_search'], + }); + + const mockReq = { user: { id: userId.toString() } }; + const result = await loadAgent({ + req: mockReq, + agent_id: agentId, + endpoint: 'openai', + model_parameters: { model: 'gpt-4' }, + }); + + expect(result).toBeDefined(); + expect(result.id).toBe(agentId); + expect(result.name).toBe('Test Agent'); + expect(result.author.toString()).toBe(userId.toString()); + expect(result.version).toBe(1); + }); + + test('should return null when user is not author and agent has no projectIds', async () => { + const authorId = new mongoose.Types.ObjectId(); + const userId = new mongoose.Types.ObjectId(); + const agentId = `agent_${uuidv4()}`; + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'openai', + model: 'gpt-4', + author: authorId, + }); + + const mockReq = { user: { id: userId.toString() } }; + const result = await loadAgent({ + req: mockReq, + agent_id: agentId, + endpoint: 'openai', + model_parameters: { model: 'gpt-4' }, + }); + + expect(result).toBeFalsy(); + }); + + test('should handle ephemeral agent with no MCP servers', async () => { + const { EPHEMERAL_AGENT_ID } = require('librechat-data-provider').Constants; + + getCachedTools.mockResolvedValue({}); + + const mockReq = { + user: { id: 'user123' }, + body: { + promptPrefix: 'Simple instructions', + ephemeralAgent: { + execute_code: false, + web_search: false, + mcp: [], + }, + }, + }; + + const result = await loadAgent({ + req: mockReq, + agent_id: EPHEMERAL_AGENT_ID, + endpoint: 'openai', + model_parameters: { model: 'gpt-3.5-turbo' }, + }); + + if (result) { + expect(result.tools).toEqual([]); + expect(result.instructions).toBe('Simple instructions'); + } else { + expect(result).toBeFalsy(); + } + }); + + test('should handle ephemeral agent with undefined ephemeralAgent in body', async () => { + const { EPHEMERAL_AGENT_ID } = require('librechat-data-provider').Constants; + + getCachedTools.mockResolvedValue({}); + + const mockReq = { + user: { id: 'user123' }, + body: { + promptPrefix: 'Basic instructions', + }, + }; + + const result = await loadAgent({ + req: mockReq, + agent_id: EPHEMERAL_AGENT_ID, + endpoint: 'openai', + model_parameters: { model: 'gpt-4' }, + }); + + if (result) { + expect(result.tools).toEqual([]); + } else { + expect(result).toBeFalsy(); + } + }); + + describe('Edge Cases', () => { + test('should handle loadAgent with malformed req object', async () => { + const result = await loadAgent({ + req: null, + agent_id: 'test', + endpoint: 'openai', + model_parameters: { model: 'gpt-4' }, + }); + + expect(result).toBeNull(); + }); + + test('should handle ephemeral agent with extremely large tool list', async () => { + const { EPHEMERAL_AGENT_ID } = require('librechat-data-provider').Constants; + + const largeToolList = Array.from({ length: 100 }, (_, i) => `tool_${i}_mcp_server1`); + const availableTools = largeToolList.reduce((acc, tool) => { + acc[tool] = {}; + return acc; + }, {}); + + getCachedTools.mockResolvedValue(availableTools); + + const mockReq = { + user: { id: 'user123' }, + body: { + promptPrefix: 'Test', + ephemeralAgent: { + execute_code: true, + web_search: true, + mcp: ['server1'], + }, + }, + }; + + const result = await loadAgent({ + req: mockReq, + agent_id: EPHEMERAL_AGENT_ID, + endpoint: 'openai', + model_parameters: { model: 'gpt-4' }, + }); + + if (result) { + expect(result.tools.length).toBeGreaterThan(100); + } + }); + + test('should handle loadAgent with agent from different project', async () => { + const authorId = new mongoose.Types.ObjectId(); + const userId = new mongoose.Types.ObjectId(); + const agentId = `agent_${uuidv4()}`; + const projectId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Project Agent', + provider: 'openai', + model: 'gpt-4', + author: authorId, + projectIds: [projectId], + }); + + const mockReq = { user: { id: userId.toString() } }; + const result = await loadAgent({ + req: mockReq, + agent_id: agentId, + endpoint: 'openai', + model_parameters: { model: 'gpt-4' }, + }); + + expect(result).toBeFalsy(); + }); + }); + }); + + describe('Agent Edge Cases and Error Handling', () => { + let mongoServer; + + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema); + await mongoose.connect(mongoUri); + }); + + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); + }); + + beforeEach(async () => { + await Agent.deleteMany({}); + }); + + test('should handle agent creation with minimal required fields', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + const agent = await createAgent({ + id: agentId, + provider: 'test', + model: 'test-model', + author: authorId, + }); + + expect(agent).toBeDefined(); + expect(agent.id).toBe(agentId); + expect(agent.versions).toHaveLength(1); + expect(agent.versions[0].provider).toBe('test'); + expect(agent.versions[0].model).toBe('test-model'); + }); + + test('should handle agent creation with all optional fields', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + const projectId = new mongoose.Types.ObjectId(); + + const agent = await createAgent({ + id: agentId, + name: 'Complex Agent', + provider: 'test', + model: 'test-model', + author: authorId, + description: 'Complex description', + instructions: 'Complex instructions', + tools: ['tool1', 'tool2'], + actions: ['action1', 'action2'], + model_parameters: { temperature: 0.8, max_tokens: 1000 }, + projectIds: [projectId], + avatar: 'https://example.com/avatar.png', + isCollaborative: true, + tool_resources: { + file_search: { file_ids: ['file1', 'file2'] }, + }, + }); + + expect(agent).toBeDefined(); + expect(agent.name).toBe('Complex Agent'); + expect(agent.description).toBe('Complex description'); + expect(agent.instructions).toBe('Complex instructions'); + expect(agent.tools).toEqual(['tool1', 'tool2']); + expect(agent.actions).toEqual(['action1', 'action2']); + expect(agent.model_parameters.temperature).toBe(0.8); + expect(agent.model_parameters.max_tokens).toBe(1000); + expect(agent.projectIds.map((id) => id.toString())).toContain(projectId.toString()); + expect(agent.avatar).toBe('https://example.com/avatar.png'); + expect(agent.isCollaborative).toBe(true); + expect(agent.tool_resources.file_search.file_ids).toEqual(['file1', 'file2']); + }); + + test('should handle updateAgent with empty update object', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + const updatedAgent = await updateAgent({ id: agentId }, {}); + + expect(updatedAgent).toBeDefined(); + expect(updatedAgent.name).toBe('Test Agent'); + expect(updatedAgent.versions).toHaveLength(1); // No new version should be created + }); + + test('should handle concurrent updates to different agents', async () => { + const agent1Id = `agent_${uuidv4()}`; + const agent2Id = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agent1Id, + name: 'Agent 1', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + await createAgent({ + id: agent2Id, + name: 'Agent 2', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + // Concurrent updates to different agents + const [updated1, updated2] = await Promise.all([ + updateAgent({ id: agent1Id }, { description: 'Updated Agent 1' }), + updateAgent({ id: agent2Id }, { description: 'Updated Agent 2' }), + ]); + + expect(updated1.description).toBe('Updated Agent 1'); + expect(updated2.description).toBe('Updated Agent 2'); + expect(updated1.versions).toHaveLength(2); + expect(updated2.versions).toHaveLength(2); + }); + + test('should handle agent deletion with non-existent ID', async () => { + const nonExistentId = `agent_${uuidv4()}`; + const result = await deleteAgent({ id: nonExistentId }); + + expect(result).toBeNull(); + }); + + test('should handle getListAgents with no agents', async () => { + const authorId = new mongoose.Types.ObjectId(); + const result = await getListAgents({ author: authorId.toString() }); + + expect(result).toBeDefined(); + expect(result.data).toEqual([]); + expect(result.has_more).toBe(false); + expect(result.first_id).toBeNull(); + expect(result.last_id).toBeNull(); + }); + + test('should handle updateAgent with MongoDB operators mixed with direct updates', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + tools: ['tool1'], + }); + + // Test with $push and direct field update + const updatedAgent = await updateAgent( + { id: agentId }, + { + name: 'Updated Name', + $push: { tools: 'tool2' }, + }, + ); + + expect(updatedAgent.name).toBe('Updated Name'); + expect(updatedAgent.tools).toContain('tool1'); + expect(updatedAgent.tools).toContain('tool2'); + expect(updatedAgent.versions).toHaveLength(2); + }); + + test('should handle revertAgentVersion with invalid version index', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + // Try to revert to non-existent version + await expect(revertAgentVersion({ id: agentId }, 5)).rejects.toThrow('Version 5 not found'); + }); + + test('should handle revertAgentVersion with non-existent agent', async () => { + const nonExistentId = `agent_${uuidv4()}`; + + await expect(revertAgentVersion({ id: nonExistentId }, 0)).rejects.toThrow('Agent not found'); + }); + + test('should handle addAgentResourceFile with non-existent agent', async () => { + const nonExistentId = `agent_${uuidv4()}`; + const mockReq = { user: { id: 'user123' } }; + + await expect( + addAgentResourceFile({ + req: mockReq, + agent_id: nonExistentId, + tool_resource: 'file_search', + file_id: 'file123', + }), + ).rejects.toThrow('Agent not found for adding resource file'); + }); + + test('should handle removeAgentResourceFiles with non-existent agent', async () => { + const nonExistentId = `agent_${uuidv4()}`; + + await expect( + removeAgentResourceFiles({ + agent_id: nonExistentId, + files: [{ tool_resource: 'file_search', file_id: 'file123' }], + }), + ).rejects.toThrow('Agent not found for removing resource files'); + }); + + test('should handle updateAgent with complex nested updates', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + model_parameters: { temperature: 0.5 }, + tools: ['tool1'], + }); + + // First update with $push operation + const firstUpdate = await updateAgent( + { id: agentId }, + { + $push: { tools: 'tool2' }, + }, + ); + + expect(firstUpdate.tools).toContain('tool1'); + expect(firstUpdate.tools).toContain('tool2'); + + // Second update with direct field update and $addToSet + const secondUpdate = await updateAgent( + { id: agentId }, + { + name: 'Updated Agent', + model_parameters: { temperature: 0.8, max_tokens: 500 }, + $addToSet: { tools: 'tool3' }, + }, + ); + + expect(secondUpdate.name).toBe('Updated Agent'); + expect(secondUpdate.model_parameters.temperature).toBe(0.8); + expect(secondUpdate.model_parameters.max_tokens).toBe(500); + expect(secondUpdate.tools).toContain('tool1'); + expect(secondUpdate.tools).toContain('tool2'); + expect(secondUpdate.tools).toContain('tool3'); + expect(secondUpdate.versions).toHaveLength(3); + }); + + test('should preserve version order in versions array', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Version 1', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + await updateAgent({ id: agentId }, { name: 'Version 2' }); + await updateAgent({ id: agentId }, { name: 'Version 3' }); + const finalAgent = await updateAgent({ id: agentId }, { name: 'Version 4' }); + + expect(finalAgent.versions).toHaveLength(4); + expect(finalAgent.versions[0].name).toBe('Version 1'); + expect(finalAgent.versions[1].name).toBe('Version 2'); + expect(finalAgent.versions[2].name).toBe('Version 3'); + expect(finalAgent.versions[3].name).toBe('Version 4'); + expect(finalAgent.name).toBe('Version 4'); + }); + + test('should handle updateAgentProjects error scenarios', async () => { + const nonExistentId = `agent_${uuidv4()}`; + const userId = new mongoose.Types.ObjectId(); + const projectId = new mongoose.Types.ObjectId(); + + // Test with non-existent agent + const result = await updateAgentProjects({ + user: { id: userId.toString() }, + agentId: nonExistentId, + projectIds: [projectId.toString()], + }); + + expect(result).toBeNull(); + }); + + test('should handle revertAgentVersion properly', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Original Name', + provider: 'test', + model: 'test-model', + author: authorId, + description: 'Original description', + }); + + await updateAgent( + { id: agentId }, + { name: 'Updated Name', description: 'Updated description' }, + ); + + const revertedAgent = await revertAgentVersion({ id: agentId }, 0); + + expect(revertedAgent.name).toBe('Original Name'); + expect(revertedAgent.description).toBe('Original description'); + expect(revertedAgent.author.toString()).toBe(authorId.toString()); + }); + + test('should handle action-related updates with getActions error', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + // Create agent with actions that might cause getActions to fail + await createAgent({ + id: agentId, + name: 'Agent with Actions', + provider: 'test', + model: 'test-model', + author: authorId, + actions: ['test.com_action_invalid_id'], + }); + + // Update should still work even if getActions fails + const updatedAgent = await updateAgent( + { id: agentId }, + { description: 'Updated description' }, + ); + + expect(updatedAgent).toBeDefined(); + expect(updatedAgent.description).toBe('Updated description'); + expect(updatedAgent.versions).toHaveLength(2); + }); + + test('should handle updateAgent with combined MongoDB operators', async () => { const agentId = `agent_${uuidv4()}`; const authorId = new mongoose.Types.ObjectId(); const projectId1 = new mongoose.Types.ObjectId(); const projectId2 = new mongoose.Types.ObjectId(); - const testCases = [ + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + tools: ['tool1'], + projectIds: [projectId1], + }); + + // Use multiple operators in single update - but avoid conflicting operations on same field + const updatedAgent = await updateAgent( + { id: agentId }, { - name: 'simple field update', - initial: { - name: 'Test Agent', - description: 'Initial description', - }, - update: { name: 'Updated Name' }, - duplicate: { name: 'Updated Name' }, + name: 'Updated Name', + $push: { tools: 'tool2' }, + $addToSet: { projectIds: projectId2 }, }, + ); + + const finalAgent = await updateAgent( + { id: agentId }, { - name: 'object field update', - initial: { - model_parameters: { temperature: 0.7 }, - }, - update: { model_parameters: { temperature: 0.8 } }, - duplicate: { model_parameters: { temperature: 0.8 } }, - }, - { - name: 'array field update', - initial: { - tools: ['tool1', 'tool2'], - }, - update: { tools: ['tool2', 'tool3'] }, - duplicate: { tools: ['tool2', 'tool3'] }, - }, - { - name: 'projectIds update', - initial: { - projectIds: [projectId1], - }, - update: { projectIds: [projectId1, projectId2] }, - duplicate: { projectIds: [projectId2, projectId1] }, + $pull: { projectIds: projectId1 }, }, + ); + + expect(updatedAgent).toBeDefined(); + expect(updatedAgent.name).toBe('Updated Name'); + expect(updatedAgent.tools).toContain('tool1'); + expect(updatedAgent.tools).toContain('tool2'); + expect(updatedAgent.projectIds.map((id) => id.toString())).toContain(projectId2.toString()); + + expect(finalAgent).toBeDefined(); + expect(finalAgent.projectIds.map((id) => id.toString())).not.toContain(projectId1.toString()); + expect(finalAgent.versions).toHaveLength(3); + }); + + test('should handle updateAgent when agent does not exist', async () => { + const nonExistentId = `agent_${uuidv4()}`; + + const result = await updateAgent({ id: nonExistentId }, { name: 'New Name' }); + + expect(result).toBeNull(); + }); + + test('should handle concurrent updates with database errors', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + // Mock findOneAndUpdate to simulate database error + const cleanup = mockFindOneAndUpdateError(2); + + // Concurrent updates where one fails + const promises = [ + updateAgent({ id: agentId }, { name: 'Update 1' }), + updateAgent({ id: agentId }, { name: 'Update 2' }), + updateAgent({ id: agentId }, { name: 'Update 3' }), ]; - for (const testCase of testCases) { - const testAgentId = `agent_${uuidv4()}`; + const results = await Promise.allSettled(promises); - await createAgent({ - id: testAgentId, - provider: 'test', - model: 'test-model', - author: authorId, - ...testCase.initial, + cleanup(); + + const succeeded = results.filter((r) => r.status === 'fulfilled').length; + const failed = results.filter((r) => r.status === 'rejected').length; + + expect(succeeded).toBe(2); + expect(failed).toBe(1); + }); + + test('should handle removeAgentResourceFiles when agent is deleted during operation', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + tool_resources: { + file_search: { + file_ids: ['file1', 'file2', 'file3'], + }, + }, + }); + + // Mock findOneAndUpdate to return null (simulating deletion) + const originalFindOneAndUpdate = Agent.findOneAndUpdate; + Agent.findOneAndUpdate = jest.fn().mockImplementation(() => ({ + lean: jest.fn().mockResolvedValue(null), + })); + + // Try to remove files from deleted agent + await expect( + removeAgentResourceFiles({ + agent_id: agentId, + files: [ + { tool_resource: 'file_search', file_id: 'file1' }, + { tool_resource: 'file_search', file_id: 'file2' }, + ], + }), + ).rejects.toThrow('Failed to update agent during file removal (pull step)'); + + Agent.findOneAndUpdate = originalFindOneAndUpdate; + }); + + test('should handle loadEphemeralAgent with malformed MCP tool names', async () => { + const { EPHEMERAL_AGENT_ID } = require('librechat-data-provider').Constants; + + getCachedTools.mockResolvedValue({ + malformed_tool_name: {}, // No mcp delimiter + tool__server1: {}, // Wrong delimiter + tool_mcp_server1: {}, // Correct format + tool_mcp_server2: {}, // Different server + }); + + const mockReq = { + user: { id: 'user123' }, + body: { + promptPrefix: 'Test instructions', + ephemeralAgent: { + execute_code: false, + web_search: false, + mcp: ['server1'], + }, + }, + }; + + const result = await loadAgent({ + req: mockReq, + agent_id: EPHEMERAL_AGENT_ID, + endpoint: 'openai', + model_parameters: { model: 'gpt-4' }, + }); + + if (result) { + expect(result.tools).toEqual(['tool_mcp_server1']); + expect(result.tools).not.toContain('malformed_tool_name'); + expect(result.tools).not.toContain('tool__server1'); + expect(result.tools).not.toContain('tool_mcp_server2'); + } + }); + + test('should handle addAgentResourceFile when array initialization fails', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + // Mock the updateOne operation to fail but let updateAgent succeed + const originalUpdateOne = Agent.updateOne; + let updateOneCalled = false; + Agent.updateOne = jest.fn().mockImplementation((...args) => { + if (!updateOneCalled) { + updateOneCalled = true; + return Promise.reject(new Error('Database error')); + } + return originalUpdateOne.apply(Agent, args); + }); + + try { + const result = await addAgentResourceFile({ + agent_id: agentId, + tool_resource: 'new_tool', + file_id: 'file123', }); - await updateAgent({ id: testAgentId }, testCase.update); - - let error; - try { - await updateAgent({ id: testAgentId }, testCase.duplicate); - } catch (e) { - error = e; - } - - expect(error).toBeDefined(); - expect(error.message).toContain('Duplicate version'); - expect(error.statusCode).toBe(409); - expect(error.details).toBeDefined(); - expect(error.details.duplicateVersion).toBeDefined(); - - const agent = await getAgent({ id: testAgentId }); - expect(agent.versions).toHaveLength(2); + expect(result).toBeDefined(); + expect(result.tools).toContain('new_tool'); + } catch (error) { + expect(error.message).toBe('Database error'); } - } finally { - console.error = originalConsoleError; - } + + Agent.updateOne = originalUpdateOne; + }); }); - test('should track updatedBy when a different user updates an agent', async () => { - const agentId = `agent_${uuidv4()}`; - const originalAuthor = new mongoose.Types.ObjectId(); - const updatingUser = new mongoose.Types.ObjectId(); + describe('Agent IDs Field in Version Detection', () => { + let mongoServer; - await createAgent({ - id: agentId, - name: 'Original Agent', - provider: 'test', - model: 'test-model', - author: originalAuthor, - description: 'Original description', + beforeAll(async () => { + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema); + await mongoose.connect(mongoUri); }); - const updatedAgent = await updateAgent( - { id: agentId }, - { name: 'Updated Agent', description: 'Updated description' }, - { updatingUserId: updatingUser.toString() }, - ); - - expect(updatedAgent.versions).toHaveLength(2); - expect(updatedAgent.versions[1].updatedBy.toString()).toBe(updatingUser.toString()); - expect(updatedAgent.author.toString()).toBe(originalAuthor.toString()); - }); - - test('should include updatedBy even when the original author updates the agent', async () => { - const agentId = `agent_${uuidv4()}`; - const originalAuthor = new mongoose.Types.ObjectId(); - - await createAgent({ - id: agentId, - name: 'Original Agent', - provider: 'test', - model: 'test-model', - author: originalAuthor, - description: 'Original description', + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); }); - const updatedAgent = await updateAgent( - { id: agentId }, - { name: 'Updated Agent', description: 'Updated description' }, - { updatingUserId: originalAuthor.toString() }, - ); - - expect(updatedAgent.versions).toHaveLength(2); - expect(updatedAgent.versions[1].updatedBy.toString()).toBe(originalAuthor.toString()); - expect(updatedAgent.author.toString()).toBe(originalAuthor.toString()); - }); - - test('should track multiple different users updating the same agent', async () => { - const agentId = `agent_${uuidv4()}`; - const originalAuthor = new mongoose.Types.ObjectId(); - const user1 = new mongoose.Types.ObjectId(); - const user2 = new mongoose.Types.ObjectId(); - const user3 = new mongoose.Types.ObjectId(); - - await createAgent({ - id: agentId, - name: 'Original Agent', - provider: 'test', - model: 'test-model', - author: originalAuthor, - description: 'Original description', + beforeEach(async () => { + await Agent.deleteMany({}); }); - // User 1 makes an update - await updateAgent( - { id: agentId }, - { name: 'Updated by User 1', description: 'First update' }, - { updatingUserId: user1.toString() }, - ); + test('should now create new version when agent_ids field changes', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); - // Original author makes an update - await updateAgent( - { id: agentId }, - { description: 'Updated by original author' }, - { updatingUserId: originalAuthor.toString() }, - ); + const agent = await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + agent_ids: ['agent1', 'agent2'], + }); - // User 2 makes an update - await updateAgent( - { id: agentId }, - { name: 'Updated by User 2', model: 'new-model' }, - { updatingUserId: user2.toString() }, - ); + expect(agent).toBeDefined(); + expect(agent.versions).toHaveLength(1); - // User 3 makes an update - const finalAgent = await updateAgent( - { id: agentId }, - { description: 'Final update by User 3' }, - { updatingUserId: user3.toString() }, - ); - - expect(finalAgent.versions).toHaveLength(5); - expect(finalAgent.author.toString()).toBe(originalAuthor.toString()); - - // Check that each version has the correct updatedBy - expect(finalAgent.versions[0].updatedBy).toBeUndefined(); // Initial creation has no updatedBy - expect(finalAgent.versions[1].updatedBy.toString()).toBe(user1.toString()); - expect(finalAgent.versions[2].updatedBy.toString()).toBe(originalAuthor.toString()); - expect(finalAgent.versions[3].updatedBy.toString()).toBe(user2.toString()); - expect(finalAgent.versions[4].updatedBy.toString()).toBe(user3.toString()); - - // Verify the final state - expect(finalAgent.name).toBe('Updated by User 2'); - expect(finalAgent.description).toBe('Final update by User 3'); - expect(finalAgent.model).toBe('new-model'); - }); - - test('should preserve original author during agent restoration', async () => { - const agentId = `agent_${uuidv4()}`; - const originalAuthor = new mongoose.Types.ObjectId(); - const updatingUser = new mongoose.Types.ObjectId(); - - await createAgent({ - id: agentId, - name: 'Original Agent', - provider: 'test', - model: 'test-model', - author: originalAuthor, - description: 'Original description', - }); - - await updateAgent( - { id: agentId }, - { name: 'Updated Agent', description: 'Updated description' }, - { updatingUserId: updatingUser.toString() }, - ); - - const { revertAgentVersion } = require('./Agent'); - const revertedAgent = await revertAgentVersion({ id: agentId }, 0); - - expect(revertedAgent.author.toString()).toBe(originalAuthor.toString()); - expect(revertedAgent.name).toBe('Original Agent'); - expect(revertedAgent.description).toBe('Original description'); - }); - - test('should detect action metadata changes and force version update', async () => { - const agentId = `agent_${uuidv4()}`; - const authorId = new mongoose.Types.ObjectId(); - const actionId = 'testActionId123'; - - // Create agent with actions - await createAgent({ - id: agentId, - name: 'Agent with Actions', - provider: 'test', - model: 'test-model', - author: authorId, - actions: [`test.com_action_${actionId}`], - tools: ['listEvents_action_test.com', 'createEvent_action_test.com'], - }); - - // First update with forceVersion should create a version - const firstUpdate = await updateAgent( - { id: agentId }, - { tools: ['listEvents_action_test.com', 'createEvent_action_test.com'] }, - { updatingUserId: authorId.toString(), forceVersion: true }, - ); - - expect(firstUpdate.versions).toHaveLength(2); - - // Second update with same data but forceVersion should still create a version - const secondUpdate = await updateAgent( - { id: agentId }, - { tools: ['listEvents_action_test.com', 'createEvent_action_test.com'] }, - { updatingUserId: authorId.toString(), forceVersion: true }, - ); - - expect(secondUpdate.versions).toHaveLength(3); - - // Update without forceVersion and no changes should not create a version - let error; - try { - await updateAgent( + const updated = await updateAgent( { id: agentId }, - { tools: ['listEvents_action_test.com', 'createEvent_action_test.com'] }, - { updatingUserId: authorId.toString(), forceVersion: false }, + { agent_ids: ['agent1', 'agent2', 'agent3'] }, ); - } catch (e) { - error = e; - } - expect(error).toBeDefined(); - expect(error.message).toContain('Duplicate version'); - expect(error.statusCode).toBe(409); + // Since agent_ids is no longer excluded, this should create a new version + expect(updated.versions).toHaveLength(2); + expect(updated.agent_ids).toEqual(['agent1', 'agent2', 'agent3']); + }); + + test('should detect duplicate version if agent_ids is updated to same value', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + agent_ids: ['agent1', 'agent2'], + }); + + await updateAgent({ id: agentId }, { agent_ids: ['agent1', 'agent2', 'agent3'] }); + + await expect( + updateAgent({ id: agentId }, { agent_ids: ['agent1', 'agent2', 'agent3'] }), + ).rejects.toThrow('Duplicate version'); + }); + + test('should handle agent_ids field alongside other fields', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + description: 'Initial description', + agent_ids: ['agent1'], + }); + + const updated = await updateAgent( + { id: agentId }, + { + agent_ids: ['agent1', 'agent2'], + description: 'Updated description', + }, + ); + + expect(updated.versions).toHaveLength(2); + expect(updated.agent_ids).toEqual(['agent1', 'agent2']); + expect(updated.description).toBe('Updated description'); + + const updated2 = await updateAgent({ id: agentId }, { description: 'Another description' }); + + expect(updated2.versions).toHaveLength(3); + expect(updated2.agent_ids).toEqual(['agent1', 'agent2']); + expect(updated2.description).toBe('Another description'); + }); + + test('should skip version creation when skipVersioning option is used', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + const projectId1 = new mongoose.Types.ObjectId(); + const projectId2 = new mongoose.Types.ObjectId(); + + // Create agent with initial projectIds + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + projectIds: [projectId1], + }); + + // Share agent using updateAgentProjects (which uses skipVersioning) + const shared = await updateAgentProjects({ + user: { id: authorId.toString() }, // Use the same author ID + agentId: agentId, + projectIds: [projectId2.toString()], + }); + + // Should NOT create a new version due to skipVersioning + expect(shared.versions).toHaveLength(1); + expect(shared.projectIds.map((id) => id.toString())).toContain(projectId1.toString()); + expect(shared.projectIds.map((id) => id.toString())).toContain(projectId2.toString()); + + // Unshare agent using updateAgentProjects + const unshared = await updateAgentProjects({ + user: { id: authorId.toString() }, + agentId: agentId, + removeProjectIds: [projectId1.toString()], + }); + + // Still should NOT create a new version + expect(unshared.versions).toHaveLength(1); + expect(unshared.projectIds.map((id) => id.toString())).not.toContain(projectId1.toString()); + expect(unshared.projectIds.map((id) => id.toString())).toContain(projectId2.toString()); + + // Regular update without skipVersioning should create a version + const regularUpdate = await updateAgent( + { id: agentId }, + { description: 'Updated description' }, + ); + + expect(regularUpdate.versions).toHaveLength(2); + expect(regularUpdate.description).toBe('Updated description'); + + // Direct updateAgent with MongoDB operators should still create versions + const directUpdate = await updateAgent( + { id: agentId }, + { $addToSet: { projectIds: { $each: [projectId1] } } }, + ); + + expect(directUpdate.versions).toHaveLength(3); + expect(directUpdate.projectIds.length).toBe(2); + }); + + test('should preserve agent_ids in version history', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + agent_ids: ['agent1'], + }); + + await updateAgent({ id: agentId }, { agent_ids: ['agent1', 'agent2'] }); + + await updateAgent({ id: agentId }, { agent_ids: ['agent3'] }); + + const finalAgent = await getAgent({ id: agentId }); + + expect(finalAgent.versions).toHaveLength(3); + expect(finalAgent.versions[0].agent_ids).toEqual(['agent1']); + expect(finalAgent.versions[1].agent_ids).toEqual(['agent1', 'agent2']); + expect(finalAgent.versions[2].agent_ids).toEqual(['agent3']); + expect(finalAgent.agent_ids).toEqual(['agent3']); + }); + + test('should handle empty agent_ids arrays', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + agent_ids: ['agent1', 'agent2'], + }); + + const updated = await updateAgent({ id: agentId }, { agent_ids: [] }); + + expect(updated.versions).toHaveLength(2); + expect(updated.agent_ids).toEqual([]); + + await expect(updateAgent({ id: agentId }, { agent_ids: [] })).rejects.toThrow( + 'Duplicate version', + ); + }); + + test('should handle agent without agent_ids field', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + const agent = await createAgent({ + id: agentId, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + expect(agent.agent_ids).toEqual([]); + + const updated = await updateAgent({ id: agentId }, { agent_ids: ['agent1'] }); + + expect(updated.versions).toHaveLength(2); + expect(updated.agent_ids).toEqual(['agent1']); + }); }); }); + +function createBasicAgent(overrides = {}) { + const defaults = { + id: `agent_${uuidv4()}`, + name: 'Test Agent', + provider: 'test', + model: 'test-model', + author: new mongoose.Types.ObjectId(), + }; + return createAgent({ ...defaults, ...overrides }); +} + +function createTestIds() { + return { + agentId: `agent_${uuidv4()}`, + authorId: new mongoose.Types.ObjectId(), + projectId: new mongoose.Types.ObjectId(), + fileId: uuidv4(), + }; +} + +function createFileOperations(agentId, fileIds, operation = 'add') { + return fileIds.map((fileId) => + operation === 'add' + ? addAgentResourceFile({ agent_id: agentId, tool_resource: 'test_tool', file_id: fileId }) + : removeAgentResourceFiles({ + agent_id: agentId, + files: [{ tool_resource: 'test_tool', file_id: fileId }], + }), + ); +} + +function mockFindOneAndUpdateError(errorOnCall = 1) { + const original = Agent.findOneAndUpdate; + let callCount = 0; + + Agent.findOneAndUpdate = jest.fn().mockImplementation((...args) => { + callCount++; + if (callCount === errorOnCall) { + throw new Error('Database connection lost'); + } + return original.apply(Agent, args); + }); + + return () => { + Agent.findOneAndUpdate = original; + }; +} + +function generateVersionTestCases() { + const projectId1 = new mongoose.Types.ObjectId(); + const projectId2 = new mongoose.Types.ObjectId(); + + return [ + { + name: 'simple field update', + initial: { + name: 'Test Agent', + description: 'Initial description', + }, + update: { name: 'Updated Name' }, + duplicate: { name: 'Updated Name' }, + }, + { + name: 'object field update', + initial: { + model_parameters: { temperature: 0.7 }, + }, + update: { model_parameters: { temperature: 0.8 } }, + duplicate: { model_parameters: { temperature: 0.8 } }, + }, + { + name: 'array field update', + initial: { + tools: ['tool1', 'tool2'], + }, + update: { tools: ['tool2', 'tool3'] }, + duplicate: { tools: ['tool2', 'tool3'] }, + }, + { + name: 'projectIds update', + initial: { + projectIds: [projectId1], + }, + update: { projectIds: [projectId1, projectId2] }, + duplicate: { projectIds: [projectId2, projectId1] }, + }, + ]; +} diff --git a/api/models/Share.js b/api/models/Share.js deleted file mode 100644 index f8712c36ac..0000000000 --- a/api/models/Share.js +++ /dev/null @@ -1,346 +0,0 @@ -const { nanoid } = require('nanoid'); -const { logger } = require('@librechat/data-schemas'); -const { Constants } = require('librechat-data-provider'); -const { Conversation, SharedLink } = require('~/db/models'); -const { getMessages } = require('./Message'); - -class ShareServiceError extends Error { - constructor(message, code) { - super(message); - this.name = 'ShareServiceError'; - this.code = code; - } -} - -const memoizedAnonymizeId = (prefix) => { - const memo = new Map(); - return (id) => { - if (!memo.has(id)) { - memo.set(id, `${prefix}_${nanoid()}`); - } - return memo.get(id); - }; -}; - -const anonymizeConvoId = memoizedAnonymizeId('convo'); -const anonymizeAssistantId = memoizedAnonymizeId('a'); -const anonymizeMessageId = (id) => - id === Constants.NO_PARENT ? id : memoizedAnonymizeId('msg')(id); - -function anonymizeConvo(conversation) { - if (!conversation) { - return null; - } - - const newConvo = { ...conversation }; - if (newConvo.assistant_id) { - newConvo.assistant_id = anonymizeAssistantId(newConvo.assistant_id); - } - return newConvo; -} - -function anonymizeMessages(messages, newConvoId) { - if (!Array.isArray(messages)) { - return []; - } - - const idMap = new Map(); - return messages.map((message) => { - const newMessageId = anonymizeMessageId(message.messageId); - idMap.set(message.messageId, newMessageId); - - const anonymizedAttachments = message.attachments?.map((attachment) => { - return { - ...attachment, - messageId: newMessageId, - conversationId: newConvoId, - }; - }); - - return { - ...message, - messageId: newMessageId, - parentMessageId: - idMap.get(message.parentMessageId) || anonymizeMessageId(message.parentMessageId), - conversationId: newConvoId, - model: message.model?.startsWith('asst_') - ? anonymizeAssistantId(message.model) - : message.model, - attachments: anonymizedAttachments, - }; - }); -} - -async function getSharedMessages(shareId) { - try { - const share = await SharedLink.findOne({ shareId, isPublic: true }) - .populate({ - path: 'messages', - select: '-_id -__v -user', - }) - .select('-_id -__v -user') - .lean(); - - if (!share?.conversationId || !share.isPublic) { - return null; - } - - const newConvoId = anonymizeConvoId(share.conversationId); - const result = { - ...share, - conversationId: newConvoId, - messages: anonymizeMessages(share.messages, newConvoId), - }; - - return result; - } catch (error) { - logger.error('[getShare] Error getting share link', { - error: error.message, - shareId, - }); - throw new ShareServiceError('Error getting share link', 'SHARE_FETCH_ERROR'); - } -} - -async function getSharedLinks(user, pageParam, pageSize, isPublic, sortBy, sortDirection, search) { - try { - const query = { user, isPublic }; - - if (pageParam) { - if (sortDirection === 'desc') { - query[sortBy] = { $lt: pageParam }; - } else { - query[sortBy] = { $gt: pageParam }; - } - } - - if (search && search.trim()) { - try { - const searchResults = await Conversation.meiliSearch(search); - - if (!searchResults?.hits?.length) { - return { - links: [], - nextCursor: undefined, - hasNextPage: false, - }; - } - - const conversationIds = searchResults.hits.map((hit) => hit.conversationId); - query['conversationId'] = { $in: conversationIds }; - } catch (searchError) { - logger.error('[getSharedLinks] Meilisearch error', { - error: searchError.message, - user, - }); - return { - links: [], - nextCursor: undefined, - hasNextPage: false, - }; - } - } - - const sort = {}; - sort[sortBy] = sortDirection === 'desc' ? -1 : 1; - - if (Array.isArray(query.conversationId)) { - query.conversationId = { $in: query.conversationId }; - } - - const sharedLinks = await SharedLink.find(query) - .sort(sort) - .limit(pageSize + 1) - .select('-__v -user') - .lean(); - - const hasNextPage = sharedLinks.length > pageSize; - const links = sharedLinks.slice(0, pageSize); - - const nextCursor = hasNextPage ? links[links.length - 1][sortBy] : undefined; - - return { - links: links.map((link) => ({ - shareId: link.shareId, - title: link?.title || 'Untitled', - isPublic: link.isPublic, - createdAt: link.createdAt, - conversationId: link.conversationId, - })), - nextCursor, - hasNextPage, - }; - } catch (error) { - logger.error('[getSharedLinks] Error getting shares', { - error: error.message, - user, - }); - throw new ShareServiceError('Error getting shares', 'SHARES_FETCH_ERROR'); - } -} - -async function deleteAllSharedLinks(user) { - try { - const result = await SharedLink.deleteMany({ user }); - return { - message: 'All shared links deleted successfully', - deletedCount: result.deletedCount, - }; - } catch (error) { - logger.error('[deleteAllSharedLinks] Error deleting shared links', { - error: error.message, - user, - }); - throw new ShareServiceError('Error deleting shared links', 'BULK_DELETE_ERROR'); - } -} - -async function createSharedLink(user, conversationId) { - if (!user || !conversationId) { - throw new ShareServiceError('Missing required parameters', 'INVALID_PARAMS'); - } - try { - const [existingShare, conversationMessages] = await Promise.all([ - SharedLink.findOne({ conversationId, isPublic: true }).select('-_id -__v -user').lean(), - getMessages({ conversationId }), - ]); - - if (existingShare && existingShare.isPublic) { - throw new ShareServiceError('Share already exists', 'SHARE_EXISTS'); - } else if (existingShare) { - await SharedLink.deleteOne({ conversationId }); - } - - const conversation = await Conversation.findOne({ conversationId }).lean(); - const title = conversation?.title || 'Untitled'; - - const shareId = nanoid(); - await SharedLink.create({ - shareId, - conversationId, - messages: conversationMessages, - title, - user, - }); - - return { shareId, conversationId }; - } catch (error) { - logger.error('[createSharedLink] Error creating shared link', { - error: error.message, - user, - conversationId, - }); - throw new ShareServiceError('Error creating shared link', 'SHARE_CREATE_ERROR'); - } -} - -async function getSharedLink(user, conversationId) { - if (!user || !conversationId) { - throw new ShareServiceError('Missing required parameters', 'INVALID_PARAMS'); - } - - try { - const share = await SharedLink.findOne({ conversationId, user, isPublic: true }) - .select('shareId -_id') - .lean(); - - if (!share) { - return { shareId: null, success: false }; - } - - return { shareId: share.shareId, success: true }; - } catch (error) { - logger.error('[getSharedLink] Error getting shared link', { - error: error.message, - user, - conversationId, - }); - throw new ShareServiceError('Error getting shared link', 'SHARE_FETCH_ERROR'); - } -} - -async function updateSharedLink(user, shareId) { - if (!user || !shareId) { - throw new ShareServiceError('Missing required parameters', 'INVALID_PARAMS'); - } - - try { - const share = await SharedLink.findOne({ shareId }).select('-_id -__v -user').lean(); - - if (!share) { - throw new ShareServiceError('Share not found', 'SHARE_NOT_FOUND'); - } - - const [updatedMessages] = await Promise.all([ - getMessages({ conversationId: share.conversationId }), - ]); - - const newShareId = nanoid(); - const update = { - messages: updatedMessages, - user, - shareId: newShareId, - }; - - const updatedShare = await SharedLink.findOneAndUpdate({ shareId, user }, update, { - new: true, - upsert: false, - runValidators: true, - }).lean(); - - if (!updatedShare) { - throw new ShareServiceError('Share update failed', 'SHARE_UPDATE_ERROR'); - } - - anonymizeConvo(updatedShare); - - return { shareId: newShareId, conversationId: updatedShare.conversationId }; - } catch (error) { - logger.error('[updateSharedLink] Error updating shared link', { - error: error.message, - user, - shareId, - }); - throw new ShareServiceError( - error.code === 'SHARE_UPDATE_ERROR' ? error.message : 'Error updating shared link', - error.code || 'SHARE_UPDATE_ERROR', - ); - } -} - -async function deleteSharedLink(user, shareId) { - if (!user || !shareId) { - throw new ShareServiceError('Missing required parameters', 'INVALID_PARAMS'); - } - - try { - const result = await SharedLink.findOneAndDelete({ shareId, user }).lean(); - - if (!result) { - return null; - } - - return { - success: true, - shareId, - message: 'Share deleted successfully', - }; - } catch (error) { - logger.error('[deleteSharedLink] Error deleting shared link', { - error: error.message, - user, - shareId, - }); - throw new ShareServiceError('Error deleting shared link', 'SHARE_DELETE_ERROR'); - } -} - -module.exports = { - getSharedLink, - getSharedLinks, - createSharedLink, - updateSharedLink, - deleteSharedLink, - getSharedMessages, - deleteAllSharedLinks, -}; diff --git a/api/models/Token.js b/api/models/Token.js deleted file mode 100644 index 6f130eb2c4..0000000000 --- a/api/models/Token.js +++ /dev/null @@ -1,42 +0,0 @@ -const { findToken, updateToken, createToken } = require('~/models'); -const { encryptV2 } = require('~/server/utils/crypto'); - -/** - * Handles the OAuth token by creating or updating the token. - * @param {object} fields - * @param {string} fields.userId - The user's ID. - * @param {string} fields.token - The full token to store. - * @param {string} fields.identifier - Unique, alternative identifier for the token. - * @param {number} fields.expiresIn - The number of seconds until the token expires. - * @param {object} fields.metadata - Additional metadata to store with the token. - * @param {string} [fields.type="oauth"] - The type of token. Default is 'oauth'. - */ -async function handleOAuthToken({ - token, - userId, - identifier, - expiresIn, - metadata, - type = 'oauth', -}) { - const encrypedToken = await encryptV2(token); - const tokenData = { - type, - userId, - metadata, - identifier, - token: encrypedToken, - expiresIn: parseInt(expiresIn, 10) || 3600, - }; - - const existingToken = await findToken({ userId, identifier }); - if (existingToken) { - return await updateToken({ identifier }, tokenData); - } else { - return await createToken(tokenData); - } -} - -module.exports = { - handleOAuthToken, -}; diff --git a/api/models/inviteUser.js b/api/models/inviteUser.js index 9f35b3f02b..eeb42841bf 100644 --- a/api/models/inviteUser.js +++ b/api/models/inviteUser.js @@ -1,6 +1,6 @@ const mongoose = require('mongoose'); +const { getRandomValues } = require('@librechat/api'); const { logger, hashToken } = require('@librechat/data-schemas'); -const { getRandomValues } = require('~/server/utils/crypto'); const { createToken, findToken } = require('~/models'); /** diff --git a/api/models/tx.js b/api/models/tx.js index ddd098b80f..f3ba38652d 100644 --- a/api/models/tx.js +++ b/api/models/tx.js @@ -78,7 +78,7 @@ const tokenValues = Object.assign( 'gpt-3.5-turbo-1106': { prompt: 1, completion: 2 }, 'o4-mini': { prompt: 1.1, completion: 4.4 }, 'o3-mini': { prompt: 1.1, completion: 4.4 }, - o3: { prompt: 10, completion: 40 }, + o3: { prompt: 2, completion: 8 }, 'o1-mini': { prompt: 1.1, completion: 4.4 }, 'o1-preview': { prompt: 15, completion: 60 }, o1: { prompt: 15, completion: 60 }, diff --git a/api/models/userMethods.js b/api/models/userMethods.js index e8bf5e4784..a36409ebcf 100644 --- a/api/models/userMethods.js +++ b/api/models/userMethods.js @@ -12,6 +12,10 @@ const comparePassword = async (user, candidatePassword) => { throw new Error('No user provided'); } + if (!user.password) { + throw new Error('No password, likely an email first registered via Social/OIDC login'); + } + return new Promise((resolve, reject) => { bcrypt.compare(candidatePassword, user.password, (err, isMatch) => { if (err) { diff --git a/api/package.json b/api/package.json index 3d3766bde8..6633a99c3f 100644 --- a/api/package.json +++ b/api/package.json @@ -34,21 +34,22 @@ }, "homepage": "https://librechat.ai", "dependencies": { - "@anthropic-ai/sdk": "^0.37.0", + "@anthropic-ai/sdk": "^0.52.0", "@aws-sdk/client-s3": "^3.758.0", "@aws-sdk/s3-request-presigner": "^3.758.0", "@azure/identity": "^4.7.0", "@azure/search-documents": "^12.0.0", "@azure/storage-blob": "^12.27.0", - "@google/generative-ai": "^0.23.0", + "@google/generative-ai": "^0.24.0", "@googleapis/youtube": "^20.0.0", "@keyv/redis": "^4.3.3", - "@langchain/community": "^0.3.44", - "@langchain/core": "^0.3.57", - "@langchain/google-genai": "^0.2.9", - "@langchain/google-vertexai": "^0.2.9", + "@langchain/community": "^0.3.47", + "@langchain/core": "^0.3.60", + "@langchain/google-genai": "^0.2.13", + "@langchain/google-vertexai": "^0.2.13", "@langchain/textsplitters": "^0.1.0", - "@librechat/agents": "^2.4.37", + "@librechat/agents": "^2.4.41", + "@librechat/api": "*", "@librechat/data-schemas": "*", "@node-saml/passport-saml": "^5.0.0", "@waylaidwanderer/fetch-event-source": "^3.0.1", @@ -81,15 +82,15 @@ "keyv-file": "^5.1.2", "klona": "^2.0.6", "librechat-data-provider": "*", - "librechat-mcp": "*", "lodash": "^4.17.21", "meilisearch": "^0.38.0", "memorystore": "^1.6.7", "mime": "^3.0.0", "module-alias": "^2.2.3", "mongoose": "^8.12.1", - "multer": "^2.0.0", + "multer": "^2.0.1", "nanoid": "^3.3.7", + "node-fetch": "^2.7.0", "nodemailer": "^6.9.15", "ollama": "^0.5.0", "openai": "^4.96.2", @@ -109,8 +110,9 @@ "tiktoken": "^1.0.15", "traverse": "^0.6.7", "ua-parser-js": "^1.0.36", + "undici": "^7.10.0", "winston": "^3.11.0", - "winston-daily-rotate-file": "^4.7.1", + "winston-daily-rotate-file": "^5.0.0", "youtube-transcript": "^1.2.1", "zod": "^3.22.4" }, diff --git a/api/server/cleanup.js b/api/server/cleanup.js index 5bf336eed5..de7450cea0 100644 --- a/api/server/cleanup.js +++ b/api/server/cleanup.js @@ -220,6 +220,9 @@ function disposeClient(client) { if (client.maxResponseTokens) { client.maxResponseTokens = null; } + if (client.processMemory) { + client.processMemory = null; + } if (client.run) { // Break circular references in run if (client.run.Graph) { diff --git a/api/server/controllers/PluginController.js b/api/server/controllers/PluginController.js index 674e36002a..f7aad84aeb 100644 --- a/api/server/controllers/PluginController.js +++ b/api/server/controllers/PluginController.js @@ -1,9 +1,11 @@ +const { logger } = require('@librechat/data-schemas'); const { CacheKeys, AuthType } = require('librechat-data-provider'); +const { getCustomConfig, getCachedTools } = require('~/server/services/Config'); const { getToolkitKey } = require('~/server/services/ToolService'); -const { getCustomConfig } = require('~/server/services/Config'); +const { getMCPManager, getFlowStateManager } = require('~/config'); const { availableTools } = require('~/app/clients/tools'); -const { getMCPManager } = require('~/config'); const { getLogStores } = require('~/cache'); +const { Constants } = require('librechat-data-provider'); /** * Filters out duplicate plugins from the list of plugins. @@ -84,6 +86,45 @@ const getAvailablePluginsController = async (req, res) => { } }; +function createServerToolsCallback() { + /** + * @param {string} serverName + * @param {TPlugin[] | null} serverTools + */ + return async function (serverName, serverTools) { + try { + const mcpToolsCache = getLogStores(CacheKeys.MCP_TOOLS); + if (!serverName || !mcpToolsCache) { + return; + } + await mcpToolsCache.set(serverName, serverTools); + logger.debug(`MCP tools for ${serverName} added to cache.`); + } catch (error) { + logger.error('Error retrieving MCP tools from cache:', error); + } + }; +} + +function createGetServerTools() { + /** + * Retrieves cached server tools + * @param {string} serverName + * @returns {Promise} + */ + return async function (serverName) { + try { + const mcpToolsCache = getLogStores(CacheKeys.MCP_TOOLS); + if (!mcpToolsCache) { + return null; + } + return await mcpToolsCache.get(serverName); + } catch (error) { + logger.error('Error retrieving MCP tools from cache:', error); + return null; + } + }; +} + /** * Retrieves and returns a list of available tools, either from a cache or by reading a plugin manifest file. * @@ -109,7 +150,16 @@ const getAvailableTools = async (req, res) => { const customConfig = await getCustomConfig(); if (customConfig?.mcpServers != null) { const mcpManager = getMCPManager(); - pluginManifest = await mcpManager.loadManifestTools(pluginManifest); + const flowsCache = getLogStores(CacheKeys.FLOWS); + const flowManager = flowsCache ? getFlowStateManager(flowsCache) : null; + const serverToolsCallback = createServerToolsCallback(); + const getServerTools = createGetServerTools(); + const mcpTools = await mcpManager.loadManifestTools({ + flowManager, + serverToolsCallback, + getServerTools, + }); + pluginManifest = [...mcpTools, ...pluginManifest]; } /** @type {TPlugin[]} */ @@ -123,17 +173,57 @@ const getAvailableTools = async (req, res) => { } }); - const toolDefinitions = req.app.locals.availableTools; - const tools = authenticatedPlugins.filter( - (plugin) => - toolDefinitions[plugin.pluginKey] !== undefined || - (plugin.toolkit === true && - Object.keys(toolDefinitions).some((key) => getToolkitKey(key) === plugin.pluginKey)), - ); + const toolDefinitions = await getCachedTools({ includeGlobal: true }); - await cache.set(CacheKeys.TOOLS, tools); - res.status(200).json(tools); + const toolsOutput = []; + for (const plugin of authenticatedPlugins) { + const isToolDefined = toolDefinitions[plugin.pluginKey] !== undefined; + const isToolkit = + plugin.toolkit === true && + Object.keys(toolDefinitions).some((key) => getToolkitKey(key) === plugin.pluginKey); + + if (!isToolDefined && !isToolkit) { + continue; + } + + const toolToAdd = { ...plugin }; + + if (!plugin.pluginKey.includes(Constants.mcp_delimiter)) { + toolsOutput.push(toolToAdd); + continue; + } + + const parts = plugin.pluginKey.split(Constants.mcp_delimiter); + const serverName = parts[parts.length - 1]; + const serverConfig = customConfig?.mcpServers?.[serverName]; + + if (!serverConfig?.customUserVars) { + toolsOutput.push(toolToAdd); + continue; + } + + const customVarKeys = Object.keys(serverConfig.customUserVars); + + if (customVarKeys.length === 0) { + toolToAdd.authConfig = []; + toolToAdd.authenticated = true; + } else { + toolToAdd.authConfig = Object.entries(serverConfig.customUserVars).map(([key, value]) => ({ + authField: key, + label: value.title || key, + description: value.description || '', + })); + toolToAdd.authenticated = false; + } + + toolsOutput.push(toolToAdd); + } + + const finalTools = filterUniquePlugins(toolsOutput); + await cache.set(CacheKeys.TOOLS, finalTools); + res.status(200).json(finalTools); } catch (error) { + logger.error('[getAvailableTools]', error); res.status(500).json({ message: error.message }); } }; diff --git a/api/server/controllers/TwoFactorController.js b/api/server/controllers/TwoFactorController.js index 6e22db2e5c..44baf92ee7 100644 --- a/api/server/controllers/TwoFactorController.js +++ b/api/server/controllers/TwoFactorController.js @@ -1,3 +1,4 @@ +const { encryptV3 } = require('@librechat/api'); const { logger } = require('@librechat/data-schemas'); const { verifyTOTP, @@ -7,7 +8,6 @@ const { generateBackupCodes, } = require('~/server/services/twoFactorService'); const { getUserById, updateUser } = require('~/models'); -const { encryptV3 } = require('~/server/utils/crypto'); const safeAppTitle = (process.env.APP_TITLE || 'LibreChat').replace(/\s+/g, ''); diff --git a/api/server/controllers/UserController.js b/api/server/controllers/UserController.js index a2fbc3c485..69791dd7a5 100644 --- a/api/server/controllers/UserController.js +++ b/api/server/controllers/UserController.js @@ -1,5 +1,6 @@ const { Tools, + Constants, FileSources, webSearchKeys, extractWebSearchEnvVars, @@ -21,8 +22,9 @@ const { verifyEmail, resendVerificationEmail } = require('~/server/services/Auth const { needsRefresh, getNewS3URL } = require('~/server/services/Files/S3/crud'); const { processDeleteRequest } = require('~/server/services/Files/process'); const { Transaction, Balance, User } = require('~/db/models'); -const { deleteAllSharedLinks } = require('~/models/Share'); const { deleteToolCalls } = require('~/models/ToolCall'); +const { deleteAllSharedLinks } = require('~/models'); +const { getMCPManager } = require('~/config'); const getUserController = async (req, res) => { /** @type {MongoUser} */ @@ -102,10 +104,22 @@ const updateUserPluginsController = async (req, res) => { } let keys = Object.keys(auth); - if (keys.length === 0 && pluginKey !== Tools.web_search) { + const values = Object.values(auth); // Used in 'install' block + + const isMCPTool = pluginKey.startsWith('mcp_') || pluginKey.includes(Constants.mcp_delimiter); + + // Early exit condition: + // If keys are empty (meaning auth: {} was likely sent for uninstall, or auth was empty for install) + // AND it's not web_search (which has special key handling to populate `keys` for uninstall) + // AND it's NOT (an uninstall action FOR an MCP tool - we need to proceed for this case to clear all its auth) + // THEN return. + if ( + keys.length === 0 && + pluginKey !== Tools.web_search && + !(action === 'uninstall' && isMCPTool) + ) { return res.status(200).send(); } - const values = Object.values(auth); /** @type {number} */ let status = 200; @@ -132,16 +146,53 @@ const updateUserPluginsController = async (req, res) => { } } } else if (action === 'uninstall') { - for (let i = 0; i < keys.length; i++) { - authService = await deleteUserPluginAuth(user.id, keys[i]); + // const isMCPTool was defined earlier + if (isMCPTool && keys.length === 0) { + // This handles the case where auth: {} is sent for an MCP tool uninstall. + // It means "delete all credentials associated with this MCP pluginKey". + authService = await deleteUserPluginAuth(user.id, null, true, pluginKey); if (authService instanceof Error) { - logger.error('[authService]', authService); + logger.error( + `[authService] Error deleting all auth for MCP tool ${pluginKey}:`, + authService, + ); ({ status, message } = authService); } + } else { + // This handles: + // 1. Web_search uninstall (keys will be populated with all webSearchKeys if auth was {}). + // 2. Other tools uninstall (if keys were provided). + // 3. MCP tool uninstall if specific keys were provided in `auth` (not current frontend behavior). + // If keys is empty for non-MCP tools (and not web_search), this loop won't run, and nothing is deleted. + for (let i = 0; i < keys.length; i++) { + authService = await deleteUserPluginAuth(user.id, keys[i]); // Deletes by authField name + if (authService instanceof Error) { + logger.error('[authService] Error deleting specific auth key:', authService); + ({ status, message } = authService); + } + } } } if (status === 200) { + // If auth was updated successfully, disconnect MCP sessions as they might use these credentials + if (pluginKey.startsWith(Constants.mcp_prefix)) { + try { + const mcpManager = getMCPManager(user.id); + if (mcpManager) { + logger.info( + `[updateUserPluginsController] Disconnecting MCP connections for user ${user.id} after plugin auth update for ${pluginKey}.`, + ); + await mcpManager.disconnectUserConnections(user.id); + } + } catch (disconnectError) { + logger.error( + `[updateUserPluginsController] Error disconnecting MCP connections for user ${user.id} after plugin auth update:`, + disconnectError, + ); + // Do not fail the request for this, but log it. + } + } return res.status(status).send(); } @@ -163,7 +214,11 @@ const deleteUserController = async (req, res) => { await Balance.deleteMany({ user: user._id }); // delete user balances await deletePresets(user.id); // delete user presets /* TODO: Delete Assistant Threads */ - await deleteConvos(user.id); // delete user convos + try { + await deleteConvos(user.id); // delete user convos + } catch (error) { + logger.error('[deleteUserController] Error deleting user convos, likely no convos', error); + } await deleteUserPluginAuth(user.id, null, true); // delete user plugin auth await deleteUserById(user.id); // delete user await deleteAllSharedLinks(user.id); // delete user shared links diff --git a/api/server/controllers/agents/callbacks.js b/api/server/controllers/agents/callbacks.js index cedfc6bd62..60e68b5f2d 100644 --- a/api/server/controllers/agents/callbacks.js +++ b/api/server/controllers/agents/callbacks.js @@ -1,4 +1,6 @@ const { nanoid } = require('nanoid'); +const { sendEvent } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { Tools, StepTypes, FileContext } = require('librechat-data-provider'); const { EnvVar, @@ -12,7 +14,6 @@ const { const { processCodeOutput } = require('~/server/services/Files/Code/process'); const { loadAuthValues } = require('~/server/services/Tools/credentials'); const { saveBase64Image } = require('~/server/services/Files/process'); -const { logger, sendEvent } = require('~/config'); class ModelEndHandler { /** @@ -240,9 +241,7 @@ function createToolEndCallback({ req, res, artifactPromises }) { if (output.artifact[Tools.web_search]) { artifactPromises.push( (async () => { - const name = `${output.name}_${output.tool_call_id}_${nanoid()}`; const attachment = { - name, type: Tools.web_search, messageId: metadata.run_id, toolCallId: output.tool_call_id, diff --git a/api/server/controllers/agents/client.js b/api/server/controllers/agents/client.js index 31fd56930e..6769348d95 100644 --- a/api/server/controllers/agents/client.js +++ b/api/server/controllers/agents/client.js @@ -1,13 +1,12 @@ -// const { HttpsProxyAgent } = require('https-proxy-agent'); -// const { -// Constants, -// ImageDetail, -// EModelEndpoint, -// resolveHeaders, -// validateVisionModel, -// mapModelToAzureConfig, -// } = require('librechat-data-provider'); require('events').EventEmitter.defaultMaxListeners = 100; +const { logger } = require('@librechat/data-schemas'); +const { + sendEvent, + createRun, + Tokenizer, + memoryInstructions, + createMemoryProcessor, +} = require('@librechat/api'); const { Callback, GraphEvents, @@ -19,25 +18,34 @@ const { } = require('@librechat/agents'); const { Constants, + Permissions, VisionModes, ContentTypes, EModelEndpoint, KnownEndpoints, + PermissionTypes, isAgentsEndpoint, AgentCapabilities, bedrockInputSchema, removeNullishValues, } = require('librechat-data-provider'); -const { getCustomEndpointConfig, checkCapability } = require('~/server/services/Config'); -const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts'); -const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); +const { DynamicStructuredTool } = require('@langchain/core/tools'); const { getBufferString, HumanMessage } = require('@langchain/core/messages'); +const { + getCustomEndpointConfig, + createGetMCPAuthMap, + checkCapability, +} = require('~/server/services/Config'); +const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts'); +const { initializeAgent } = require('~/server/services/Endpoints/agents/agent'); +const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); +const { getFormattedMemories, deleteMemory, setMemory } = require('~/models'); const { encodeAndFormat } = require('~/server/services/Files/images/encode'); const initOpenAI = require('~/server/services/Endpoints/openAI/initialize'); -const Tokenizer = require('~/server/services/Tokenizer'); +const { checkAccess } = require('~/server/middleware/roles/access'); const BaseClient = require('~/app/clients/BaseClient'); -const { logger, sendEvent } = require('~/config'); -const { createRun } = require('./run'); +const { loadAgent } = require('~/models/Agent'); +const { getMCPManager } = require('~/config'); /** * @param {ServerRequest} req @@ -57,12 +65,8 @@ const legacyContentEndpoints = new Set([KnownEndpoints.groq, KnownEndpoints.deep const noSystemModelRegex = [/\b(o1-preview|o1-mini|amazon\.titan-text)\b/gi]; -// const { processMemory, memoryInstructions } = require('~/server/services/Endpoints/agents/memory'); -// const { getFormattedMemories } = require('~/models/Memory'); -// const { getCurrentDateTime } = require('~/utils'); - function createTokenCounter(encoding) { - return (message) => { + return function (message) { const countTokens = (text) => Tokenizer.getTokenCount(text, encoding); return getTokenCountForMessage(message, countTokens); }; @@ -123,6 +127,8 @@ class AgentClient extends BaseClient { this.usage; /** @type {Record} */ this.indexTokenCountMap = {}; + /** @type {(messages: BaseMessage[]) => Promise} */ + this.processMemory; } /** @@ -137,55 +143,10 @@ class AgentClient extends BaseClient { } /** - * - * Checks if the model is a vision model based on request attachments and sets the appropriate options: - * - Sets `this.modelOptions.model` to `gpt-4-vision-preview` if the request is a vision request. - * - Sets `this.isVisionModel` to `true` if vision request. - * - Deletes `this.modelOptions.stop` if vision request. + * `AgentClient` is not opinionated about vision requests, so we don't do anything here * @param {MongoFile[]} attachments */ - checkVisionRequest(attachments) { - // if (!attachments) { - // return; - // } - // const availableModels = this.options.modelsConfig?.[this.options.endpoint]; - // if (!availableModels) { - // return; - // } - // let visionRequestDetected = false; - // for (const file of attachments) { - // if (file?.type?.includes('image')) { - // visionRequestDetected = true; - // break; - // } - // } - // if (!visionRequestDetected) { - // return; - // } - // this.isVisionModel = validateVisionModel({ model: this.modelOptions.model, availableModels }); - // if (this.isVisionModel) { - // delete this.modelOptions.stop; - // return; - // } - // for (const model of availableModels) { - // if (!validateVisionModel({ model, availableModels })) { - // continue; - // } - // this.modelOptions.model = model; - // this.isVisionModel = true; - // delete this.modelOptions.stop; - // return; - // } - // if (!availableModels.includes(this.defaultVisionModel)) { - // return; - // } - // if (!validateVisionModel({ model: this.defaultVisionModel, availableModels })) { - // return; - // } - // this.modelOptions.model = this.defaultVisionModel; - // this.isVisionModel = true; - // delete this.modelOptions.stop; - } + checkVisionRequest() {} getSaveOptions() { // TODO: @@ -269,24 +230,6 @@ class AgentClient extends BaseClient { .filter(Boolean) .join('\n') .trim(); - // this.systemMessage = getCurrentDateTime(); - // const { withKeys, withoutKeys } = await getFormattedMemories({ - // userId: this.options.req.user.id, - // }); - // processMemory({ - // userId: this.options.req.user.id, - // message: this.options.req.body.text, - // parentMessageId, - // memory: withKeys, - // thread_id: this.conversationId, - // }).catch((error) => { - // logger.error('Memory Agent failed to process memory', error); - // }); - - // this.systemMessage += '\n\n' + memoryInstructions; - // if (withoutKeys) { - // this.systemMessage += `\n\n# Existing memory about the user:\n${withoutKeys}`; - // } if (this.options.attachments) { const attachments = await this.options.attachments; @@ -370,6 +313,37 @@ class AgentClient extends BaseClient { systemContent = this.augmentedPrompt + systemContent; } + // Inject MCP server instructions if available + const ephemeralAgent = this.options.req.body.ephemeralAgent; + let mcpServers = []; + + // Check for ephemeral agent MCP servers + if (ephemeralAgent && ephemeralAgent.mcp && ephemeralAgent.mcp.length > 0) { + mcpServers = ephemeralAgent.mcp; + } + // Check for regular agent MCP tools + else if (this.options.agent && this.options.agent.tools) { + mcpServers = this.options.agent.tools + .filter( + (tool) => + tool instanceof DynamicStructuredTool && tool.name.includes(Constants.mcp_delimiter), + ) + .map((tool) => tool.name.split(Constants.mcp_delimiter).pop()) + .filter(Boolean); + } + + if (mcpServers.length > 0) { + try { + const mcpInstructions = getMCPManager().formatInstructionsForContext(mcpServers); + if (mcpInstructions) { + systemContent = [systemContent, mcpInstructions].filter(Boolean).join('\n\n'); + logger.debug('[AgentClient] Injected MCP instructions for servers:', mcpServers); + } + } catch (error) { + logger.error('[AgentClient] Failed to inject MCP instructions:', error); + } + } + if (systemContent) { this.options.agent.instructions = systemContent; } @@ -399,9 +373,150 @@ class AgentClient extends BaseClient { opts.getReqData({ promptTokens }); } + const withoutKeys = await this.useMemory(); + if (withoutKeys) { + systemContent += `${memoryInstructions}\n\n# Existing memory about the user:\n${withoutKeys}`; + } + + if (systemContent) { + this.options.agent.instructions = systemContent; + } + return result; } + /** + * @returns {Promise} + */ + async useMemory() { + const user = this.options.req.user; + if (user.personalization?.memories === false) { + return; + } + const hasAccess = await checkAccess(user, PermissionTypes.MEMORIES, [Permissions.USE]); + + if (!hasAccess) { + logger.debug( + `[api/server/controllers/agents/client.js #useMemory] User ${user.id} does not have USE permission for memories`, + ); + return; + } + /** @type {TCustomConfig['memory']} */ + const memoryConfig = this.options.req?.app?.locals?.memory; + if (!memoryConfig || memoryConfig.disabled === true) { + return; + } + + /** @type {Agent} */ + let prelimAgent; + const allowedProviders = new Set( + this.options.req?.app?.locals?.[EModelEndpoint.agents]?.allowedProviders, + ); + try { + if (memoryConfig.agent?.id != null && memoryConfig.agent.id !== this.options.agent.id) { + prelimAgent = await loadAgent({ + req: this.options.req, + agent_id: memoryConfig.agent.id, + endpoint: EModelEndpoint.agents, + }); + } else if ( + memoryConfig.agent?.id == null && + memoryConfig.agent?.model != null && + memoryConfig.agent?.provider != null + ) { + prelimAgent = { id: Constants.EPHEMERAL_AGENT_ID, ...memoryConfig.agent }; + } + } catch (error) { + logger.error( + '[api/server/controllers/agents/client.js #useMemory] Error loading agent for memory', + error, + ); + } + + const agent = await initializeAgent({ + req: this.options.req, + res: this.options.res, + agent: prelimAgent, + allowedProviders, + }); + + if (!agent) { + logger.warn( + '[api/server/controllers/agents/client.js #useMemory] No agent found for memory', + memoryConfig, + ); + return; + } + + const llmConfig = Object.assign( + { + provider: agent.provider, + model: agent.model, + }, + agent.model_parameters, + ); + + /** @type {import('@librechat/api').MemoryConfig} */ + const config = { + validKeys: memoryConfig.validKeys, + instructions: agent.instructions, + llmConfig, + tokenLimit: memoryConfig.tokenLimit, + }; + + const userId = this.options.req.user.id + ''; + const messageId = this.responseMessageId + ''; + const conversationId = this.conversationId + ''; + const [withoutKeys, processMemory] = await createMemoryProcessor({ + userId, + config, + messageId, + conversationId, + memoryMethods: { + setMemory, + deleteMemory, + getFormattedMemories, + }, + res: this.options.res, + }); + + this.processMemory = processMemory; + return withoutKeys; + } + + /** + * @param {BaseMessage[]} messages + * @returns {Promise} + */ + async runMemory(messages) { + try { + if (this.processMemory == null) { + return; + } + /** @type {TCustomConfig['memory']} */ + const memoryConfig = this.options.req?.app?.locals?.memory; + const messageWindowSize = memoryConfig?.messageWindowSize ?? 5; + + let messagesToProcess = [...messages]; + if (messages.length > messageWindowSize) { + for (let i = messages.length - messageWindowSize; i >= 0; i--) { + const potentialWindow = messages.slice(i, i + messageWindowSize); + if (potentialWindow[0]?.role === 'user') { + messagesToProcess = [...potentialWindow]; + break; + } + } + + if (messagesToProcess.length === messages.length) { + messagesToProcess = [...messages.slice(-messageWindowSize)]; + } + } + return await this.processMemory(messagesToProcess); + } catch (error) { + logger.error('Memory Agent failed to process memory', error); + } + } + /** @type {sendCompletion} */ async sendCompletion(payload, opts = {}) { await this.chatCompletion({ @@ -544,100 +659,13 @@ class AgentClient extends BaseClient { let config; /** @type {ReturnType} */ let run; + /** @type {Promise<(TAttachment | null)[] | undefined>} */ + let memoryPromise; try { if (!abortController) { abortController = new AbortController(); } - // if (this.options.headers) { - // opts.defaultHeaders = { ...opts.defaultHeaders, ...this.options.headers }; - // } - - // if (this.options.proxy) { - // opts.httpAgent = new HttpsProxyAgent(this.options.proxy); - // } - - // if (this.isVisionModel) { - // modelOptions.max_tokens = 4000; - // } - - // /** @type {TAzureConfig | undefined} */ - // const azureConfig = this.options?.req?.app?.locals?.[EModelEndpoint.azureOpenAI]; - - // if ( - // (this.azure && this.isVisionModel && azureConfig) || - // (azureConfig && this.isVisionModel && this.options.endpoint === EModelEndpoint.azureOpenAI) - // ) { - // const { modelGroupMap, groupMap } = azureConfig; - // const { - // azureOptions, - // baseURL, - // headers = {}, - // serverless, - // } = mapModelToAzureConfig({ - // modelName: modelOptions.model, - // modelGroupMap, - // groupMap, - // }); - // opts.defaultHeaders = resolveHeaders(headers); - // this.langchainProxy = extractBaseURL(baseURL); - // this.apiKey = azureOptions.azureOpenAIApiKey; - - // const groupName = modelGroupMap[modelOptions.model].group; - // this.options.addParams = azureConfig.groupMap[groupName].addParams; - // this.options.dropParams = azureConfig.groupMap[groupName].dropParams; - // // Note: `forcePrompt` not re-assigned as only chat models are vision models - - // this.azure = !serverless && azureOptions; - // this.azureEndpoint = - // !serverless && genAzureChatCompletion(this.azure, modelOptions.model, this); - // } - - // if (this.azure || this.options.azure) { - // /* Azure Bug, extremely short default `max_tokens` response */ - // if (!modelOptions.max_tokens && modelOptions.model === 'gpt-4-vision-preview') { - // modelOptions.max_tokens = 4000; - // } - - // /* Azure does not accept `model` in the body, so we need to remove it. */ - // delete modelOptions.model; - - // opts.baseURL = this.langchainProxy - // ? constructAzureURL({ - // baseURL: this.langchainProxy, - // azureOptions: this.azure, - // }) - // : this.azureEndpoint.split(/(? { - // delete modelOptions[param]; - // }); - // logger.debug('[api/server/controllers/agents/client.js #chatCompletion] dropped params', { - // dropParams: this.options.dropParams, - // modelOptions, - // }); - // } - /** @type {TCustomConfig['endpoints']['agents']} */ const agentsEConfig = this.options.req.app.locals[EModelEndpoint.agents]; @@ -647,6 +675,7 @@ class AgentClient extends BaseClient { last_agent_index: this.agentConfigs?.size ?? 0, user_id: this.user ?? this.options.req.user?.id, hide_sequential_outputs: this.options.agent.hide_sequential_outputs, + user: this.options.req.user, }, recursionLimit: agentsEConfig?.recursionLimit, signal: abortController.signal, @@ -654,6 +683,8 @@ class AgentClient extends BaseClient { version: 'v2', }; + const getUserMCPAuthMap = await createGetMCPAuthMap(); + const toolSet = new Set((this.options.agent.tools ?? []).map((tool) => tool && tool.name)); let { messages: initialMessages, indexTokenCountMap } = formatAgentMessages( payload, @@ -734,6 +765,10 @@ class AgentClient extends BaseClient { messages = addCacheControl(messages); } + if (i === 0) { + memoryPromise = this.runMemory(messages); + } + run = await createRun({ agent, req: this.options.req, @@ -769,10 +804,23 @@ class AgentClient extends BaseClient { run.Graph.contentData = contentData; } - const encoding = this.getEncoding(); + try { + if (getUserMCPAuthMap) { + config.configurable.userMCPAuthMap = await getUserMCPAuthMap({ + tools: agent.tools, + userId: this.options.req.user.id, + }); + } + } catch (err) { + logger.error( + `[api/server/controllers/agents/client.js #chatCompletion] Error getting custom user vars for agent ${agent.id}`, + err, + ); + } + await run.processStream({ messages }, config, { keepContent: i !== 0, - tokenCounter: createTokenCounter(encoding), + tokenCounter: createTokenCounter(this.getEncoding()), indexTokenCountMap: currentIndexCountMap, maxContextTokens: agent.maxContextTokens, callbacks: { @@ -887,6 +935,12 @@ class AgentClient extends BaseClient { }); try { + if (memoryPromise) { + const attachments = await memoryPromise; + if (attachments && attachments.length > 0) { + this.artifactPromises.push(...attachments); + } + } await this.recordCollectedUsage({ context: 'message' }); } catch (err) { logger.error( @@ -895,6 +949,12 @@ class AgentClient extends BaseClient { ); } } catch (err) { + if (memoryPromise) { + const attachments = await memoryPromise; + if (attachments && attachments.length > 0) { + this.artifactPromises.push(...attachments); + } + } logger.error( '[api/server/controllers/agents/client.js #sendCompletion] Operation aborted', err, diff --git a/api/server/controllers/agents/run.js b/api/server/controllers/agents/run.js deleted file mode 100644 index 2452e66233..0000000000 --- a/api/server/controllers/agents/run.js +++ /dev/null @@ -1,94 +0,0 @@ -const { Run, Providers } = require('@librechat/agents'); -const { providerEndpointMap, KnownEndpoints } = require('librechat-data-provider'); - -/** - * @typedef {import('@librechat/agents').t} t - * @typedef {import('@librechat/agents').StandardGraphConfig} StandardGraphConfig - * @typedef {import('@librechat/agents').StreamEventData} StreamEventData - * @typedef {import('@librechat/agents').EventHandler} EventHandler - * @typedef {import('@librechat/agents').GraphEvents} GraphEvents - * @typedef {import('@librechat/agents').LLMConfig} LLMConfig - * @typedef {import('@librechat/agents').IState} IState - */ - -const customProviders = new Set([ - Providers.XAI, - Providers.OLLAMA, - Providers.DEEPSEEK, - Providers.OPENROUTER, -]); - -/** - * Creates a new Run instance with custom handlers and configuration. - * - * @param {Object} options - The options for creating the Run instance. - * @param {ServerRequest} [options.req] - The server request. - * @param {string | undefined} [options.runId] - Optional run ID; otherwise, a new run ID will be generated. - * @param {Agent} options.agent - The agent for this run. - * @param {AbortSignal} options.signal - The signal for this run. - * @param {Record | undefined} [options.customHandlers] - Custom event handlers. - * @param {boolean} [options.streaming=true] - Whether to use streaming. - * @param {boolean} [options.streamUsage=true] - Whether to stream usage information. - * @returns {Promise>} A promise that resolves to a new Run instance. - */ -async function createRun({ - runId, - agent, - signal, - customHandlers, - streaming = true, - streamUsage = true, -}) { - const provider = providerEndpointMap[agent.provider] ?? agent.provider; - /** @type {LLMConfig} */ - const llmConfig = Object.assign( - { - provider, - streaming, - streamUsage, - }, - agent.model_parameters, - ); - - /** Resolves issues with new OpenAI usage field */ - if ( - customProviders.has(agent.provider) || - (agent.provider === Providers.OPENAI && agent.endpoint !== agent.provider) - ) { - llmConfig.streamUsage = false; - llmConfig.usage = true; - } - - /** @type {'reasoning_content' | 'reasoning'} */ - let reasoningKey; - if ( - llmConfig.configuration?.baseURL?.includes(KnownEndpoints.openrouter) || - (agent.endpoint && agent.endpoint.toLowerCase().includes(KnownEndpoints.openrouter)) - ) { - reasoningKey = 'reasoning'; - } - - /** @type {StandardGraphConfig} */ - const graphConfig = { - signal, - llmConfig, - reasoningKey, - tools: agent.tools, - instructions: agent.instructions, - additional_instructions: agent.additional_instructions, - // toolEnd: agent.end_after_tools, - }; - - // TEMPORARY FOR TESTING - if (agent.provider === Providers.ANTHROPIC || agent.provider === Providers.BEDROCK) { - graphConfig.streamBuffer = 2000; - } - - return Run.create({ - runId, - graphConfig, - customHandlers, - }); -} - -module.exports = { createRun }; diff --git a/api/server/controllers/agents/v1.js b/api/server/controllers/agents/v1.js index 65189a18a1..18bd7190f0 100644 --- a/api/server/controllers/agents/v1.js +++ b/api/server/controllers/agents/v1.js @@ -1,9 +1,9 @@ const fs = require('fs').promises; const { nanoid } = require('nanoid'); +const { logger } = require('@librechat/data-schemas'); const { Tools, Constants, - FileContext, FileSources, SystemRoles, EToolResources, @@ -16,16 +16,16 @@ const { deleteAgent, getListAgents, } = require('~/models/Agent'); -const { uploadImageBuffer, filterFile } = require('~/server/services/Files/process'); const { getStrategyFunctions } = require('~/server/services/Files/strategies'); const { resizeAvatar } = require('~/server/services/Files/images/avatar'); const { refreshS3Url } = require('~/server/services/Files/S3/crud'); +const { filterFile } = require('~/server/services/Files/process'); const { updateAction, getActions } = require('~/models/Action'); +const { getCachedTools } = require('~/server/services/Config'); const { updateAgentProjects } = require('~/models/Agent'); const { getProjectByName } = require('~/models/Project'); -const { deleteFileByFilter } = require('~/models/File'); const { revertAgentVersion } = require('~/models/Agent'); -const { logger } = require('~/config'); +const { deleteFileByFilter } = require('~/models/File'); const systemTools = { [Tools.execute_code]: true, @@ -47,8 +47,9 @@ const createAgentHandler = async (req, res) => { agentData.tools = []; + const availableTools = await getCachedTools({ includeGlobal: true }); for (const tool of tools) { - if (req.app.locals.availableTools[tool]) { + if (availableTools[tool]) { agentData.tools.push(tool); } @@ -169,12 +170,18 @@ const updateAgentHandler = async (req, res) => { }); } + /** @type {boolean} */ + const isProjectUpdate = (projectIds?.length ?? 0) > 0 || (removeProjectIds?.length ?? 0) > 0; + let updatedAgent = Object.keys(updateData).length > 0 - ? await updateAgent({ id }, updateData, { updatingUserId: req.user.id }) + ? await updateAgent({ id }, updateData, { + updatingUserId: req.user.id, + skipVersioning: isProjectUpdate, + }) : existingAgent; - if (projectIds || removeProjectIds) { + if (isProjectUpdate) { updatedAgent = await updateAgentProjects({ user: req.user, agentId: id, @@ -387,6 +394,7 @@ const uploadAgentAvatarHandler = async (req, res) => { buffer: resizedBuffer, userId: req.user.id, manual: 'false', + agentId: agent_id, }); const image = { @@ -438,7 +446,7 @@ const uploadAgentAvatarHandler = async (req, res) => { try { await fs.unlink(req.file.path); logger.debug('[/:agent_id/avatar] Temp. image upload file deleted'); - } catch (error) { + } catch { logger.debug('[/:agent_id/avatar] Temp. image upload file already deleted'); } } diff --git a/api/server/controllers/assistants/v1.js b/api/server/controllers/assistants/v1.js index 8fb73167c1..e723cda4fc 100644 --- a/api/server/controllers/assistants/v1.js +++ b/api/server/controllers/assistants/v1.js @@ -1,4 +1,5 @@ const fs = require('fs').promises; +const { logger } = require('@librechat/data-schemas'); const { FileContext } = require('librechat-data-provider'); const { uploadImageBuffer, filterFile } = require('~/server/services/Files/process'); const validateAuthor = require('~/server/middleware/assistants/validateAuthor'); @@ -6,9 +7,9 @@ const { getStrategyFunctions } = require('~/server/services/Files/strategies'); const { deleteAssistantActions } = require('~/server/services/ActionService'); const { updateAssistantDoc, getAssistants } = require('~/models/Assistant'); const { getOpenAIClient, fetchAssistants } = require('./helpers'); +const { getCachedTools } = require('~/server/services/Config'); const { manifestToolMap } = require('~/app/clients/tools'); const { deleteFileByFilter } = require('~/models/File'); -const { logger } = require('~/config'); /** * Create an assistant. @@ -30,21 +31,20 @@ const createAssistant = async (req, res) => { delete assistantData.conversation_starters; delete assistantData.append_current_datetime; + const toolDefinitions = await getCachedTools({ includeGlobal: true }); + assistantData.tools = tools .map((tool) => { if (typeof tool !== 'string') { return tool; } - const toolDefinitions = req.app.locals.availableTools; const toolDef = toolDefinitions[tool]; if (!toolDef && manifestToolMap[tool] && manifestToolMap[tool].toolkit === true) { - return ( - Object.entries(toolDefinitions) - .filter(([key]) => key.startsWith(`${tool}_`)) - // eslint-disable-next-line no-unused-vars - .map(([_, val]) => val) - ); + return Object.entries(toolDefinitions) + .filter(([key]) => key.startsWith(`${tool}_`)) + + .map(([_, val]) => val); } return toolDef; @@ -135,21 +135,21 @@ const patchAssistant = async (req, res) => { append_current_datetime, ...updateData } = req.body; + + const toolDefinitions = await getCachedTools({ includeGlobal: true }); + updateData.tools = (updateData.tools ?? []) .map((tool) => { if (typeof tool !== 'string') { return tool; } - const toolDefinitions = req.app.locals.availableTools; const toolDef = toolDefinitions[tool]; if (!toolDef && manifestToolMap[tool] && manifestToolMap[tool].toolkit === true) { - return ( - Object.entries(toolDefinitions) - .filter(([key]) => key.startsWith(`${tool}_`)) - // eslint-disable-next-line no-unused-vars - .map(([_, val]) => val) - ); + return Object.entries(toolDefinitions) + .filter(([key]) => key.startsWith(`${tool}_`)) + + .map(([_, val]) => val); } return toolDef; diff --git a/api/server/controllers/assistants/v2.js b/api/server/controllers/assistants/v2.js index 3bf83a626f..98441ba70a 100644 --- a/api/server/controllers/assistants/v2.js +++ b/api/server/controllers/assistants/v2.js @@ -1,10 +1,11 @@ +const { logger } = require('@librechat/data-schemas'); const { ToolCallTypes } = require('librechat-data-provider'); const validateAuthor = require('~/server/middleware/assistants/validateAuthor'); const { validateAndUpdateTool } = require('~/server/services/ActionService'); +const { getCachedTools } = require('~/server/services/Config'); const { updateAssistantDoc } = require('~/models/Assistant'); const { manifestToolMap } = require('~/app/clients/tools'); const { getOpenAIClient } = require('./helpers'); -const { logger } = require('~/config'); /** * Create an assistant. @@ -27,21 +28,20 @@ const createAssistant = async (req, res) => { delete assistantData.conversation_starters; delete assistantData.append_current_datetime; + const toolDefinitions = await getCachedTools({ includeGlobal: true }); + assistantData.tools = tools .map((tool) => { if (typeof tool !== 'string') { return tool; } - const toolDefinitions = req.app.locals.availableTools; const toolDef = toolDefinitions[tool]; if (!toolDef && manifestToolMap[tool] && manifestToolMap[tool].toolkit === true) { - return ( - Object.entries(toolDefinitions) - .filter(([key]) => key.startsWith(`${tool}_`)) - // eslint-disable-next-line no-unused-vars - .map(([_, val]) => val) - ); + return Object.entries(toolDefinitions) + .filter(([key]) => key.startsWith(`${tool}_`)) + + .map(([_, val]) => val); } return toolDef; @@ -125,13 +125,13 @@ const updateAssistant = async ({ req, openai, assistant_id, updateData }) => { let hasFileSearch = false; for (const tool of updateData.tools ?? []) { - const toolDefinitions = req.app.locals.availableTools; + const toolDefinitions = await getCachedTools({ includeGlobal: true }); let actualTool = typeof tool === 'string' ? toolDefinitions[tool] : tool; if (!actualTool && manifestToolMap[tool] && manifestToolMap[tool].toolkit === true) { actualTool = Object.entries(toolDefinitions) .filter(([key]) => key.startsWith(`${tool}_`)) - // eslint-disable-next-line no-unused-vars + .map(([_, val]) => val); } else if (!actualTool) { continue; diff --git a/api/server/index.js b/api/server/index.js index ed770f7703..8c7db3e226 100644 --- a/api/server/index.js +++ b/api/server/index.js @@ -1,22 +1,22 @@ require('dotenv').config(); +const fs = require('fs'); const path = require('path'); require('module-alias')({ base: path.resolve(__dirname, '..') }); const cors = require('cors'); const axios = require('axios'); const express = require('express'); -const compression = require('compression'); const passport = require('passport'); -const mongoSanitize = require('express-mongo-sanitize'); -const fs = require('fs'); +const compression = require('compression'); const cookieParser = require('cookie-parser'); +const { isEnabled } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); +const mongoSanitize = require('express-mongo-sanitize'); const { connectDb, indexSync } = require('~/db'); -const { jwtLogin, passportLogin } = require('~/strategies'); -const { isEnabled } = require('~/server/utils'); -const { ldapLogin } = require('~/strategies'); -const { logger } = require('~/config'); const validateImageRequest = require('./middleware/validateImageRequest'); +const { jwtLogin, ldapLogin, passportLogin } = require('~/strategies'); const errorController = require('./controllers/ErrorController'); +const initializeMCP = require('./services/initializeMCP'); const configureSocialLogins = require('./socialLogins'); const AppService = require('./services/AppService'); const staticCache = require('./utils/staticCache'); @@ -39,7 +39,9 @@ const startServer = async () => { await connectDb(); logger.info('Connected to MongoDB'); - await indexSync(); + indexSync().catch((err) => { + logger.error('[indexSync] Background sync failed:', err); + }); app.disable('x-powered-by'); app.set('trust proxy', trusted_proxy); @@ -117,8 +119,9 @@ const startServer = async () => { app.use('/api/agents', routes.agents); app.use('/api/banner', routes.banner); app.use('/api/bedrock', routes.bedrock); - + app.use('/api/memories', routes.memories); app.use('/api/tags', routes.tags); + app.use('/api/mcp', routes.mcp); app.use((req, res) => { res.set({ @@ -142,6 +145,8 @@ const startServer = async () => { } else { logger.info(`Server listening at http://${host == '0.0.0.0' ? 'localhost' : host}:${port}`); } + + initializeMCP(app); }); }; @@ -184,5 +189,5 @@ process.on('uncaughtException', (err) => { process.exit(1); }); -// export app for easier testing purposes +/** Export app for easier testing purposes */ module.exports = app; diff --git a/api/server/middleware/roles/generateCheckAccess.js b/api/server/middleware/roles/access.js similarity index 100% rename from api/server/middleware/roles/generateCheckAccess.js rename to api/server/middleware/roles/access.js diff --git a/api/server/middleware/roles/checkAdmin.js b/api/server/middleware/roles/admin.js similarity index 100% rename from api/server/middleware/roles/checkAdmin.js rename to api/server/middleware/roles/admin.js diff --git a/api/server/middleware/roles/index.js b/api/server/middleware/roles/index.js index a9fc5b2a08..ebc0043f2f 100644 --- a/api/server/middleware/roles/index.js +++ b/api/server/middleware/roles/index.js @@ -1,5 +1,5 @@ -const checkAdmin = require('./checkAdmin'); -const { checkAccess, generateCheckAccess } = require('./generateCheckAccess'); +const checkAdmin = require('./admin'); +const { checkAccess, generateCheckAccess } = require('./access'); module.exports = { checkAdmin, diff --git a/api/server/middleware/validate/convoAccess.js b/api/server/middleware/validate/convoAccess.js index 43cca0097d..afd2aeacef 100644 --- a/api/server/middleware/validate/convoAccess.js +++ b/api/server/middleware/validate/convoAccess.js @@ -1,8 +1,8 @@ +const { isEnabled } = require('@librechat/api'); const { Constants, ViolationTypes, Time } = require('librechat-data-provider'); const { searchConversation } = require('~/models/Conversation'); const denyRequest = require('~/server/middleware/denyRequest'); const { logViolation, getLogStores } = require('~/cache'); -const { isEnabled } = require('~/server/utils'); const { USE_REDIS, CONVO_ACCESS_VIOLATION_SCORE: score = 0 } = process.env ?? {}; diff --git a/api/server/routes/actions.js b/api/server/routes/actions.js index dc474d1a67..9f94f617ce 100644 --- a/api/server/routes/actions.js +++ b/api/server/routes/actions.js @@ -1,8 +1,10 @@ const express = require('express'); const jwt = require('jsonwebtoken'); +const { getAccessToken } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { CacheKeys } = require('librechat-data-provider'); -const { getAccessToken } = require('~/server/services/TokenService'); -const { logger, getFlowStateManager } = require('~/config'); +const { findToken, updateToken, createToken } = require('~/models'); +const { getFlowStateManager } = require('~/config'); const { getLogStores } = require('~/cache'); const router = express.Router(); @@ -28,18 +30,19 @@ router.get('/:action_id/oauth/callback', async (req, res) => { try { decodedState = jwt.verify(state, JWT_SECRET); } catch (err) { + logger.error('Error verifying state parameter:', err); await flowManager.failFlow(identifier, 'oauth', 'Invalid or expired state parameter'); - return res.status(400).send('Invalid or expired state parameter'); + return res.redirect('/oauth/error?error=invalid_state'); } if (decodedState.action_id !== action_id) { await flowManager.failFlow(identifier, 'oauth', 'Mismatched action ID in state parameter'); - return res.status(400).send('Mismatched action ID in state parameter'); + return res.redirect('/oauth/error?error=invalid_state'); } if (!decodedState.user) { await flowManager.failFlow(identifier, 'oauth', 'Invalid user ID in state parameter'); - return res.status(400).send('Invalid user ID in state parameter'); + return res.redirect('/oauth/error?error=invalid_state'); } identifier = `${decodedState.user}:${action_id}`; const flowState = await flowManager.getFlowState(identifier, 'oauth'); @@ -47,90 +50,34 @@ router.get('/:action_id/oauth/callback', async (req, res) => { throw new Error('OAuth flow not found'); } - const tokenData = await getAccessToken({ - code, - userId: decodedState.user, - identifier, - client_url: flowState.metadata.client_url, - redirect_uri: flowState.metadata.redirect_uri, - /** Encrypted values */ - encrypted_oauth_client_id: flowState.metadata.encrypted_oauth_client_id, - encrypted_oauth_client_secret: flowState.metadata.encrypted_oauth_client_secret, - }); + const tokenData = await getAccessToken( + { + code, + userId: decodedState.user, + identifier, + client_url: flowState.metadata.client_url, + redirect_uri: flowState.metadata.redirect_uri, + token_exchange_method: flowState.metadata.token_exchange_method, + /** Encrypted values */ + encrypted_oauth_client_id: flowState.metadata.encrypted_oauth_client_id, + encrypted_oauth_client_secret: flowState.metadata.encrypted_oauth_client_secret, + }, + { + findToken, + updateToken, + createToken, + }, + ); await flowManager.completeFlow(identifier, 'oauth', tokenData); - res.send(` - - - - Authentication Successful - - - - - -
-

Authentication Successful

-

- Your authentication was successful. This window will close in - 3 seconds. -

-
- - - - `); + + /** Redirect to React success page */ + const serverName = flowState.metadata?.action_name || `Action ${action_id}`; + const redirectUrl = `/oauth/success?serverName=${encodeURIComponent(serverName)}`; + res.redirect(redirectUrl); } catch (error) { logger.error('Error in OAuth callback:', error); await flowManager.failFlow(identifier, 'oauth', error); - res.status(500).send('Authentication failed. Please try again.'); + res.redirect('/oauth/error?error=callback_failed'); } }); diff --git a/api/server/routes/config.js b/api/server/routes/config.js index a53a636d05..dd93037dd9 100644 --- a/api/server/routes/config.js +++ b/api/server/routes/config.js @@ -1,10 +1,11 @@ const express = require('express'); +const { logger } = require('@librechat/data-schemas'); const { CacheKeys, defaultSocialLogins, Constants } = require('librechat-data-provider'); +const { getCustomConfig } = require('~/server/services/Config/getCustomConfig'); const { getLdapConfig } = require('~/server/services/Config/ldap'); const { getProjectByName } = require('~/models/Project'); const { isEnabled } = require('~/server/utils'); const { getLogStores } = require('~/cache'); -const { logger } = require('~/config'); const router = express.Router(); const emailLoginEnabled = @@ -21,6 +22,7 @@ const publicSharedLinksEnabled = router.get('/', async function (req, res) { const cache = getLogStores(CacheKeys.CONFIG_STORE); + const cachedStartupConfig = await cache.get(CacheKeys.STARTUP_CONFIG); if (cachedStartupConfig) { res.send(cachedStartupConfig); @@ -96,6 +98,18 @@ router.get('/', async function (req, res) { bundlerURL: process.env.SANDPACK_BUNDLER_URL, staticBundlerURL: process.env.SANDPACK_STATIC_BUNDLER_URL, }; + + payload.mcpServers = {}; + const config = await getCustomConfig(); + if (config?.mcpServers != null) { + for (const serverName in config.mcpServers) { + const serverConfig = config.mcpServers[serverName]; + payload.mcpServers[serverName] = { + customUserVars: serverConfig?.customUserVars || {}, + }; + } + } + /** @type {TCustomConfig['webSearch']} */ const webSearchConfig = req.app.locals.webSearch; if ( diff --git a/api/server/routes/convos.js b/api/server/routes/convos.js index 87bac6ed29..eb7e2c5c27 100644 --- a/api/server/routes/convos.js +++ b/api/server/routes/convos.js @@ -65,8 +65,14 @@ router.post('/gen_title', async (req, res) => { let title = await titleCache.get(key); if (!title) { - await sleep(2500); - title = await titleCache.get(key); + // Retry every 1s for up to 20s + for (let i = 0; i < 20; i++) { + await sleep(1000); + title = await titleCache.get(key); + if (title) { + break; + } + } } if (title) { diff --git a/api/server/routes/files/multer.js b/api/server/routes/files/multer.js index f23ecd2823..257c309fa2 100644 --- a/api/server/routes/files/multer.js +++ b/api/server/routes/files/multer.js @@ -2,8 +2,8 @@ const fs = require('fs'); const path = require('path'); const crypto = require('crypto'); const multer = require('multer'); +const { sanitizeFilename } = require('@librechat/api'); const { fileConfig: defaultFileConfig, mergeFileConfig } = require('librechat-data-provider'); -const { sanitizeFilename } = require('~/server/utils/handleText'); const { getCustomConfig } = require('~/server/services/Config'); const storage = multer.diskStorage({ diff --git a/api/server/routes/files/multer.spec.js b/api/server/routes/files/multer.spec.js new file mode 100644 index 0000000000..0324262a71 --- /dev/null +++ b/api/server/routes/files/multer.spec.js @@ -0,0 +1,571 @@ +/* eslint-disable no-unused-vars */ +/* eslint-disable jest/no-done-callback */ +const fs = require('fs'); +const os = require('os'); +const path = require('path'); +const crypto = require('crypto'); +const { createMulterInstance, storage, importFileFilter } = require('./multer'); + +// Mock only the config service that requires external dependencies +jest.mock('~/server/services/Config', () => ({ + getCustomConfig: jest.fn(() => + Promise.resolve({ + fileConfig: { + endpoints: { + openAI: { + supportedMimeTypes: ['image/jpeg', 'image/png', 'application/pdf'], + }, + default: { + supportedMimeTypes: ['image/jpeg', 'image/png', 'text/plain'], + }, + }, + serverFileSizeLimit: 10000000, // 10MB + }, + }), + ), +})); + +describe('Multer Configuration', () => { + let tempDir; + let mockReq; + let mockFile; + + beforeEach(() => { + // Create a temporary directory for each test + tempDir = fs.mkdtempSync(path.join(os.tmpdir(), 'multer-test-')); + + mockReq = { + user: { id: 'test-user-123' }, + app: { + locals: { + paths: { + uploads: tempDir, + }, + }, + }, + body: {}, + originalUrl: '/api/files/upload', + }; + + mockFile = { + originalname: 'test-file.jpg', + mimetype: 'image/jpeg', + size: 1024, + }; + + // Clear mocks + jest.clearAllMocks(); + }); + + afterEach(() => { + // Clean up temporary directory + if (fs.existsSync(tempDir)) { + fs.rmSync(tempDir, { recursive: true, force: true }); + } + }); + + describe('Storage Configuration', () => { + describe('destination function', () => { + it('should create the correct destination path', (done) => { + const cb = jest.fn((err, destination) => { + expect(err).toBeNull(); + expect(destination).toBe(path.join(tempDir, 'temp', 'test-user-123')); + expect(fs.existsSync(destination)).toBe(true); + done(); + }); + + storage.getDestination(mockReq, mockFile, cb); + }); + + it("should create directory recursively if it doesn't exist", (done) => { + const deepPath = path.join(tempDir, 'deep', 'nested', 'path'); + mockReq.app.locals.paths.uploads = deepPath; + + const cb = jest.fn((err, destination) => { + expect(err).toBeNull(); + expect(destination).toBe(path.join(deepPath, 'temp', 'test-user-123')); + expect(fs.existsSync(destination)).toBe(true); + done(); + }); + + storage.getDestination(mockReq, mockFile, cb); + }); + }); + + describe('filename function', () => { + it('should generate a UUID for req.file_id', (done) => { + const cb = jest.fn((err, filename) => { + expect(err).toBeNull(); + expect(mockReq.file_id).toBeDefined(); + expect(mockReq.file_id).toMatch( + /^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i, + ); + done(); + }); + + storage.getFilename(mockReq, mockFile, cb); + }); + + it('should decode URI components in filename', (done) => { + const encodedFile = { + ...mockFile, + originalname: encodeURIComponent('test file with spaces.jpg'), + }; + + const cb = jest.fn((err, filename) => { + expect(err).toBeNull(); + expect(encodedFile.originalname).toBe('test file with spaces.jpg'); + done(); + }); + + storage.getFilename(mockReq, encodedFile, cb); + }); + + it('should call real sanitizeFilename with properly encoded filename', (done) => { + // Test with a properly URI-encoded filename that needs sanitization + const unsafeFile = { + ...mockFile, + originalname: encodeURIComponent('test@#$%^&*()file with spaces!.jpg'), + }; + + const cb = jest.fn((err, filename) => { + expect(err).toBeNull(); + // The actual sanitizeFilename should have cleaned this up after decoding + expect(filename).not.toContain('@'); + expect(filename).not.toContain('#'); + expect(filename).not.toContain('*'); + expect(filename).not.toContain('!'); + // Should still preserve dots and hyphens + expect(filename).toContain('.jpg'); + done(); + }); + + storage.getFilename(mockReq, unsafeFile, cb); + }); + + it('should handle very long filenames with actual crypto', (done) => { + const longFile = { + ...mockFile, + originalname: 'a'.repeat(300) + '.jpg', + }; + + const cb = jest.fn((err, filename) => { + expect(err).toBeNull(); + expect(filename.length).toBeLessThanOrEqual(255); + expect(filename).toMatch(/\.jpg$/); // Should still end with .jpg + // Should contain a hex suffix if truncated + if (filename.length === 255) { + expect(filename).toMatch(/-[a-f0-9]{6}\.jpg$/); + } + done(); + }); + + storage.getFilename(mockReq, longFile, cb); + }); + + it('should generate unique file_id for each call', (done) => { + let firstFileId; + + const firstCb = jest.fn((err, filename) => { + expect(err).toBeNull(); + firstFileId = mockReq.file_id; + + // Reset req for second call + delete mockReq.file_id; + + const secondCb = jest.fn((err, filename) => { + expect(err).toBeNull(); + expect(mockReq.file_id).toBeDefined(); + expect(mockReq.file_id).not.toBe(firstFileId); + done(); + }); + + storage.getFilename(mockReq, mockFile, secondCb); + }); + + storage.getFilename(mockReq, mockFile, firstCb); + }); + }); + }); + + describe('Import File Filter', () => { + it('should accept JSON files by mimetype', (done) => { + const jsonFile = { + ...mockFile, + mimetype: 'application/json', + originalname: 'data.json', + }; + + const cb = jest.fn((err, result) => { + expect(err).toBeNull(); + expect(result).toBe(true); + done(); + }); + + importFileFilter(mockReq, jsonFile, cb); + }); + + it('should accept files with .json extension', (done) => { + const jsonFile = { + ...mockFile, + mimetype: 'text/plain', + originalname: 'data.json', + }; + + const cb = jest.fn((err, result) => { + expect(err).toBeNull(); + expect(result).toBe(true); + done(); + }); + + importFileFilter(mockReq, jsonFile, cb); + }); + + it('should reject non-JSON files', (done) => { + const textFile = { + ...mockFile, + mimetype: 'text/plain', + originalname: 'document.txt', + }; + + const cb = jest.fn((err, result) => { + expect(err).toBeInstanceOf(Error); + expect(err.message).toBe('Only JSON files are allowed'); + expect(result).toBe(false); + done(); + }); + + importFileFilter(mockReq, textFile, cb); + }); + + it('should handle files with uppercase .JSON extension', (done) => { + const jsonFile = { + ...mockFile, + mimetype: 'text/plain', + originalname: 'DATA.JSON', + }; + + const cb = jest.fn((err, result) => { + expect(err).toBeNull(); + expect(result).toBe(true); + done(); + }); + + importFileFilter(mockReq, jsonFile, cb); + }); + }); + + describe('File Filter with Real defaultFileConfig', () => { + it('should use real fileConfig.checkType for validation', async () => { + // Test with actual librechat-data-provider functions + const { + fileConfig, + imageMimeTypes, + applicationMimeTypes, + } = require('librechat-data-provider'); + + // Test that the real checkType function works with regex patterns + expect(fileConfig.checkType('image/jpeg', [imageMimeTypes])).toBe(true); + expect(fileConfig.checkType('video/mp4', [imageMimeTypes])).toBe(false); + expect(fileConfig.checkType('application/pdf', [applicationMimeTypes])).toBe(true); + expect(fileConfig.checkType('application/pdf', [])).toBe(false); + }); + + it('should handle audio files for speech-to-text endpoint with real config', async () => { + mockReq.originalUrl = '/api/speech/stt'; + + const multerInstance = await createMulterInstance(); + expect(multerInstance).toBeDefined(); + expect(typeof multerInstance.single).toBe('function'); + }); + + it('should reject unsupported file types using real config', async () => { + // Mock defaultFileConfig for this specific test + const originalCheckType = require('librechat-data-provider').fileConfig.checkType; + const mockCheckType = jest.fn().mockReturnValue(false); + require('librechat-data-provider').fileConfig.checkType = mockCheckType; + + try { + const multerInstance = await createMulterInstance(); + expect(multerInstance).toBeDefined(); + + // Test the actual file filter behavior would reject unsupported files + expect(mockCheckType).toBeDefined(); + } finally { + // Restore original function + require('librechat-data-provider').fileConfig.checkType = originalCheckType; + } + }); + + it('should use real mergeFileConfig function', async () => { + const { mergeFileConfig, mbToBytes } = require('librechat-data-provider'); + + // Test with actual merge function - note that it converts MB to bytes + const testConfig = { + serverFileSizeLimit: 5, // 5 MB + endpoints: { + custom: { + supportedMimeTypes: ['text/plain'], + }, + }, + }; + + const result = mergeFileConfig(testConfig); + + // The function converts MB to bytes, so 5 MB becomes 5 * 1024 * 1024 bytes + expect(result.serverFileSizeLimit).toBe(mbToBytes(5)); + expect(result.endpoints.custom.supportedMimeTypes).toBeDefined(); + // Should still have the default endpoints + expect(result.endpoints.default).toBeDefined(); + }); + }); + + describe('createMulterInstance with Real Functions', () => { + it('should create a multer instance with correct configuration', async () => { + const multerInstance = await createMulterInstance(); + + expect(multerInstance).toBeDefined(); + expect(typeof multerInstance.single).toBe('function'); + expect(typeof multerInstance.array).toBe('function'); + expect(typeof multerInstance.fields).toBe('function'); + }); + + it('should use real config merging', async () => { + const { getCustomConfig } = require('~/server/services/Config'); + + const multerInstance = await createMulterInstance(); + + expect(getCustomConfig).toHaveBeenCalled(); + expect(multerInstance).toBeDefined(); + }); + + it('should create multer instance with expected interface', async () => { + const multerInstance = await createMulterInstance(); + + expect(multerInstance).toBeDefined(); + expect(typeof multerInstance.single).toBe('function'); + expect(typeof multerInstance.array).toBe('function'); + expect(typeof multerInstance.fields).toBe('function'); + }); + }); + + describe('Real Crypto Integration', () => { + it('should use actual crypto.randomUUID()', (done) => { + // Spy on crypto.randomUUID to ensure it's called + const uuidSpy = jest.spyOn(crypto, 'randomUUID'); + + const cb = jest.fn((err, filename) => { + expect(err).toBeNull(); + expect(uuidSpy).toHaveBeenCalled(); + expect(mockReq.file_id).toMatch( + /^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i, + ); + + uuidSpy.mockRestore(); + done(); + }); + + storage.getFilename(mockReq, mockFile, cb); + }); + + it('should generate different UUIDs on subsequent calls', (done) => { + const uuids = []; + let callCount = 0; + const totalCalls = 5; + + const cb = jest.fn((err, filename) => { + expect(err).toBeNull(); + uuids.push(mockReq.file_id); + callCount++; + + if (callCount === totalCalls) { + // Check that all UUIDs are unique + const uniqueUuids = new Set(uuids); + expect(uniqueUuids.size).toBe(totalCalls); + done(); + } else { + // Reset for next call + delete mockReq.file_id; + storage.getFilename(mockReq, mockFile, cb); + } + }); + + // Start the chain + storage.getFilename(mockReq, mockFile, cb); + }); + + it('should generate cryptographically secure UUIDs', (done) => { + const generatedUuids = new Set(); + let callCount = 0; + const totalCalls = 10; + + const cb = jest.fn((err, filename) => { + expect(err).toBeNull(); + + // Verify UUID format and uniqueness + expect(mockReq.file_id).toMatch( + /^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$/i, + ); + + generatedUuids.add(mockReq.file_id); + callCount++; + + if (callCount === totalCalls) { + // All UUIDs should be unique + expect(generatedUuids.size).toBe(totalCalls); + done(); + } else { + // Reset for next call + delete mockReq.file_id; + storage.getFilename(mockReq, mockFile, cb); + } + }); + + // Start the chain + storage.getFilename(mockReq, mockFile, cb); + }); + }); + + describe('Error Handling', () => { + it('should handle CVE-2024-28870: empty field name DoS vulnerability', async () => { + // Test for the CVE where empty field name could cause unhandled exception + const multerInstance = await createMulterInstance(); + + // Create a mock request with empty field name (the vulnerability scenario) + const mockReqWithEmptyField = { + ...mockReq, + headers: { + 'content-type': 'multipart/form-data', + }, + }; + + const mockRes = { + status: jest.fn().mockReturnThis(), + json: jest.fn(), + end: jest.fn(), + }; + + // This should not crash or throw unhandled exceptions + const uploadMiddleware = multerInstance.single(''); // Empty field name + + const mockNext = jest.fn((err) => { + // If there's an error, it should be handled gracefully, not crash + if (err) { + expect(err).toBeInstanceOf(Error); + // The error should be handled, not crash the process + } + }); + + // This should complete without crashing the process + expect(() => { + uploadMiddleware(mockReqWithEmptyField, mockRes, mockNext); + }).not.toThrow(); + }); + + it('should handle file system errors when directory creation fails', (done) => { + // Test with a non-existent parent directory to simulate fs issues + const invalidPath = '/nonexistent/path/that/should/not/exist'; + mockReq.app.locals.paths.uploads = invalidPath; + + try { + // Call getDestination which should fail due to permission/path issues + storage.getDestination(mockReq, mockFile, (err, destination) => { + // If callback is reached, we didn't get the expected error + done(new Error('Expected mkdirSync to throw an error but callback was called')); + }); + // If we get here without throwing, something unexpected happened + done(new Error('Expected mkdirSync to throw an error but no error was thrown')); + } catch (error) { + // This is the expected behavior - mkdirSync throws synchronously for invalid paths + expect(error.code).toBe('EACCES'); + done(); + } + }); + + it('should handle malformed filenames with real sanitization', (done) => { + const malformedFile = { + ...mockFile, + originalname: null, // This should be handled gracefully + }; + + const cb = jest.fn((err, filename) => { + // The function should handle this gracefully + expect(typeof err === 'object' || err === null).toBe(true); + done(); + }); + + try { + storage.getFilename(mockReq, malformedFile, cb); + } catch (error) { + // If it throws, that's also acceptable behavior + done(); + } + }); + + it('should handle edge cases in filename sanitization', (done) => { + const edgeCaseFiles = [ + { originalname: '', expected: /_/ }, + { originalname: '.hidden', expected: /^_\.hidden/ }, + { originalname: '../../../etc/passwd', expected: /passwd/ }, + { originalname: 'file\x00name.txt', expected: /file_name\.txt/ }, + ]; + + let testCount = 0; + + const testNextFile = (fileData) => { + const fileToTest = { ...mockFile, originalname: fileData.originalname }; + + const cb = jest.fn((err, filename) => { + expect(err).toBeNull(); + expect(filename).toMatch(fileData.expected); + + testCount++; + if (testCount === edgeCaseFiles.length) { + done(); + } else { + testNextFile(edgeCaseFiles[testCount]); + } + }); + + storage.getFilename(mockReq, fileToTest, cb); + }; + + testNextFile(edgeCaseFiles[0]); + }); + }); + + describe('Real Configuration Testing', () => { + it('should handle missing custom config gracefully with real mergeFileConfig', async () => { + const { getCustomConfig } = require('~/server/services/Config'); + + // Mock getCustomConfig to return undefined + getCustomConfig.mockResolvedValueOnce(undefined); + + const multerInstance = await createMulterInstance(); + expect(multerInstance).toBeDefined(); + expect(typeof multerInstance.single).toBe('function'); + }); + + it('should properly integrate real fileConfig with custom endpoints', async () => { + const { getCustomConfig } = require('~/server/services/Config'); + + // Mock a custom config with additional endpoints + getCustomConfig.mockResolvedValueOnce({ + fileConfig: { + endpoints: { + anthropic: { + supportedMimeTypes: ['text/plain', 'image/png'], + }, + }, + serverFileSizeLimit: 20, // 20 MB + }, + }); + + const multerInstance = await createMulterInstance(); + expect(multerInstance).toBeDefined(); + + // Verify that getCustomConfig was called (we can't spy on the actual merge function easily) + expect(getCustomConfig).toHaveBeenCalled(); + }); + }); +}); diff --git a/api/server/routes/index.js b/api/server/routes/index.js index 449759383d..7c1b5de0fa 100644 --- a/api/server/routes/index.js +++ b/api/server/routes/index.js @@ -4,6 +4,7 @@ const tokenizer = require('./tokenizer'); const endpoints = require('./endpoints'); const staticRoute = require('./static'); const messages = require('./messages'); +const memories = require('./memories'); const presets = require('./presets'); const prompts = require('./prompts'); const balance = require('./balance'); @@ -26,6 +27,7 @@ const edit = require('./edit'); const keys = require('./keys'); const user = require('./user'); const ask = require('./ask'); +const mcp = require('./mcp'); module.exports = { ask, @@ -51,9 +53,11 @@ module.exports = { presets, balance, messages, + memories, endpoints, tokenizer, assistants, categories, staticRoute, + mcp, }; diff --git a/api/server/routes/mcp.js b/api/server/routes/mcp.js new file mode 100644 index 0000000000..3dfed4d240 --- /dev/null +++ b/api/server/routes/mcp.js @@ -0,0 +1,205 @@ +const { Router } = require('express'); +const { MCPOAuthHandler } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); +const { CacheKeys } = require('librechat-data-provider'); +const { requireJwtAuth } = require('~/server/middleware'); +const { getFlowStateManager } = require('~/config'); +const { getLogStores } = require('~/cache'); + +const router = Router(); + +/** + * Initiate OAuth flow + * This endpoint is called when the user clicks the auth link in the UI + */ +router.get('/:serverName/oauth/initiate', requireJwtAuth, async (req, res) => { + try { + const { serverName } = req.params; + const { userId, flowId } = req.query; + const user = req.user; + + // Verify the userId matches the authenticated user + if (userId !== user.id) { + return res.status(403).json({ error: 'User mismatch' }); + } + + logger.debug('[MCP OAuth] Initiate request', { serverName, userId, flowId }); + + const flowsCache = getLogStores(CacheKeys.FLOWS); + const flowManager = getFlowStateManager(flowsCache); + + /** Flow state to retrieve OAuth config */ + const flowState = await flowManager.getFlowState(flowId, 'mcp_oauth'); + if (!flowState) { + logger.error('[MCP OAuth] Flow state not found', { flowId }); + return res.status(404).json({ error: 'Flow not found' }); + } + + const { serverUrl, oauth: oauthConfig } = flowState.metadata || {}; + if (!serverUrl || !oauthConfig) { + logger.error('[MCP OAuth] Missing server URL or OAuth config in flow state'); + return res.status(400).json({ error: 'Invalid flow state' }); + } + + const { authorizationUrl, flowId: oauthFlowId } = await MCPOAuthHandler.initiateOAuthFlow( + serverName, + serverUrl, + userId, + oauthConfig, + ); + + logger.debug('[MCP OAuth] OAuth flow initiated', { oauthFlowId, authorizationUrl }); + + // Redirect user to the authorization URL + res.redirect(authorizationUrl); + } catch (error) { + logger.error('[MCP OAuth] Failed to initiate OAuth', error); + res.status(500).json({ error: 'Failed to initiate OAuth' }); + } +}); + +/** + * OAuth callback handler + * This handles the OAuth callback after the user has authorized the application + */ +router.get('/:serverName/oauth/callback', async (req, res) => { + try { + const { serverName } = req.params; + const { code, state, error: oauthError } = req.query; + + logger.debug('[MCP OAuth] Callback received', { + serverName, + code: code ? 'present' : 'missing', + state, + error: oauthError, + }); + + if (oauthError) { + logger.error('[MCP OAuth] OAuth error received', { error: oauthError }); + return res.redirect(`/oauth/error?error=${encodeURIComponent(String(oauthError))}`); + } + + if (!code || typeof code !== 'string') { + logger.error('[MCP OAuth] Missing or invalid code'); + return res.redirect('/oauth/error?error=missing_code'); + } + + if (!state || typeof state !== 'string') { + logger.error('[MCP OAuth] Missing or invalid state'); + return res.redirect('/oauth/error?error=missing_state'); + } + + // Extract flow ID from state + const flowId = state; + logger.debug('[MCP OAuth] Using flow ID from state', { flowId }); + + const flowsCache = getLogStores(CacheKeys.FLOWS); + const flowManager = getFlowStateManager(flowsCache); + + logger.debug('[MCP OAuth] Getting flow state for flowId: ' + flowId); + const flowState = await MCPOAuthHandler.getFlowState(flowId, flowManager); + + if (!flowState) { + logger.error('[MCP OAuth] Flow state not found for flowId:', flowId); + return res.redirect('/oauth/error?error=invalid_state'); + } + + logger.debug('[MCP OAuth] Flow state details', { + serverName: flowState.serverName, + userId: flowState.userId, + hasMetadata: !!flowState.metadata, + hasClientInfo: !!flowState.clientInfo, + hasCodeVerifier: !!flowState.codeVerifier, + }); + + // Complete the OAuth flow + logger.debug('[MCP OAuth] Completing OAuth flow'); + const tokens = await MCPOAuthHandler.completeOAuthFlow(flowId, code, flowManager); + logger.info('[MCP OAuth] OAuth flow completed, tokens received in callback route'); + + // For system-level OAuth, we need to store the tokens and retry the connection + if (flowState.userId === 'system') { + logger.debug(`[MCP OAuth] System-level OAuth completed for ${serverName}`); + } + + /** ID of the flow that the tool/connection is waiting for */ + const toolFlowId = flowState.metadata?.toolFlowId; + if (toolFlowId) { + logger.debug('[MCP OAuth] Completing tool flow', { toolFlowId }); + await flowManager.completeFlow(toolFlowId, 'mcp_oauth', tokens); + } + + /** Redirect to success page with flowId and serverName */ + const redirectUrl = `/oauth/success?serverName=${encodeURIComponent(serverName)}`; + res.redirect(redirectUrl); + } catch (error) { + logger.error('[MCP OAuth] OAuth callback error', error); + res.redirect('/oauth/error?error=callback_failed'); + } +}); + +/** + * Get OAuth tokens for a completed flow + * This is primarily for user-level OAuth flows + */ +router.get('/oauth/tokens/:flowId', requireJwtAuth, async (req, res) => { + try { + const { flowId } = req.params; + const user = req.user; + + if (!user?.id) { + return res.status(401).json({ error: 'User not authenticated' }); + } + + // Allow system flows or user-owned flows + if (!flowId.startsWith(`${user.id}:`) && !flowId.startsWith('system:')) { + return res.status(403).json({ error: 'Access denied' }); + } + + const flowsCache = getLogStores(CacheKeys.FLOWS); + const flowManager = getFlowStateManager(flowsCache); + + const flowState = await flowManager.getFlowState(flowId, 'mcp_oauth'); + if (!flowState) { + return res.status(404).json({ error: 'Flow not found' }); + } + + if (flowState.status !== 'COMPLETED') { + return res.status(400).json({ error: 'Flow not completed' }); + } + + res.json({ tokens: flowState.result }); + } catch (error) { + logger.error('[MCP OAuth] Failed to get tokens', error); + res.status(500).json({ error: 'Failed to get tokens' }); + } +}); + +/** + * Check OAuth flow status + * This endpoint can be used to poll the status of an OAuth flow + */ +router.get('/oauth/status/:flowId', async (req, res) => { + try { + const { flowId } = req.params; + const flowsCache = getLogStores(CacheKeys.FLOWS); + const flowManager = getFlowStateManager(flowsCache); + + const flowState = await flowManager.getFlowState(flowId, 'mcp_oauth'); + if (!flowState) { + return res.status(404).json({ error: 'Flow not found' }); + } + + res.json({ + status: flowState.status, + completed: flowState.status === 'COMPLETED', + failed: flowState.status === 'FAILED', + error: flowState.error, + }); + } catch (error) { + logger.error('[MCP OAuth] Failed to get flow status', error); + res.status(500).json({ error: 'Failed to get flow status' }); + } +}); + +module.exports = router; diff --git a/api/server/routes/memories.js b/api/server/routes/memories.js new file mode 100644 index 0000000000..86065fecaa --- /dev/null +++ b/api/server/routes/memories.js @@ -0,0 +1,231 @@ +const express = require('express'); +const { Tokenizer } = require('@librechat/api'); +const { PermissionTypes, Permissions } = require('librechat-data-provider'); +const { + getAllUserMemories, + toggleUserMemories, + createMemory, + setMemory, + deleteMemory, +} = require('~/models'); +const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware'); + +const router = express.Router(); + +const checkMemoryRead = generateCheckAccess(PermissionTypes.MEMORIES, [ + Permissions.USE, + Permissions.READ, +]); +const checkMemoryCreate = generateCheckAccess(PermissionTypes.MEMORIES, [ + Permissions.USE, + Permissions.CREATE, +]); +const checkMemoryUpdate = generateCheckAccess(PermissionTypes.MEMORIES, [ + Permissions.USE, + Permissions.UPDATE, +]); +const checkMemoryDelete = generateCheckAccess(PermissionTypes.MEMORIES, [ + Permissions.USE, + Permissions.UPDATE, +]); +const checkMemoryOptOut = generateCheckAccess(PermissionTypes.MEMORIES, [ + Permissions.USE, + Permissions.OPT_OUT, +]); + +router.use(requireJwtAuth); + +/** + * GET /memories + * Returns all memories for the authenticated user, sorted by updated_at (newest first). + * Also includes memory usage percentage based on token limit. + */ +router.get('/', checkMemoryRead, async (req, res) => { + try { + const memories = await getAllUserMemories(req.user.id); + + const sortedMemories = memories.sort( + (a, b) => new Date(b.updated_at).getTime() - new Date(a.updated_at).getTime(), + ); + + const totalTokens = memories.reduce((sum, memory) => { + return sum + (memory.tokenCount || 0); + }, 0); + + const memoryConfig = req.app.locals?.memory; + const tokenLimit = memoryConfig?.tokenLimit; + + let usagePercentage = null; + if (tokenLimit && tokenLimit > 0) { + usagePercentage = Math.min(100, Math.round((totalTokens / tokenLimit) * 100)); + } + + res.json({ + memories: sortedMemories, + totalTokens, + tokenLimit: tokenLimit || null, + usagePercentage, + }); + } catch (error) { + res.status(500).json({ error: error.message }); + } +}); + +/** + * POST /memories + * Creates a new memory entry for the authenticated user. + * Body: { key: string, value: string } + * Returns 201 and { created: true, memory: } when successful. + */ +router.post('/', checkMemoryCreate, async (req, res) => { + const { key, value } = req.body; + + if (typeof key !== 'string' || key.trim() === '') { + return res.status(400).json({ error: 'Key is required and must be a non-empty string.' }); + } + + if (typeof value !== 'string' || value.trim() === '') { + return res.status(400).json({ error: 'Value is required and must be a non-empty string.' }); + } + + try { + const tokenCount = Tokenizer.getTokenCount(value, 'o200k_base'); + + const memories = await getAllUserMemories(req.user.id); + + // Check token limit + const memoryConfig = req.app.locals?.memory; + const tokenLimit = memoryConfig?.tokenLimit; + + if (tokenLimit) { + const currentTotalTokens = memories.reduce( + (sum, memory) => sum + (memory.tokenCount || 0), + 0, + ); + if (currentTotalTokens + tokenCount > tokenLimit) { + return res.status(400).json({ + error: `Adding this memory would exceed the token limit of ${tokenLimit}. Current usage: ${currentTotalTokens} tokens.`, + }); + } + } + + const result = await createMemory({ + userId: req.user.id, + key: key.trim(), + value: value.trim(), + tokenCount, + }); + + if (!result.ok) { + return res.status(500).json({ error: 'Failed to create memory.' }); + } + + const updatedMemories = await getAllUserMemories(req.user.id); + const newMemory = updatedMemories.find((m) => m.key === key.trim()); + + res.status(201).json({ created: true, memory: newMemory }); + } catch (error) { + if (error.message && error.message.includes('already exists')) { + return res.status(409).json({ error: 'Memory with this key already exists.' }); + } + res.status(500).json({ error: error.message }); + } +}); + +/** + * PATCH /memories/preferences + * Updates the user's memory preferences (e.g., enabling/disabling memories). + * Body: { memories: boolean } + * Returns 200 and { updated: true, preferences: { memories: boolean } } when successful. + */ +router.patch('/preferences', checkMemoryOptOut, async (req, res) => { + const { memories } = req.body; + + if (typeof memories !== 'boolean') { + return res.status(400).json({ error: 'memories must be a boolean value.' }); + } + + try { + const updatedUser = await toggleUserMemories(req.user.id, memories); + + if (!updatedUser) { + return res.status(404).json({ error: 'User not found.' }); + } + + res.json({ + updated: true, + preferences: { + memories: updatedUser.personalization?.memories ?? true, + }, + }); + } catch (error) { + res.status(500).json({ error: error.message }); + } +}); + +/** + * PATCH /memories/:key + * Updates the value of an existing memory entry for the authenticated user. + * Body: { value: string } + * Returns 200 and { updated: true, memory: } when successful. + */ +router.patch('/:key', checkMemoryUpdate, async (req, res) => { + const { key } = req.params; + const { value } = req.body || {}; + + if (typeof value !== 'string' || value.trim() === '') { + return res.status(400).json({ error: 'Value is required and must be a non-empty string.' }); + } + + try { + const tokenCount = Tokenizer.getTokenCount(value, 'o200k_base'); + + const memories = await getAllUserMemories(req.user.id); + const existingMemory = memories.find((m) => m.key === key); + + if (!existingMemory) { + return res.status(404).json({ error: 'Memory not found.' }); + } + + const result = await setMemory({ + userId: req.user.id, + key, + value, + tokenCount, + }); + + if (!result.ok) { + return res.status(500).json({ error: 'Failed to update memory.' }); + } + + const updatedMemories = await getAllUserMemories(req.user.id); + const updatedMemory = updatedMemories.find((m) => m.key === key); + + res.json({ updated: true, memory: updatedMemory }); + } catch (error) { + res.status(500).json({ error: error.message }); + } +}); + +/** + * DELETE /memories/:key + * Deletes a memory entry for the authenticated user. + * Returns 200 and { deleted: true } when successful. + */ +router.delete('/:key', checkMemoryDelete, async (req, res) => { + const { key } = req.params; + + try { + const result = await deleteMemory({ userId: req.user.id, key }); + + if (!result.ok) { + return res.status(404).json({ error: 'Memory not found.' }); + } + + res.json({ deleted: true }); + } catch (error) { + res.status(500).json({ error: error.message }); + } +}); + +module.exports = router; diff --git a/api/server/routes/oauth.js b/api/server/routes/oauth.js index bc8d120ef5..afc4a05b75 100644 --- a/api/server/routes/oauth.js +++ b/api/server/routes/oauth.js @@ -47,7 +47,9 @@ const oauthHandler = async (req, res) => { router.get('/error', (req, res) => { // A single error message is pushed by passport when authentication fails. - logger.error('Error in OAuth authentication:', { message: req.session.messages.pop() }); + logger.error('Error in OAuth authentication:', { + message: req.session?.messages?.pop() || 'Unknown error', + }); // Redirect to login page with auth_failed parameter to prevent infinite redirect loops res.redirect(`${domains.client}/login?redirect=false`); diff --git a/api/server/routes/roles.js b/api/server/routes/roles.js index 17768c7de6..aefbfcec0c 100644 --- a/api/server/routes/roles.js +++ b/api/server/routes/roles.js @@ -1,6 +1,7 @@ const express = require('express'); const { promptPermissionsSchema, + memoryPermissionsSchema, agentPermissionsSchema, PermissionTypes, roleDefaults, @@ -118,4 +119,43 @@ router.put('/:roleName/agents', checkAdmin, async (req, res) => { } }); +/** + * PUT /api/roles/:roleName/memories + * Update memory permissions for a specific role + */ +router.put('/:roleName/memories', checkAdmin, async (req, res) => { + const { roleName: _r } = req.params; + // TODO: TEMP, use a better parsing for roleName + const roleName = _r.toUpperCase(); + /** @type {TRole['permissions']['MEMORIES']} */ + const updates = req.body; + + try { + const parsedUpdates = memoryPermissionsSchema.partial().parse(updates); + + const role = await getRoleByName(roleName); + if (!role) { + return res.status(404).send({ message: 'Role not found' }); + } + + const currentPermissions = + role.permissions?.[PermissionTypes.MEMORIES] || role[PermissionTypes.MEMORIES] || {}; + + const mergedUpdates = { + permissions: { + ...role.permissions, + [PermissionTypes.MEMORIES]: { + ...currentPermissions, + ...parsedUpdates, + }, + }, + }; + + const updatedRole = await updateRoleByName(roleName, mergedUpdates); + res.status(200).send(updatedRole); + } catch (error) { + return res.status(400).send({ message: 'Invalid memory permissions.', error: error.errors }); + } +}); + module.exports = router; diff --git a/api/server/routes/share.js b/api/server/routes/share.js index e551f4a354..14c25271fc 100644 --- a/api/server/routes/share.js +++ b/api/server/routes/share.js @@ -1,15 +1,15 @@ const express = require('express'); - +const { isEnabled } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { - getSharedLink, getSharedMessages, createSharedLink, updateSharedLink, - getSharedLinks, deleteSharedLink, -} = require('~/models/Share'); + getSharedLinks, + getSharedLink, +} = require('~/models'); const requireJwtAuth = require('~/server/middleware/requireJwtAuth'); -const { isEnabled } = require('~/server/utils'); const router = express.Router(); /** @@ -35,6 +35,7 @@ if (allowSharedLinks) { res.status(404).end(); } } catch (error) { + logger.error('Error getting shared messages:', error); res.status(500).json({ message: 'Error getting shared messages' }); } }, @@ -54,9 +55,7 @@ router.get('/', requireJwtAuth, async (req, res) => { sortDirection: ['asc', 'desc'].includes(req.query.sortDirection) ? req.query.sortDirection : 'desc', - search: req.query.search - ? decodeURIComponent(req.query.search.trim()) - : undefined, + search: req.query.search ? decodeURIComponent(req.query.search.trim()) : undefined, }; const result = await getSharedLinks( @@ -75,7 +74,7 @@ router.get('/', requireJwtAuth, async (req, res) => { hasNextPage: result.hasNextPage, }); } catch (error) { - console.error('Error getting shared links:', error); + logger.error('Error getting shared links:', error); res.status(500).json({ message: 'Error getting shared links', error: error.message, @@ -93,6 +92,7 @@ router.get('/link/:conversationId', requireJwtAuth, async (req, res) => { conversationId: req.params.conversationId, }); } catch (error) { + logger.error('Error getting shared link:', error); res.status(500).json({ message: 'Error getting shared link' }); } }); @@ -106,6 +106,7 @@ router.post('/:conversationId', requireJwtAuth, async (req, res) => { res.status(404).end(); } } catch (error) { + logger.error('Error creating shared link:', error); res.status(500).json({ message: 'Error creating shared link' }); } }); @@ -119,6 +120,7 @@ router.patch('/:shareId', requireJwtAuth, async (req, res) => { res.status(404).end(); } } catch (error) { + logger.error('Error updating shared link:', error); res.status(500).json({ message: 'Error updating shared link' }); } }); @@ -133,7 +135,8 @@ router.delete('/:shareId', requireJwtAuth, async (req, res) => { return res.status(200).json(result); } catch (error) { - return res.status(400).json({ message: error.message }); + logger.error('Error deleting shared link:', error); + return res.status(400).json({ message: 'Error deleting shared link' }); } }); diff --git a/api/server/services/ActionService.js b/api/server/services/ActionService.js index a35c74ad74..b9555a752c 100644 --- a/api/server/services/ActionService.js +++ b/api/server/services/ActionService.js @@ -1,7 +1,15 @@ const jwt = require('jsonwebtoken'); const { nanoid } = require('nanoid'); const { tool } = require('@langchain/core/tools'); +const { logger } = require('@librechat/data-schemas'); const { GraphEvents, sleep } = require('@librechat/agents'); +const { + sendEvent, + encryptV2, + decryptV2, + logAxiosError, + refreshAccessToken, +} = require('@librechat/api'); const { Time, CacheKeys, @@ -12,14 +20,11 @@ const { isImageVisionTool, actionDomainSeparator, } = require('librechat-data-provider'); -const { refreshAccessToken } = require('~/server/services/TokenService'); -const { logger, getFlowStateManager, sendEvent } = require('~/config'); -const { encryptV2, decryptV2 } = require('~/server/utils/crypto'); +const { findToken, updateToken, createToken } = require('~/models'); const { getActions, deleteActions } = require('~/models/Action'); const { deleteAssistant } = require('~/models/Assistant'); -const { logAxiosError } = require('~/utils'); +const { getFlowStateManager } = require('~/config'); const { getLogStores } = require('~/cache'); -const { findToken } = require('~/models'); const JWT_SECRET = process.env.JWT_SECRET; const toolNameRegex = /^[a-zA-Z0-9_-]+$/; @@ -208,6 +213,7 @@ async function createActionTool({ userId: userId, client_url: metadata.auth.client_url, redirect_uri: `${process.env.DOMAIN_SERVER}/api/actions/${action_id}/oauth/callback`, + token_exchange_method: metadata.auth.token_exchange_method, /** Encrypted values */ encrypted_oauth_client_id: encrypted.oauth_client_id, encrypted_oauth_client_secret: encrypted.oauth_client_secret, @@ -256,14 +262,22 @@ async function createActionTool({ try { const refresh_token = await decryptV2(refreshTokenData.token); const refreshTokens = async () => - await refreshAccessToken({ - userId, - identifier, - refresh_token, - client_url: metadata.auth.client_url, - encrypted_oauth_client_id: encrypted.oauth_client_id, - encrypted_oauth_client_secret: encrypted.oauth_client_secret, - }); + await refreshAccessToken( + { + userId, + identifier, + refresh_token, + client_url: metadata.auth.client_url, + encrypted_oauth_client_id: encrypted.oauth_client_id, + token_exchange_method: metadata.auth.token_exchange_method, + encrypted_oauth_client_secret: encrypted.oauth_client_secret, + }, + { + findToken, + updateToken, + createToken, + }, + ); const flowsCache = getLogStores(CacheKeys.FLOWS); const flowManager = getFlowStateManager(flowsCache); const refreshData = await flowManager.createFlowWithHandler( diff --git a/api/server/services/AppService.js b/api/server/services/AppService.js index 4bb8c51d00..6b7ff7417f 100644 --- a/api/server/services/AppService.js +++ b/api/server/services/AppService.js @@ -1,11 +1,12 @@ const { FileSources, loadOCRConfig, - processMCPEnv, EModelEndpoint, + loadMemoryConfig, getConfigDefaults, loadWebSearchConfig, } = require('librechat-data-provider'); +const { agentsConfigSetup } = require('@librechat/api'); const { checkHealth, checkConfig, @@ -24,10 +25,9 @@ const { azureConfigSetup } = require('./start/azureOpenAI'); const { processModelSpecs } = require('./start/modelSpecs'); const { initializeS3 } = require('./Files/S3/initialize'); const { loadAndFormatTools } = require('./ToolService'); -const { agentsConfigSetup } = require('./start/agents'); const { isEnabled } = require('~/server/utils'); const { initializeRoles } = require('~/models'); -const { getMCPManager } = require('~/config'); +const { setCachedTools } = require('./Config'); const paths = require('~/config/paths'); /** @@ -44,6 +44,7 @@ const AppService = async (app) => { const ocr = loadOCRConfig(config.ocr); const webSearch = loadWebSearchConfig(config.webSearch); checkWebSearchConfig(webSearch); + const memory = loadMemoryConfig(config.memory); const filteredTools = config.filteredTools; const includedTools = config.includedTools; const fileStrategy = config.fileStrategy ?? configDefaults.fileStrategy; @@ -74,11 +75,10 @@ const AppService = async (app) => { directory: paths.structuredTools, }); - if (config.mcpServers != null) { - const mcpManager = getMCPManager(); - await mcpManager.initializeMCP(config.mcpServers, processMCPEnv); - await mcpManager.mapAvailableTools(availableTools); - } + await setCachedTools(availableTools, { isGlobal: true }); + + // Store MCP config for later initialization + const mcpConfig = config.mcpServers || null; const socialLogins = config?.registration?.socialLogins ?? configDefaults?.registration?.socialLogins; @@ -88,20 +88,26 @@ const AppService = async (app) => { const defaultLocals = { ocr, paths, + memory, webSearch, fileStrategy, socialLogins, filteredTools, includedTools, - availableTools, imageOutputType, interfaceConfig, turnstileConfig, balance, + mcpConfig, }; + const agentsDefaults = agentsConfigSetup(config); + if (!Object.keys(config).length) { - app.locals = defaultLocals; + app.locals = { + ...defaultLocals, + [EModelEndpoint.agents]: agentsDefaults, + }; return; } @@ -136,9 +142,7 @@ const AppService = async (app) => { ); } - if (endpoints?.[EModelEndpoint.agents]) { - endpointLocals[EModelEndpoint.agents] = agentsConfigSetup(config); - } + endpointLocals[EModelEndpoint.agents] = agentsConfigSetup(config, agentsDefaults); const endpointKeys = [ EModelEndpoint.openAI, diff --git a/api/server/services/AppService.spec.js b/api/server/services/AppService.spec.js index 5e0837ce3b..7edccc2c0d 100644 --- a/api/server/services/AppService.spec.js +++ b/api/server/services/AppService.spec.js @@ -2,8 +2,10 @@ const { FileSources, EModelEndpoint, EImageOutputType, + AgentCapabilities, defaultSocialLogins, validateAzureGroups, + defaultAgentCapabilities, deprecatedAzureVariables, conflictingAzureVariables, } = require('librechat-data-provider'); @@ -30,6 +32,25 @@ jest.mock('~/models', () => ({ jest.mock('~/models/Role', () => ({ updateAccessPermissions: jest.fn(), })); +jest.mock('./Config', () => ({ + setCachedTools: jest.fn(), + getCachedTools: jest.fn().mockResolvedValue({ + ExampleTool: { + type: 'function', + function: { + description: 'Example tool function', + name: 'exampleFunction', + parameters: { + type: 'object', + properties: { + param1: { type: 'string', description: 'An example parameter' }, + }, + required: ['param1'], + }, + }, + }, + }), +})); jest.mock('./ToolService', () => ({ loadAndFormatTools: jest.fn().mockReturnValue({ ExampleTool: { @@ -119,22 +140,9 @@ describe('AppService', () => { sidePanel: true, presets: true, }), + mcpConfig: null, turnstileConfig: mockedTurnstileConfig, modelSpecs: undefined, - availableTools: { - ExampleTool: { - type: 'function', - function: expect.objectContaining({ - description: 'Example tool function', - name: 'exampleFunction', - parameters: expect.objectContaining({ - type: 'object', - properties: expect.any(Object), - required: expect.arrayContaining(['param1']), - }), - }), - }, - }, paths: expect.anything(), ocr: expect.anything(), imageOutputType: expect.any(String), @@ -151,6 +159,11 @@ describe('AppService', () => { safeSearch: 1, serperApiKey: '${SERPER_API_KEY}', }, + memory: undefined, + agents: { + disableBuilder: false, + capabilities: expect.arrayContaining([...defaultAgentCapabilities]), + }, }); }); @@ -216,14 +229,41 @@ describe('AppService', () => { it('should load and format tools accurately with defined structure', async () => { const { loadAndFormatTools } = require('./ToolService'); + const { setCachedTools, getCachedTools } = require('./Config'); + await AppService(app); expect(loadAndFormatTools).toHaveBeenCalledWith({ + adminFilter: undefined, + adminIncluded: undefined, directory: expect.anything(), }); - expect(app.locals.availableTools.ExampleTool).toBeDefined(); - expect(app.locals.availableTools.ExampleTool).toEqual({ + // Verify setCachedTools was called with the tools + expect(setCachedTools).toHaveBeenCalledWith( + { + ExampleTool: { + type: 'function', + function: { + description: 'Example tool function', + name: 'exampleFunction', + parameters: { + type: 'object', + properties: { + param1: { type: 'string', description: 'An example parameter' }, + }, + required: ['param1'], + }, + }, + }, + }, + { isGlobal: true }, + ); + + // Verify we can retrieve the tools from cache + const cachedTools = await getCachedTools({ includeGlobal: true }); + expect(cachedTools.ExampleTool).toBeDefined(); + expect(cachedTools.ExampleTool).toEqual({ type: 'function', function: { description: 'Example tool function', @@ -268,6 +308,71 @@ describe('AppService', () => { ); }); + it('should correctly configure Agents endpoint based on custom config', async () => { + require('./Config/loadCustomConfig').mockImplementationOnce(() => + Promise.resolve({ + endpoints: { + [EModelEndpoint.agents]: { + disableBuilder: true, + recursionLimit: 10, + maxRecursionLimit: 20, + allowedProviders: ['openai', 'anthropic'], + capabilities: [AgentCapabilities.tools, AgentCapabilities.actions], + }, + }, + }), + ); + + await AppService(app); + + expect(app.locals).toHaveProperty(EModelEndpoint.agents); + expect(app.locals[EModelEndpoint.agents]).toEqual( + expect.objectContaining({ + disableBuilder: true, + recursionLimit: 10, + maxRecursionLimit: 20, + allowedProviders: expect.arrayContaining(['openai', 'anthropic']), + capabilities: expect.arrayContaining([AgentCapabilities.tools, AgentCapabilities.actions]), + }), + ); + }); + + it('should configure Agents endpoint with defaults when no config is provided', async () => { + require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve({})); + + await AppService(app); + + expect(app.locals).toHaveProperty(EModelEndpoint.agents); + expect(app.locals[EModelEndpoint.agents]).toEqual( + expect.objectContaining({ + disableBuilder: false, + capabilities: expect.arrayContaining([...defaultAgentCapabilities]), + }), + ); + }); + + it('should configure Agents endpoint with defaults when endpoints exist but agents is not defined', async () => { + require('./Config/loadCustomConfig').mockImplementationOnce(() => + Promise.resolve({ + endpoints: { + [EModelEndpoint.openAI]: { + titleConvo: true, + }, + }, + }), + ); + + await AppService(app); + + expect(app.locals).toHaveProperty(EModelEndpoint.agents); + expect(app.locals[EModelEndpoint.agents]).toEqual( + expect.objectContaining({ + disableBuilder: false, + capabilities: expect.arrayContaining([...defaultAgentCapabilities]), + }), + ); + }); + it('should correctly configure minimum Azure OpenAI Assistant values', async () => { const assistantGroups = [azureGroups[0], { ...azureGroups[1], assistants: true }]; require('./Config/loadCustomConfig').mockImplementationOnce(() => @@ -463,7 +568,6 @@ describe('AppService updating app.locals and issuing warnings', () => { expect(app.locals).toBeDefined(); expect(app.locals.paths).toBeDefined(); - expect(app.locals.availableTools).toBeDefined(); expect(app.locals.fileStrategy).toEqual(FileSources.local); expect(app.locals.socialLogins).toEqual(defaultSocialLogins); expect(app.locals.balance).toEqual( @@ -496,7 +600,6 @@ describe('AppService updating app.locals and issuing warnings', () => { expect(app.locals).toBeDefined(); expect(app.locals.paths).toBeDefined(); - expect(app.locals.availableTools).toBeDefined(); expect(app.locals.fileStrategy).toEqual(customConfig.fileStrategy); expect(app.locals.socialLogins).toEqual(customConfig.registration.socialLogins); expect(app.locals.balance).toEqual(customConfig.balance); diff --git a/api/server/services/AuthService.js b/api/server/services/AuthService.js index 11b37ac886..6061277437 100644 --- a/api/server/services/AuthService.js +++ b/api/server/services/AuthService.js @@ -1,5 +1,7 @@ const bcrypt = require('bcryptjs'); const { webcrypto } = require('node:crypto'); +const { isEnabled } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { SystemRoles, errorsToString } = require('librechat-data-provider'); const { findUser, @@ -17,11 +19,10 @@ const { deleteUserById, generateRefreshToken, } = require('~/models'); -const { isEnabled, checkEmailConfig, sendEmail } = require('~/server/utils'); const { isEmailDomainAllowed } = require('~/server/services/domains'); +const { checkEmailConfig, sendEmail } = require('~/server/utils'); const { getBalanceConfig } = require('~/server/services/Config'); const { registerSchema } = require('~/strategies/validators'); -const { logger } = require('~/config'); const domains = { client: process.env.DOMAIN_CLIENT, @@ -409,7 +410,9 @@ const setOpenIDAuthTokens = (tokenset, res) => { return; } const { REFRESH_TOKEN_EXPIRY } = process.env ?? {}; - const expiryInMilliseconds = eval(REFRESH_TOKEN_EXPIRY) ?? 1000 * 60 * 60 * 24 * 7; // 7 days default + const expiryInMilliseconds = REFRESH_TOKEN_EXPIRY + ? eval(REFRESH_TOKEN_EXPIRY) + : 1000 * 60 * 60 * 24 * 7; // 7 days default const expirationDate = new Date(Date.now() + expiryInMilliseconds); if (tokenset == null) { logger.error('[setOpenIDAuthTokens] No tokenset found in request'); diff --git a/api/server/services/Config/getCachedTools.js b/api/server/services/Config/getCachedTools.js new file mode 100644 index 0000000000..b3a4f0c869 --- /dev/null +++ b/api/server/services/Config/getCachedTools.js @@ -0,0 +1,258 @@ +const { CacheKeys } = require('librechat-data-provider'); +const getLogStores = require('~/cache/getLogStores'); + +/** + * Cache key generators for different tool access patterns + * These will support future permission-based caching + */ +const ToolCacheKeys = { + /** Global tools available to all users */ + GLOBAL: 'tools:global', + /** Tools available to a specific user */ + USER: (userId) => `tools:user:${userId}`, + /** Tools available to a specific role */ + ROLE: (roleId) => `tools:role:${roleId}`, + /** Tools available to a specific group */ + GROUP: (groupId) => `tools:group:${groupId}`, + /** Combined effective tools for a user (computed from all sources) */ + EFFECTIVE: (userId) => `tools:effective:${userId}`, +}; + +/** + * Retrieves available tools from cache + * @function getCachedTools + * @param {Object} options - Options for retrieving tools + * @param {string} [options.userId] - User ID for user-specific tools + * @param {string[]} [options.roleIds] - Role IDs for role-based tools + * @param {string[]} [options.groupIds] - Group IDs for group-based tools + * @param {boolean} [options.includeGlobal=true] - Whether to include global tools + * @returns {Promise} The available tools object or null if not cached + */ +async function getCachedTools(options = {}) { + const cache = getLogStores(CacheKeys.CONFIG_STORE); + const { userId, roleIds = [], groupIds = [], includeGlobal = true } = options; + + // For now, return global tools (current behavior) + // This will be expanded to merge tools from different sources + if (!userId && includeGlobal) { + return await cache.get(ToolCacheKeys.GLOBAL); + } + + // Future implementation will merge tools from multiple sources + // based on user permissions, roles, and groups + if (userId) { + // Check if we have pre-computed effective tools for this user + const effectiveTools = await cache.get(ToolCacheKeys.EFFECTIVE(userId)); + if (effectiveTools) { + return effectiveTools; + } + + // Otherwise, compute from individual sources + const toolSources = []; + + if (includeGlobal) { + const globalTools = await cache.get(ToolCacheKeys.GLOBAL); + if (globalTools) { + toolSources.push(globalTools); + } + } + + // User-specific tools + const userTools = await cache.get(ToolCacheKeys.USER(userId)); + if (userTools) { + toolSources.push(userTools); + } + + // Role-based tools + for (const roleId of roleIds) { + const roleTools = await cache.get(ToolCacheKeys.ROLE(roleId)); + if (roleTools) { + toolSources.push(roleTools); + } + } + + // Group-based tools + for (const groupId of groupIds) { + const groupTools = await cache.get(ToolCacheKeys.GROUP(groupId)); + if (groupTools) { + toolSources.push(groupTools); + } + } + + // Merge all tool sources (for now, simple merge - future will handle conflicts) + if (toolSources.length > 0) { + return mergeToolSources(toolSources); + } + } + + return null; +} + +/** + * Sets available tools in cache + * @function setCachedTools + * @param {Object} tools - The tools object to cache + * @param {Object} options - Options for caching tools + * @param {string} [options.userId] - User ID for user-specific tools + * @param {string} [options.roleId] - Role ID for role-based tools + * @param {string} [options.groupId] - Group ID for group-based tools + * @param {boolean} [options.isGlobal=false] - Whether these are global tools + * @param {number} [options.ttl] - Time to live in milliseconds + * @returns {Promise} Whether the operation was successful + */ +async function setCachedTools(tools, options = {}) { + const cache = getLogStores(CacheKeys.CONFIG_STORE); + const { userId, roleId, groupId, isGlobal = false, ttl } = options; + + let cacheKey; + if (isGlobal || (!userId && !roleId && !groupId)) { + cacheKey = ToolCacheKeys.GLOBAL; + } else if (userId) { + cacheKey = ToolCacheKeys.USER(userId); + } else if (roleId) { + cacheKey = ToolCacheKeys.ROLE(roleId); + } else if (groupId) { + cacheKey = ToolCacheKeys.GROUP(groupId); + } + + if (!cacheKey) { + throw new Error('Invalid cache key options provided'); + } + + return await cache.set(cacheKey, tools, ttl); +} + +/** + * Invalidates cached tools + * @function invalidateCachedTools + * @param {Object} options - Options for invalidating tools + * @param {string} [options.userId] - User ID to invalidate + * @param {string} [options.roleId] - Role ID to invalidate + * @param {string} [options.groupId] - Group ID to invalidate + * @param {boolean} [options.invalidateGlobal=false] - Whether to invalidate global tools + * @param {boolean} [options.invalidateEffective=true] - Whether to invalidate effective tools + * @returns {Promise} + */ +async function invalidateCachedTools(options = {}) { + const cache = getLogStores(CacheKeys.CONFIG_STORE); + const { userId, roleId, groupId, invalidateGlobal = false, invalidateEffective = true } = options; + + const keysToDelete = []; + + if (invalidateGlobal) { + keysToDelete.push(ToolCacheKeys.GLOBAL); + } + + if (userId) { + keysToDelete.push(ToolCacheKeys.USER(userId)); + if (invalidateEffective) { + keysToDelete.push(ToolCacheKeys.EFFECTIVE(userId)); + } + } + + if (roleId) { + keysToDelete.push(ToolCacheKeys.ROLE(roleId)); + // TODO: In future, invalidate all users with this role + } + + if (groupId) { + keysToDelete.push(ToolCacheKeys.GROUP(groupId)); + // TODO: In future, invalidate all users in this group + } + + await Promise.all(keysToDelete.map((key) => cache.delete(key))); +} + +/** + * Computes and caches effective tools for a user + * @function computeEffectiveTools + * @param {string} userId - The user ID + * @param {Object} context - Context containing user's roles and groups + * @param {string[]} [context.roleIds=[]] - User's role IDs + * @param {string[]} [context.groupIds=[]] - User's group IDs + * @param {number} [ttl] - Time to live for the computed result + * @returns {Promise} The computed effective tools + */ +async function computeEffectiveTools(userId, context = {}, ttl) { + const { roleIds = [], groupIds = [] } = context; + + // Get all tool sources + const tools = await getCachedTools({ + userId, + roleIds, + groupIds, + includeGlobal: true, + }); + + if (tools) { + // Cache the computed result + const cache = getLogStores(CacheKeys.CONFIG_STORE); + await cache.set(ToolCacheKeys.EFFECTIVE(userId), tools, ttl); + } + + return tools; +} + +/** + * Merges multiple tool sources into a single tools object + * @function mergeToolSources + * @param {Object[]} sources - Array of tool objects to merge + * @returns {Object} Merged tools object + */ +function mergeToolSources(sources) { + // For now, simple merge that combines all tools + // Future implementation will handle: + // - Permission precedence (deny > allow) + // - Tool property conflicts + // - Metadata merging + const merged = {}; + + for (const source of sources) { + if (!source || typeof source !== 'object') { + continue; + } + + for (const [toolId, toolConfig] of Object.entries(source)) { + // Simple last-write-wins for now + // Future: merge based on permission levels + merged[toolId] = toolConfig; + } + } + + return merged; +} + +/** + * Middleware-friendly function to get tools for a request + * @function getToolsForRequest + * @param {Object} req - Express request object + * @returns {Promise} Available tools for the request + */ +async function getToolsForRequest(req) { + const userId = req.user?.id; + + // For now, return global tools if no user + if (!userId) { + return getCachedTools({ includeGlobal: true }); + } + + // Future: Extract roles and groups from req.user + const roleIds = req.user?.roles || []; + const groupIds = req.user?.groups || []; + + return getCachedTools({ + userId, + roleIds, + groupIds, + includeGlobal: true, + }); +} + +module.exports = { + ToolCacheKeys, + getCachedTools, + setCachedTools, + getToolsForRequest, + invalidateCachedTools, + computeEffectiveTools, +}; diff --git a/api/server/services/Config/getCustomConfig.js b/api/server/services/Config/getCustomConfig.js index 74828789fc..d1ee5c3278 100644 --- a/api/server/services/Config/getCustomConfig.js +++ b/api/server/services/Config/getCustomConfig.js @@ -1,6 +1,10 @@ +const { logger } = require('@librechat/data-schemas'); +const { getUserMCPAuthMap } = require('@librechat/api'); const { CacheKeys, EModelEndpoint } = require('librechat-data-provider'); const { normalizeEndpointName, isEnabled } = require('~/server/utils'); const loadCustomConfig = require('./loadCustomConfig'); +const { getCachedTools } = require('./getCachedTools'); +const { findPluginAuthsByKeys } = require('~/models'); const getLogStores = require('~/cache/getLogStores'); /** @@ -50,4 +54,46 @@ const getCustomEndpointConfig = async (endpoint) => { ); }; -module.exports = { getCustomConfig, getBalanceConfig, getCustomEndpointConfig }; +async function createGetMCPAuthMap() { + const customConfig = await getCustomConfig(); + const mcpServers = customConfig?.mcpServers; + const hasCustomUserVars = Object.values(mcpServers ?? {}).some((server) => server.customUserVars); + if (!hasCustomUserVars) { + return; + } + + /** + * @param {Object} params + * @param {GenericTool[]} [params.tools] + * @param {string} params.userId + * @returns {Promise> | undefined>} + */ + return async function ({ tools, userId }) { + try { + if (!tools || tools.length === 0) { + return; + } + const appTools = await getCachedTools({ + userId, + }); + return await getUserMCPAuthMap({ + tools, + userId, + appTools, + findPluginAuthsByKeys, + }); + } catch (err) { + logger.error( + `[api/server/controllers/agents/client.js #chatCompletion] Error getting custom user vars for agent`, + err, + ); + } + }; +} + +module.exports = { + getCustomConfig, + getBalanceConfig, + createGetMCPAuthMap, + getCustomEndpointConfig, +}; diff --git a/api/server/services/Config/index.js b/api/server/services/Config/index.js index 9d668da958..ad25e57998 100644 --- a/api/server/services/Config/index.js +++ b/api/server/services/Config/index.js @@ -1,4 +1,5 @@ const { config } = require('./EndpointService'); +const getCachedTools = require('./getCachedTools'); const getCustomConfig = require('./getCustomConfig'); const loadCustomConfig = require('./loadCustomConfig'); const loadConfigModels = require('./loadConfigModels'); @@ -14,6 +15,7 @@ module.exports = { loadDefaultModels, loadOverrideConfig, loadAsyncEndpoints, + ...getCachedTools, ...getCustomConfig, ...getEndpointsConfig, }; diff --git a/api/server/services/Endpoints/agents/agent.js b/api/server/services/Endpoints/agents/agent.js new file mode 100644 index 0000000000..e135401467 --- /dev/null +++ b/api/server/services/Endpoints/agents/agent.js @@ -0,0 +1,196 @@ +const { Providers } = require('@librechat/agents'); +const { primeResources, optionalChainWithEmptyCheck } = require('@librechat/api'); +const { + ErrorTypes, + EModelEndpoint, + EToolResources, + replaceSpecialVars, + providerEndpointMap, +} = require('librechat-data-provider'); +const initAnthropic = require('~/server/services/Endpoints/anthropic/initialize'); +const getBedrockOptions = require('~/server/services/Endpoints/bedrock/options'); +const initOpenAI = require('~/server/services/Endpoints/openAI/initialize'); +const initCustom = require('~/server/services/Endpoints/custom/initialize'); +const initGoogle = require('~/server/services/Endpoints/google/initialize'); +const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts'); +const { getCustomEndpointConfig } = require('~/server/services/Config'); +const { processFiles } = require('~/server/services/Files/process'); +const { getConvoFiles } = require('~/models/Conversation'); +const { getToolFilesByIds } = require('~/models/File'); +const { getModelMaxTokens } = require('~/utils'); +const { getFiles } = require('~/models/File'); + +const providerConfigMap = { + [Providers.XAI]: initCustom, + [Providers.OLLAMA]: initCustom, + [Providers.DEEPSEEK]: initCustom, + [Providers.OPENROUTER]: initCustom, + [EModelEndpoint.openAI]: initOpenAI, + [EModelEndpoint.google]: initGoogle, + [EModelEndpoint.azureOpenAI]: initOpenAI, + [EModelEndpoint.anthropic]: initAnthropic, + [EModelEndpoint.bedrock]: getBedrockOptions, +}; + +/** + * @param {object} params + * @param {ServerRequest} params.req + * @param {ServerResponse} params.res + * @param {Agent} params.agent + * @param {string | null} [params.conversationId] + * @param {Array} [params.requestFiles] + * @param {typeof import('~/server/services/ToolService').loadAgentTools | undefined} [params.loadTools] + * @param {TEndpointOption} [params.endpointOption] + * @param {Set} [params.allowedProviders] + * @param {boolean} [params.isInitialAgent] + * @returns {Promise, toolContextMap: Record, maxContextTokens: number }>} + */ +const initializeAgent = async ({ + req, + res, + agent, + loadTools, + requestFiles, + conversationId, + endpointOption, + allowedProviders, + isInitialAgent = false, +}) => { + if (allowedProviders.size > 0 && !allowedProviders.has(agent.provider)) { + throw new Error( + `{ "type": "${ErrorTypes.INVALID_AGENT_PROVIDER}", "info": "${agent.provider}" }`, + ); + } + let currentFiles; + + const _modelOptions = structuredClone( + Object.assign( + { model: agent.model }, + agent.model_parameters ?? { model: agent.model }, + isInitialAgent === true ? endpointOption?.model_parameters : {}, + ), + ); + + const { resendFiles = true, ...modelOptions } = _modelOptions; + + if (isInitialAgent && conversationId != null && resendFiles) { + const fileIds = (await getConvoFiles(conversationId)) ?? []; + /** @type {Set} */ + const toolResourceSet = new Set(); + for (const tool of agent.tools) { + if (EToolResources[tool]) { + toolResourceSet.add(EToolResources[tool]); + } + } + const toolFiles = await getToolFilesByIds(fileIds, toolResourceSet); + if (requestFiles.length || toolFiles.length) { + currentFiles = await processFiles(requestFiles.concat(toolFiles)); + } + } else if (isInitialAgent && requestFiles.length) { + currentFiles = await processFiles(requestFiles); + } + + const { attachments, tool_resources } = await primeResources({ + req, + getFiles, + attachments: currentFiles, + tool_resources: agent.tool_resources, + requestFileSet: new Set(requestFiles?.map((file) => file.file_id)), + }); + + const provider = agent.provider; + const { tools, toolContextMap } = + (await loadTools?.({ + req, + res, + provider, + agentId: agent.id, + tools: agent.tools, + model: agent.model, + tool_resources, + })) ?? {}; + + agent.endpoint = provider; + let getOptions = providerConfigMap[provider]; + if (!getOptions && providerConfigMap[provider.toLowerCase()] != null) { + agent.provider = provider.toLowerCase(); + getOptions = providerConfigMap[agent.provider]; + } else if (!getOptions) { + const customEndpointConfig = await getCustomEndpointConfig(provider); + if (!customEndpointConfig) { + throw new Error(`Provider ${provider} not supported`); + } + getOptions = initCustom; + agent.provider = Providers.OPENAI; + } + + const _endpointOption = + isInitialAgent === true + ? Object.assign({}, endpointOption, { model_parameters: modelOptions }) + : { model_parameters: modelOptions }; + + const options = await getOptions({ + req, + res, + optionsOnly: true, + overrideEndpoint: provider, + overrideModel: agent.model, + endpointOption: _endpointOption, + }); + + const tokensModel = + agent.provider === EModelEndpoint.azureOpenAI ? agent.model : modelOptions.model; + const maxTokens = optionalChainWithEmptyCheck( + modelOptions.maxOutputTokens, + modelOptions.maxTokens, + 0, + ); + const maxContextTokens = optionalChainWithEmptyCheck( + modelOptions.maxContextTokens, + modelOptions.max_context_tokens, + getModelMaxTokens(tokensModel, providerEndpointMap[provider]), + 4096, + ); + + if ( + agent.endpoint === EModelEndpoint.azureOpenAI && + options.llmConfig?.azureOpenAIApiInstanceName == null + ) { + agent.provider = Providers.OPENAI; + } + + if (options.provider != null) { + agent.provider = options.provider; + } + + /** @type {import('@librechat/agents').ClientOptions} */ + agent.model_parameters = { ...options.llmConfig }; + if (options.configOptions) { + agent.model_parameters.configuration = options.configOptions; + } + + if (agent.instructions && agent.instructions !== '') { + agent.instructions = replaceSpecialVars({ + text: agent.instructions, + user: req.user, + }); + } + + if (typeof agent.artifacts === 'string' && agent.artifacts !== '') { + agent.additional_instructions = generateArtifactsPrompt({ + endpoint: agent.provider, + artifacts: agent.artifacts, + }); + } + + return { + ...agent, + tools, + attachments, + resendFiles, + toolContextMap, + maxContextTokens: (maxContextTokens - maxTokens) * 0.9, + }; +}; + +module.exports = { initializeAgent }; diff --git a/api/server/services/Endpoints/agents/initialize.js b/api/server/services/Endpoints/agents/initialize.js index c9e363e815..e4ffcf4730 100644 --- a/api/server/services/Endpoints/agents/initialize.js +++ b/api/server/services/Endpoints/agents/initialize.js @@ -1,294 +1,41 @@ -const { createContentAggregator, Providers } = require('@librechat/agents'); -const { - Constants, - ErrorTypes, - EModelEndpoint, - EToolResources, - getResponseSender, - AgentCapabilities, - replaceSpecialVars, - providerEndpointMap, -} = require('librechat-data-provider'); +const { logger } = require('@librechat/data-schemas'); +const { createContentAggregator } = require('@librechat/agents'); +const { Constants, EModelEndpoint, getResponseSender } = require('librechat-data-provider'); const { getDefaultHandlers, createToolEndCallback, } = require('~/server/controllers/agents/callbacks'); -const initAnthropic = require('~/server/services/Endpoints/anthropic/initialize'); -const getBedrockOptions = require('~/server/services/Endpoints/bedrock/options'); -const initOpenAI = require('~/server/services/Endpoints/openAI/initialize'); -const initCustom = require('~/server/services/Endpoints/custom/initialize'); -const initGoogle = require('~/server/services/Endpoints/google/initialize'); -const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts'); -const { getCustomEndpointConfig } = require('~/server/services/Config'); -const { processFiles } = require('~/server/services/Files/process'); +const { initializeAgent } = require('~/server/services/Endpoints/agents/agent'); const { loadAgentTools } = require('~/server/services/ToolService'); const AgentClient = require('~/server/controllers/agents/client'); -const { getConvoFiles } = require('~/models/Conversation'); -const { getToolFilesByIds } = require('~/models/File'); -const { getModelMaxTokens } = require('~/utils'); const { getAgent } = require('~/models/Agent'); -const { getFiles } = require('~/models/File'); -const { logger } = require('~/config'); -const providerConfigMap = { - [Providers.XAI]: initCustom, - [Providers.OLLAMA]: initCustom, - [Providers.DEEPSEEK]: initCustom, - [Providers.OPENROUTER]: initCustom, - [EModelEndpoint.openAI]: initOpenAI, - [EModelEndpoint.google]: initGoogle, - [EModelEndpoint.azureOpenAI]: initOpenAI, - [EModelEndpoint.anthropic]: initAnthropic, - [EModelEndpoint.bedrock]: getBedrockOptions, -}; - -/** - * @param {Object} params - * @param {ServerRequest} params.req - * @param {Promise> | undefined} [params.attachments] - * @param {Set} params.requestFileSet - * @param {AgentToolResources | undefined} [params.tool_resources] - * @returns {Promise<{ attachments: Array | undefined, tool_resources: AgentToolResources | undefined }>} - */ -const primeResources = async ({ - req, - attachments: _attachments, - tool_resources: _tool_resources, - requestFileSet, -}) => { - try { - /** @type {Array | undefined} */ - let attachments; - const tool_resources = _tool_resources ?? {}; - const isOCREnabled = (req.app.locals?.[EModelEndpoint.agents]?.capabilities ?? []).includes( - AgentCapabilities.ocr, - ); - if (tool_resources[EToolResources.ocr]?.file_ids && isOCREnabled) { - const context = await getFiles( - { - file_id: { $in: tool_resources.ocr.file_ids }, - }, - {}, - {}, - ); - attachments = (attachments ?? []).concat(context); +function createToolLoader() { + /** + * @param {object} params + * @param {ServerRequest} params.req + * @param {ServerResponse} params.res + * @param {string} params.agentId + * @param {string[]} params.tools + * @param {string} params.provider + * @param {string} params.model + * @param {AgentToolResources} params.tool_resources + * @returns {Promise<{ tools: StructuredTool[], toolContextMap: Record } | undefined>} + */ + return async function loadTools({ req, res, agentId, tools, provider, model, tool_resources }) { + const agent = { id: agentId, tools, provider, model }; + try { + return await loadAgentTools({ + req, + res, + agent, + tool_resources, + }); + } catch (error) { + logger.error('Error loading tools for agent ' + agentId, error); } - if (!_attachments) { - return { attachments, tool_resources }; - } - /** @type {Array | undefined} */ - const files = await _attachments; - if (!attachments) { - /** @type {Array} */ - attachments = []; - } - - for (const file of files) { - if (!file) { - continue; - } - if (file.metadata?.fileIdentifier) { - const execute_code = tool_resources[EToolResources.execute_code] ?? {}; - if (!execute_code.files) { - tool_resources[EToolResources.execute_code] = { ...execute_code, files: [] }; - } - tool_resources[EToolResources.execute_code].files.push(file); - } else if (file.embedded === true) { - const file_search = tool_resources[EToolResources.file_search] ?? {}; - if (!file_search.files) { - tool_resources[EToolResources.file_search] = { ...file_search, files: [] }; - } - tool_resources[EToolResources.file_search].files.push(file); - } else if ( - requestFileSet.has(file.file_id) && - file.type.startsWith('image') && - file.height && - file.width - ) { - const image_edit = tool_resources[EToolResources.image_edit] ?? {}; - if (!image_edit.files) { - tool_resources[EToolResources.image_edit] = { ...image_edit, files: [] }; - } - tool_resources[EToolResources.image_edit].files.push(file); - } - - attachments.push(file); - } - return { attachments, tool_resources }; - } catch (error) { - logger.error('Error priming resources', error); - return { attachments: _attachments, tool_resources: _tool_resources }; - } -}; - -/** - * @param {...string | number} values - * @returns {string | number | undefined} - */ -function optionalChainWithEmptyCheck(...values) { - for (const value of values) { - if (value !== undefined && value !== null && value !== '') { - return value; - } - } - return values[values.length - 1]; -} - -/** - * @param {object} params - * @param {ServerRequest} params.req - * @param {ServerResponse} params.res - * @param {Agent} params.agent - * @param {Set} [params.allowedProviders] - * @param {object} [params.endpointOption] - * @param {boolean} [params.isInitialAgent] - * @returns {Promise} - */ -const initializeAgentOptions = async ({ - req, - res, - agent, - endpointOption, - allowedProviders, - isInitialAgent = false, -}) => { - if (allowedProviders.size > 0 && !allowedProviders.has(agent.provider)) { - throw new Error( - `{ "type": "${ErrorTypes.INVALID_AGENT_PROVIDER}", "info": "${agent.provider}" }`, - ); - } - let currentFiles; - /** @type {Array} */ - const requestFiles = req.body.files ?? []; - if ( - isInitialAgent && - req.body.conversationId != null && - (agent.model_parameters?.resendFiles ?? true) === true - ) { - const fileIds = (await getConvoFiles(req.body.conversationId)) ?? []; - /** @type {Set} */ - const toolResourceSet = new Set(); - for (const tool of agent.tools) { - if (EToolResources[tool]) { - toolResourceSet.add(EToolResources[tool]); - } - } - const toolFiles = await getToolFilesByIds(fileIds, toolResourceSet); - if (requestFiles.length || toolFiles.length) { - currentFiles = await processFiles(requestFiles.concat(toolFiles)); - } - } else if (isInitialAgent && requestFiles.length) { - currentFiles = await processFiles(requestFiles); - } - - const { attachments, tool_resources } = await primeResources({ - req, - attachments: currentFiles, - tool_resources: agent.tool_resources, - requestFileSet: new Set(requestFiles.map((file) => file.file_id)), - }); - - const provider = agent.provider; - const { tools, toolContextMap } = await loadAgentTools({ - req, - res, - agent: { - id: agent.id, - tools: agent.tools, - provider, - model: agent.model, - }, - tool_resources, - }); - - agent.endpoint = provider; - let getOptions = providerConfigMap[provider]; - if (!getOptions && providerConfigMap[provider.toLowerCase()] != null) { - agent.provider = provider.toLowerCase(); - getOptions = providerConfigMap[agent.provider]; - } else if (!getOptions) { - const customEndpointConfig = await getCustomEndpointConfig(provider); - if (!customEndpointConfig) { - throw new Error(`Provider ${provider} not supported`); - } - getOptions = initCustom; - agent.provider = Providers.OPENAI; - } - const model_parameters = Object.assign( - {}, - agent.model_parameters ?? { model: agent.model }, - isInitialAgent === true ? endpointOption?.model_parameters : {}, - ); - const _endpointOption = - isInitialAgent === true - ? Object.assign({}, endpointOption, { model_parameters }) - : { model_parameters }; - - const options = await getOptions({ - req, - res, - optionsOnly: true, - overrideEndpoint: provider, - overrideModel: agent.model, - endpointOption: _endpointOption, - }); - - if ( - agent.endpoint === EModelEndpoint.azureOpenAI && - options.llmConfig?.azureOpenAIApiInstanceName == null - ) { - agent.provider = Providers.OPENAI; - } - - if (options.provider != null) { - agent.provider = options.provider; - } - - /** @type {import('@librechat/agents').ClientOptions} */ - agent.model_parameters = Object.assign(model_parameters, options.llmConfig); - if (options.configOptions) { - agent.model_parameters.configuration = options.configOptions; - } - - if (!agent.model_parameters.model) { - agent.model_parameters.model = agent.model; - } - - if (agent.instructions && agent.instructions !== '') { - agent.instructions = replaceSpecialVars({ - text: agent.instructions, - user: req.user, - }); - } - - if (typeof agent.artifacts === 'string' && agent.artifacts !== '') { - agent.additional_instructions = generateArtifactsPrompt({ - endpoint: agent.provider, - artifacts: agent.artifacts, - }); - } - - const tokensModel = - agent.provider === EModelEndpoint.azureOpenAI ? agent.model : agent.model_parameters.model; - const maxTokens = optionalChainWithEmptyCheck( - agent.model_parameters.maxOutputTokens, - agent.model_parameters.maxTokens, - 0, - ); - const maxContextTokens = optionalChainWithEmptyCheck( - agent.model_parameters.maxContextTokens, - agent.max_context_tokens, - getModelMaxTokens(tokensModel, providerEndpointMap[provider]), - 4096, - ); - return { - ...agent, - tools, - attachments, - toolContextMap, - maxContextTokens: (maxContextTokens - maxTokens) * 0.9, }; -}; +} const initializeClient = async ({ req, res, endpointOption }) => { if (!endpointOption) { @@ -313,7 +60,6 @@ const initializeClient = async ({ req, res, endpointOption }) => { throw new Error('No agent promise provided'); } - // Initialize primary agent const primaryAgent = await endpointOption.agent; if (!primaryAgent) { throw new Error('Agent not found'); @@ -323,10 +69,18 @@ const initializeClient = async ({ req, res, endpointOption }) => { /** @type {Set} */ const allowedProviders = new Set(req?.app?.locals?.[EModelEndpoint.agents]?.allowedProviders); - // Handle primary agent - const primaryConfig = await initializeAgentOptions({ + const loadTools = createToolLoader(); + /** @type {Array} */ + const requestFiles = req.body.files ?? []; + /** @type {string} */ + const conversationId = req.body.conversationId; + + const primaryConfig = await initializeAgent({ req, res, + loadTools, + requestFiles, + conversationId, agent: primaryAgent, endpointOption, allowedProviders, @@ -340,10 +94,13 @@ const initializeClient = async ({ req, res, endpointOption }) => { if (!agent) { throw new Error(`Agent ${agentId} not found`); } - const config = await initializeAgentOptions({ + const config = await initializeAgent({ req, res, agent, + loadTools, + requestFiles, + conversationId, endpointOption, allowedProviders, }); @@ -373,8 +130,8 @@ const initializeClient = async ({ req, res, endpointOption }) => { iconURL: endpointOption.iconURL, attachments: primaryConfig.attachments, endpointType: endpointOption.endpointType, + resendFiles: primaryConfig.resendFiles ?? true, maxContextTokens: primaryConfig.maxContextTokens, - resendFiles: primaryConfig.model_parameters?.resendFiles ?? true, endpoint: primaryConfig.id === Constants.EPHEMERAL_AGENT_ID ? primaryConfig.endpoint diff --git a/api/server/services/Endpoints/anthropic/llm.js b/api/server/services/Endpoints/anthropic/llm.js index 9f20b8e61d..66496f00fd 100644 --- a/api/server/services/Endpoints/anthropic/llm.js +++ b/api/server/services/Endpoints/anthropic/llm.js @@ -1,4 +1,4 @@ -const { HttpsProxyAgent } = require('https-proxy-agent'); +const { ProxyAgent } = require('undici'); const { anthropicSettings, removeNullishValues } = require('librechat-data-provider'); const { checkPromptCacheSupport, getClaudeHeaders, configureReasoning } = require('./helpers'); @@ -67,7 +67,10 @@ function getLLMConfig(apiKey, options = {}) { } if (options.proxy) { - requestOptions.clientOptions.httpAgent = new HttpsProxyAgent(options.proxy); + const proxyAgent = new ProxyAgent(options.proxy); + requestOptions.clientOptions.fetchOptions = { + dispatcher: proxyAgent, + }; } if (options.reverseProxyUrl) { diff --git a/api/server/services/Endpoints/anthropic/llm.spec.js b/api/server/services/Endpoints/anthropic/llm.spec.js index 9c453efb92..f3f77ee897 100644 --- a/api/server/services/Endpoints/anthropic/llm.spec.js +++ b/api/server/services/Endpoints/anthropic/llm.spec.js @@ -21,8 +21,12 @@ describe('getLLMConfig', () => { proxy: 'http://proxy:8080', }); - expect(result.llmConfig.clientOptions).toHaveProperty('httpAgent'); - expect(result.llmConfig.clientOptions.httpAgent).toHaveProperty('proxy', 'http://proxy:8080'); + expect(result.llmConfig.clientOptions).toHaveProperty('fetchOptions'); + expect(result.llmConfig.clientOptions.fetchOptions).toHaveProperty('dispatcher'); + expect(result.llmConfig.clientOptions.fetchOptions.dispatcher).toBeDefined(); + expect(result.llmConfig.clientOptions.fetchOptions.dispatcher.constructor.name).toBe( + 'ProxyAgent', + ); }); it('should include reverse proxy URL when provided', () => { diff --git a/api/server/services/Endpoints/azureAssistants/initialize.js b/api/server/services/Endpoints/azureAssistants/initialize.js index fc8024af07..88acef23e5 100644 --- a/api/server/services/Endpoints/azureAssistants/initialize.js +++ b/api/server/services/Endpoints/azureAssistants/initialize.js @@ -1,5 +1,6 @@ const OpenAI = require('openai'); const { HttpsProxyAgent } = require('https-proxy-agent'); +const { constructAzureURL, isUserProvided } = require('@librechat/api'); const { ErrorTypes, EModelEndpoint, @@ -12,8 +13,6 @@ const { checkUserKeyExpiry, } = require('~/server/services/UserService'); const OpenAIClient = require('~/app/clients/OpenAIClient'); -const { isUserProvided } = require('~/server/utils'); -const { constructAzureURL } = require('~/utils'); class Files { constructor(client) { diff --git a/api/server/services/Endpoints/bedrock/options.js b/api/server/services/Endpoints/bedrock/options.js index da332060e9..fc5536abbf 100644 --- a/api/server/services/Endpoints/bedrock/options.js +++ b/api/server/services/Endpoints/bedrock/options.js @@ -1,4 +1,5 @@ const { HttpsProxyAgent } = require('https-proxy-agent'); +const { createHandleLLMNewToken } = require('@librechat/api'); const { AuthType, Constants, @@ -8,7 +9,6 @@ const { removeNullishValues, } = require('librechat-data-provider'); const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService'); -const { createHandleLLMNewToken } = require('~/app/clients/generators'); const getOptions = async ({ req, overrideModel, endpointOption }) => { const { diff --git a/api/server/services/Endpoints/custom/initialize.js b/api/server/services/Endpoints/custom/initialize.js index 39def8d0d5..754abef5a8 100644 --- a/api/server/services/Endpoints/custom/initialize.js +++ b/api/server/services/Endpoints/custom/initialize.js @@ -6,10 +6,9 @@ const { extractEnvVariable, } = require('librechat-data-provider'); const { Providers } = require('@librechat/agents'); +const { getOpenAIConfig, createHandleLLMNewToken } = require('@librechat/api'); const { getUserKeyValues, checkUserKeyExpiry } = require('~/server/services/UserService'); -const { getLLMConfig } = require('~/server/services/Endpoints/openAI/llm'); const { getCustomEndpointConfig } = require('~/server/services/Config'); -const { createHandleLLMNewToken } = require('~/app/clients/generators'); const { fetchModels } = require('~/server/services/ModelService'); const OpenAIClient = require('~/app/clients/OpenAIClient'); const { isUserProvided } = require('~/server/utils'); @@ -144,7 +143,7 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid clientOptions, ); clientOptions.modelOptions.user = req.user.id; - const options = getLLMConfig(apiKey, clientOptions, endpoint); + const options = getOpenAIConfig(apiKey, clientOptions, endpoint); if (!customOptions.streamRate) { return options; } diff --git a/api/server/services/Endpoints/google/initialize.js b/api/server/services/Endpoints/google/initialize.js index b7419a8a87..b6bc2d6a79 100644 --- a/api/server/services/Endpoints/google/initialize.js +++ b/api/server/services/Endpoints/google/initialize.js @@ -25,9 +25,9 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio const credentials = isUserProvided ? userKey : { - [AuthKeys.GOOGLE_SERVICE_KEY]: serviceKey, - [AuthKeys.GOOGLE_API_KEY]: GOOGLE_KEY, - }; + [AuthKeys.GOOGLE_SERVICE_KEY]: serviceKey, + [AuthKeys.GOOGLE_API_KEY]: GOOGLE_KEY, + }; let clientOptions = {}; diff --git a/api/server/services/Endpoints/google/llm.js b/api/server/services/Endpoints/google/llm.js index a64b33480b..235e1e3df9 100644 --- a/api/server/services/Endpoints/google/llm.js +++ b/api/server/services/Endpoints/google/llm.js @@ -94,7 +94,7 @@ function getLLMConfig(credentials, options = {}) { // Extract from credentials const serviceKeyRaw = creds[AuthKeys.GOOGLE_SERVICE_KEY] ?? {}; const serviceKey = - typeof serviceKeyRaw === 'string' ? JSON.parse(serviceKeyRaw) : serviceKeyRaw ?? {}; + typeof serviceKeyRaw === 'string' ? JSON.parse(serviceKeyRaw) : (serviceKeyRaw ?? {}); const project_id = serviceKey?.project_id ?? null; const apiKey = creds[AuthKeys.GOOGLE_API_KEY] ?? null; @@ -156,10 +156,6 @@ function getLLMConfig(credentials, options = {}) { } if (authHeader) { - /** - * NOTE: NOT SUPPORTED BY LANGCHAIN GENAI CLIENT, - * REQUIRES PR IN https://github.com/langchain-ai/langchainjs - */ llmConfig.customHeaders = { Authorization: `Bearer ${apiKey}`, }; diff --git a/api/server/services/Endpoints/gptPlugins/initialize.js b/api/server/services/Endpoints/gptPlugins/initialize.js index 7bfb43f004..d2af6c757e 100644 --- a/api/server/services/Endpoints/gptPlugins/initialize.js +++ b/api/server/services/Endpoints/gptPlugins/initialize.js @@ -1,11 +1,10 @@ const { EModelEndpoint, - mapModelToAzureConfig, resolveHeaders, + mapModelToAzureConfig, } = require('librechat-data-provider'); +const { isEnabled, isUserProvided, getAzureCredentials } = require('@librechat/api'); const { getUserKeyValues, checkUserKeyExpiry } = require('~/server/services/UserService'); -const { isEnabled, isUserProvided } = require('~/server/utils'); -const { getAzureCredentials } = require('~/utils'); const { PluginsClient } = require('~/app'); const initializeClient = async ({ req, res, endpointOption }) => { diff --git a/api/server/services/Endpoints/gptPlugins/initialize.spec.js b/api/server/services/Endpoints/gptPlugins/initialize.spec.js index 02199c9397..f9cb2750a4 100644 --- a/api/server/services/Endpoints/gptPlugins/initialize.spec.js +++ b/api/server/services/Endpoints/gptPlugins/initialize.spec.js @@ -114,11 +114,11 @@ describe('gptPlugins/initializeClient', () => { test('should initialize PluginsClient with Azure credentials when PLUGINS_USE_AZURE is true', async () => { process.env.AZURE_API_KEY = 'test-azure-api-key'; (process.env.AZURE_OPENAI_API_INSTANCE_NAME = 'some-value'), - (process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME = 'some-value'), - (process.env.AZURE_OPENAI_API_VERSION = 'some-value'), - (process.env.AZURE_OPENAI_API_COMPLETIONS_DEPLOYMENT_NAME = 'some-value'), - (process.env.AZURE_OPENAI_API_EMBEDDINGS_DEPLOYMENT_NAME = 'some-value'), - (process.env.PLUGINS_USE_AZURE = 'true'); + (process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME = 'some-value'), + (process.env.AZURE_OPENAI_API_VERSION = 'some-value'), + (process.env.AZURE_OPENAI_API_COMPLETIONS_DEPLOYMENT_NAME = 'some-value'), + (process.env.AZURE_OPENAI_API_EMBEDDINGS_DEPLOYMENT_NAME = 'some-value'), + (process.env.PLUGINS_USE_AZURE = 'true'); process.env.DEBUG_PLUGINS = 'false'; process.env.OPENAI_SUMMARIZE = 'false'; diff --git a/api/server/services/Endpoints/openAI/initialize.js b/api/server/services/Endpoints/openAI/initialize.js index 714ed5a1e6..bc0907b3de 100644 --- a/api/server/services/Endpoints/openAI/initialize.js +++ b/api/server/services/Endpoints/openAI/initialize.js @@ -4,12 +4,15 @@ const { resolveHeaders, mapModelToAzureConfig, } = require('librechat-data-provider'); +const { + isEnabled, + isUserProvided, + getOpenAIConfig, + getAzureCredentials, + createHandleLLMNewToken, +} = require('@librechat/api'); const { getUserKeyValues, checkUserKeyExpiry } = require('~/server/services/UserService'); -const { getLLMConfig } = require('~/server/services/Endpoints/openAI/llm'); -const { createHandleLLMNewToken } = require('~/app/clients/generators'); -const { isEnabled, isUserProvided } = require('~/server/utils'); const OpenAIClient = require('~/app/clients/OpenAIClient'); -const { getAzureCredentials } = require('~/utils'); const initializeClient = async ({ req, @@ -140,7 +143,7 @@ const initializeClient = async ({ modelOptions.model = modelName; clientOptions = Object.assign({ modelOptions }, clientOptions); clientOptions.modelOptions.user = req.user.id; - const options = getLLMConfig(apiKey, clientOptions); + const options = getOpenAIConfig(apiKey, clientOptions); const streamRate = clientOptions.streamRate; if (!streamRate) { return options; diff --git a/api/server/services/Endpoints/openAI/llm.js b/api/server/services/Endpoints/openAI/llm.js deleted file mode 100644 index c1fd090b28..0000000000 --- a/api/server/services/Endpoints/openAI/llm.js +++ /dev/null @@ -1,170 +0,0 @@ -const { HttpsProxyAgent } = require('https-proxy-agent'); -const { KnownEndpoints } = require('librechat-data-provider'); -const { sanitizeModelName, constructAzureURL } = require('~/utils'); -const { isEnabled } = require('~/server/utils'); - -/** - * Generates configuration options for creating a language model (LLM) instance. - * @param {string} apiKey - The API key for authentication. - * @param {Object} options - Additional options for configuring the LLM. - * @param {Object} [options.modelOptions] - Model-specific options. - * @param {string} [options.modelOptions.model] - The name of the model to use. - * @param {string} [options.modelOptions.user] - The user ID - * @param {number} [options.modelOptions.temperature] - Controls randomness in output generation (0-2). - * @param {number} [options.modelOptions.top_p] - Controls diversity via nucleus sampling (0-1). - * @param {number} [options.modelOptions.frequency_penalty] - Reduces repetition of token sequences (-2 to 2). - * @param {number} [options.modelOptions.presence_penalty] - Encourages discussing new topics (-2 to 2). - * @param {number} [options.modelOptions.max_tokens] - The maximum number of tokens to generate. - * @param {string[]} [options.modelOptions.stop] - Sequences where the API will stop generating further tokens. - * @param {string} [options.reverseProxyUrl] - URL for a reverse proxy, if used. - * @param {boolean} [options.useOpenRouter] - Flag to use OpenRouter API. - * @param {Object} [options.headers] - Additional headers for API requests. - * @param {string} [options.proxy] - Proxy server URL. - * @param {Object} [options.azure] - Azure-specific configurations. - * @param {boolean} [options.streaming] - Whether to use streaming mode. - * @param {Object} [options.addParams] - Additional parameters to add to the model options. - * @param {string[]} [options.dropParams] - Parameters to remove from the model options. - * @param {string|null} [endpoint=null] - The endpoint name - * @returns {Object} Configuration options for creating an LLM instance. - */ -function getLLMConfig(apiKey, options = {}, endpoint = null) { - let { - modelOptions = {}, - reverseProxyUrl, - defaultQuery, - headers, - proxy, - azure, - streaming = true, - addParams, - dropParams, - } = options; - - /** @type {OpenAIClientOptions} */ - let llmConfig = { - streaming, - }; - - Object.assign(llmConfig, modelOptions); - - if (addParams && typeof addParams === 'object') { - Object.assign(llmConfig, addParams); - } - /** Note: OpenAI Web Search models do not support any known parameters besdies `max_tokens` */ - if (modelOptions.model && /gpt-4o.*search/.test(modelOptions.model)) { - const searchExcludeParams = [ - 'frequency_penalty', - 'presence_penalty', - 'temperature', - 'top_p', - 'top_k', - 'stop', - 'logit_bias', - 'seed', - 'response_format', - 'n', - 'logprobs', - 'user', - ]; - - dropParams = dropParams || []; - dropParams = [...new Set([...dropParams, ...searchExcludeParams])]; - } - - if (dropParams && Array.isArray(dropParams)) { - dropParams.forEach((param) => { - if (llmConfig[param]) { - llmConfig[param] = undefined; - } - }); - } - - let useOpenRouter; - /** @type {OpenAIClientOptions['configuration']} */ - const configOptions = {}; - if ( - (reverseProxyUrl && reverseProxyUrl.includes(KnownEndpoints.openrouter)) || - (endpoint && endpoint.toLowerCase().includes(KnownEndpoints.openrouter)) - ) { - useOpenRouter = true; - llmConfig.include_reasoning = true; - configOptions.baseURL = reverseProxyUrl; - configOptions.defaultHeaders = Object.assign( - { - 'HTTP-Referer': 'https://librechat.ai', - 'X-Title': 'LibreChat', - }, - headers, - ); - } else if (reverseProxyUrl) { - configOptions.baseURL = reverseProxyUrl; - if (headers) { - configOptions.defaultHeaders = headers; - } - } - - if (defaultQuery) { - configOptions.defaultQuery = defaultQuery; - } - - if (proxy) { - const proxyAgent = new HttpsProxyAgent(proxy); - Object.assign(configOptions, { - httpAgent: proxyAgent, - httpsAgent: proxyAgent, - }); - } - - if (azure) { - const useModelName = isEnabled(process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME); - azure.azureOpenAIApiDeploymentName = useModelName - ? sanitizeModelName(llmConfig.model) - : azure.azureOpenAIApiDeploymentName; - - if (process.env.AZURE_OPENAI_DEFAULT_MODEL) { - llmConfig.model = process.env.AZURE_OPENAI_DEFAULT_MODEL; - } - - if (configOptions.baseURL) { - const azureURL = constructAzureURL({ - baseURL: configOptions.baseURL, - azureOptions: azure, - }); - azure.azureOpenAIBasePath = azureURL.split(`/${azure.azureOpenAIApiDeploymentName}`)[0]; - } - - Object.assign(llmConfig, azure); - llmConfig.model = llmConfig.azureOpenAIApiDeploymentName; - } else { - llmConfig.apiKey = apiKey; - // Object.assign(llmConfig, { - // configuration: { apiKey }, - // }); - } - - if (process.env.OPENAI_ORGANIZATION && this.azure) { - llmConfig.organization = process.env.OPENAI_ORGANIZATION; - } - - if (useOpenRouter && llmConfig.reasoning_effort != null) { - llmConfig.reasoning = { - effort: llmConfig.reasoning_effort, - }; - delete llmConfig.reasoning_effort; - } - - if (llmConfig?.['max_tokens'] != null) { - /** @type {number} */ - llmConfig.maxTokens = llmConfig['max_tokens']; - delete llmConfig['max_tokens']; - } - - return { - /** @type {OpenAIClientOptions} */ - llmConfig, - /** @type {OpenAIClientOptions['configuration']} */ - configOptions, - }; -} - -module.exports = { getLLMConfig }; diff --git a/api/server/services/Files/Audio/STTService.js b/api/server/services/Files/Audio/STTService.js index d6c8cc4146..49a800336b 100644 --- a/api/server/services/Files/Audio/STTService.js +++ b/api/server/services/Files/Audio/STTService.js @@ -2,9 +2,9 @@ const axios = require('axios'); const fs = require('fs').promises; const FormData = require('form-data'); const { Readable } = require('stream'); +const { genAzureEndpoint } = require('@librechat/api'); const { extractEnvVariable, STTProviders } = require('librechat-data-provider'); const { getCustomConfig } = require('~/server/services/Config'); -const { genAzureEndpoint } = require('~/utils'); const { logger } = require('~/config'); /** diff --git a/api/server/services/Files/Audio/TTSService.js b/api/server/services/Files/Audio/TTSService.js index cd718fdfc1..34d8202156 100644 --- a/api/server/services/Files/Audio/TTSService.js +++ b/api/server/services/Files/Audio/TTSService.js @@ -1,8 +1,8 @@ const axios = require('axios'); +const { genAzureEndpoint } = require('@librechat/api'); const { extractEnvVariable, TTSProviders } = require('librechat-data-provider'); const { getRandomVoiceId, createChunkProcessor, splitTextIntoChunks } = require('./streamAudio'); const { getCustomConfig } = require('~/server/services/Config'); -const { genAzureEndpoint } = require('~/utils'); const { logger } = require('~/config'); /** diff --git a/api/server/services/Files/Azure/images.js b/api/server/services/Files/Azure/images.js index 80163bee05..80d5e76290 100644 --- a/api/server/services/Files/Azure/images.js +++ b/api/server/services/Files/Azure/images.js @@ -91,15 +91,28 @@ async function prepareAzureImageURL(req, file) { * @param {Buffer} params.buffer - The avatar image buffer. * @param {string} params.userId - The user's id. * @param {string} params.manual - Flag to indicate manual update. + * @param {string} [params.agentId] - Optional agent ID if this is an agent avatar. * @param {string} [params.basePath='images'] - The base folder within the container. * @param {string} [params.containerName] - The Azure Blob container name. * @returns {Promise} The URL of the avatar. */ -async function processAzureAvatar({ buffer, userId, manual, basePath = 'images', containerName }) { +async function processAzureAvatar({ + buffer, + userId, + manual, + agentId, + basePath = 'images', + containerName, +}) { try { const metadata = await sharp(buffer).metadata(); const extension = metadata.format === 'gif' ? 'gif' : 'png'; - const fileName = `avatar.${extension}`; + const timestamp = new Date().getTime(); + + /** Unique filename with timestamp and optional agent ID */ + const fileName = agentId + ? `agent-${agentId}-avatar-${timestamp}.${extension}` + : `avatar-${timestamp}.${extension}`; const downloadURL = await saveBufferToAzure({ userId, @@ -110,9 +123,12 @@ async function processAzureAvatar({ buffer, userId, manual, basePath = 'images', }); const isManual = manual === 'true'; const url = `${downloadURL}?manual=${isManual}`; - if (isManual) { + + // Only update user record if this is a user avatar (manual === 'true') + if (isManual && !agentId) { await updateUser(userId, { avatar: url }); } + return url; } catch (error) { logger.error('[processAzureAvatar] Error uploading profile picture to Azure:', error); diff --git a/api/server/services/Files/Code/crud.js b/api/server/services/Files/Code/crud.js index caea9ab30a..c696eae0c4 100644 --- a/api/server/services/Files/Code/crud.js +++ b/api/server/services/Files/Code/crud.js @@ -1,7 +1,6 @@ const FormData = require('form-data'); const { getCodeBaseURL } = require('@librechat/agents'); -const { createAxiosInstance } = require('~/config'); -const { logAxiosError } = require('~/utils'); +const { createAxiosInstance, logAxiosError } = require('@librechat/api'); const axios = createAxiosInstance(); diff --git a/api/server/services/Files/Code/process.js b/api/server/services/Files/Code/process.js index c92e628589..cf65154983 100644 --- a/api/server/services/Files/Code/process.js +++ b/api/server/services/Files/Code/process.js @@ -1,6 +1,8 @@ const path = require('path'); const { v4 } = require('uuid'); const axios = require('axios'); +const { logAxiosError } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { getCodeBaseURL } = require('@librechat/agents'); const { Tools, @@ -12,8 +14,6 @@ const { const { getStrategyFunctions } = require('~/server/services/Files/strategies'); const { convertImage } = require('~/server/services/Files/images/convert'); const { createFile, getFiles, updateFile } = require('~/models/File'); -const { logAxiosError } = require('~/utils'); -const { logger } = require('~/config'); /** * Process OpenAI image files, convert to target format, save and return file metadata. diff --git a/api/server/services/Files/Firebase/images.js b/api/server/services/Files/Firebase/images.js index 80bff28d89..8b0866b5d0 100644 --- a/api/server/services/Files/Firebase/images.js +++ b/api/server/services/Files/Firebase/images.js @@ -82,14 +82,20 @@ async function prepareImageURL(req, file) { * @param {Buffer} params.buffer - The Buffer containing the avatar image. * @param {string} params.userId - The user ID. * @param {string} params.manual - A string flag indicating whether the update is manual ('true' or 'false'). + * @param {string} [params.agentId] - Optional agent ID if this is an agent avatar. * @returns {Promise} - A promise that resolves with the URL of the uploaded avatar. * @throws {Error} - Throws an error if Firebase is not initialized or if there is an error in uploading. */ -async function processFirebaseAvatar({ buffer, userId, manual }) { +async function processFirebaseAvatar({ buffer, userId, manual, agentId }) { try { const metadata = await sharp(buffer).metadata(); const extension = metadata.format === 'gif' ? 'gif' : 'png'; - const fileName = `avatar.${extension}`; + const timestamp = new Date().getTime(); + + /** Unique filename with timestamp and optional agent ID */ + const fileName = agentId + ? `agent-${agentId}-avatar-${timestamp}.${extension}` + : `avatar-${timestamp}.${extension}`; const downloadURL = await saveBufferToFirebase({ userId, @@ -98,10 +104,10 @@ async function processFirebaseAvatar({ buffer, userId, manual }) { }); const isManual = manual === 'true'; - const url = `${downloadURL}?manual=${isManual}`; - if (isManual) { + // Only update user record if this is a user avatar (manual === 'true') + if (isManual && !agentId) { await updateUser(userId, { avatar: url }); } diff --git a/api/server/services/Files/Local/crud.js b/api/server/services/Files/Local/crud.js index 783230f2f6..7df528c5e1 100644 --- a/api/server/services/Files/Local/crud.js +++ b/api/server/services/Files/Local/crud.js @@ -201,6 +201,10 @@ const unlinkFile = async (filepath) => { */ const deleteLocalFile = async (req, file) => { const { publicPath, uploads } = req.app.locals.paths; + + /** Filepath stripped of query parameters (e.g., ?manual=true) */ + const cleanFilepath = file.filepath.split('?')[0]; + if (file.embedded && process.env.RAG_API_URL) { const jwtToken = req.headers.authorization.split(' ')[1]; axios.delete(`${process.env.RAG_API_URL}/documents`, { @@ -213,32 +217,32 @@ const deleteLocalFile = async (req, file) => { }); } - if (file.filepath.startsWith(`/uploads/${req.user.id}`)) { + if (cleanFilepath.startsWith(`/uploads/${req.user.id}`)) { const userUploadDir = path.join(uploads, req.user.id); - const basePath = file.filepath.split(`/uploads/${req.user.id}/`)[1]; + const basePath = cleanFilepath.split(`/uploads/${req.user.id}/`)[1]; if (!basePath) { - throw new Error(`Invalid file path: ${file.filepath}`); + throw new Error(`Invalid file path: ${cleanFilepath}`); } const filepath = path.join(userUploadDir, basePath); const rel = path.relative(userUploadDir, filepath); if (rel.startsWith('..') || path.isAbsolute(rel) || rel.includes(`..${path.sep}`)) { - throw new Error(`Invalid file path: ${file.filepath}`); + throw new Error(`Invalid file path: ${cleanFilepath}`); } await unlinkFile(filepath); return; } - const parts = file.filepath.split(path.sep); + const parts = cleanFilepath.split(path.sep); const subfolder = parts[1]; if (!subfolder && parts[0] === EModelEndpoint.agents) { logger.warn(`Agent File ${file.file_id} is missing filepath, may have been deleted already`); return; } - const filepath = path.join(publicPath, file.filepath); + const filepath = path.join(publicPath, cleanFilepath); if (!isValidPath(req, publicPath, subfolder, filepath)) { throw new Error('Invalid file path'); diff --git a/api/server/services/Files/Local/images.js b/api/server/services/Files/Local/images.js index fc344cea88..ea3af87c70 100644 --- a/api/server/services/Files/Local/images.js +++ b/api/server/services/Files/Local/images.js @@ -112,10 +112,11 @@ async function prepareImagesLocal(req, file) { * @param {Buffer} params.buffer - The Buffer containing the avatar image. * @param {string} params.userId - The user ID. * @param {string} params.manual - A string flag indicating whether the update is manual ('true' or 'false'). + * @param {string} [params.agentId] - Optional agent ID if this is an agent avatar. * @returns {Promise} - A promise that resolves with the URL of the uploaded avatar. * @throws {Error} - Throws an error if Firebase is not initialized or if there is an error in uploading. */ -async function processLocalAvatar({ buffer, userId, manual }) { +async function processLocalAvatar({ buffer, userId, manual, agentId }) { const userDir = path.resolve( __dirname, '..', @@ -132,7 +133,11 @@ async function processLocalAvatar({ buffer, userId, manual }) { const metadata = await sharp(buffer).metadata(); const extension = metadata.format === 'gif' ? 'gif' : 'png'; - const fileName = `avatar-${new Date().getTime()}.${extension}`; + const timestamp = new Date().getTime(); + /** Unique filename with timestamp and optional agent ID */ + const fileName = agentId + ? `agent-${agentId}-avatar-${timestamp}.${extension}` + : `avatar-${timestamp}.${extension}`; const urlRoute = `/images/${userId}/${fileName}`; const avatarPath = path.join(userDir, fileName); @@ -142,7 +147,8 @@ async function processLocalAvatar({ buffer, userId, manual }) { const isManual = manual === 'true'; let url = `${urlRoute}?manual=${isManual}`; - if (isManual) { + // Only update user record if this is a user avatar (manual === 'true') + if (isManual && !agentId) { await updateUser(userId, { avatar: url }); } diff --git a/api/server/services/Files/MistralOCR/crud.js b/api/server/services/Files/MistralOCR/crud.js deleted file mode 100644 index 02fdb85461..0000000000 --- a/api/server/services/Files/MistralOCR/crud.js +++ /dev/null @@ -1,238 +0,0 @@ -// ~/server/services/Files/MistralOCR/crud.js -const fs = require('fs'); -const path = require('path'); -const FormData = require('form-data'); -const { - FileSources, - envVarRegex, - extractEnvVariable, - extractVariableName, -} = require('librechat-data-provider'); -const { loadAuthValues } = require('~/server/services/Tools/credentials'); -const { logger, createAxiosInstance } = require('~/config'); -const { logAxiosError } = require('~/utils/axios'); - -const axios = createAxiosInstance(); - -/** - * Uploads a document to Mistral API using file streaming to avoid loading the entire file into memory - * - * @param {Object} params Upload parameters - * @param {string} params.filePath The path to the file on disk - * @param {string} [params.fileName] Optional filename to use (defaults to the name from filePath) - * @param {string} params.apiKey Mistral API key - * @param {string} [params.baseURL=https://api.mistral.ai/v1] Mistral API base URL - * @returns {Promise} The response from Mistral API - */ -async function uploadDocumentToMistral({ - filePath, - fileName = '', - apiKey, - baseURL = 'https://api.mistral.ai/v1', -}) { - const form = new FormData(); - form.append('purpose', 'ocr'); - const actualFileName = fileName || path.basename(filePath); - const fileStream = fs.createReadStream(filePath); - form.append('file', fileStream, { filename: actualFileName }); - - return axios - .post(`${baseURL}/files`, form, { - headers: { - Authorization: `Bearer ${apiKey}`, - ...form.getHeaders(), - }, - maxBodyLength: Infinity, - maxContentLength: Infinity, - }) - .then((res) => res.data) - .catch((error) => { - throw error; - }); -} - -async function getSignedUrl({ - apiKey, - fileId, - expiry = 24, - baseURL = 'https://api.mistral.ai/v1', -}) { - return axios - .get(`${baseURL}/files/${fileId}/url?expiry=${expiry}`, { - headers: { - Authorization: `Bearer ${apiKey}`, - }, - }) - .then((res) => res.data) - .catch((error) => { - logger.error('Error fetching signed URL:', error.message); - throw error; - }); -} - -/** - * @param {Object} params - * @param {string} params.apiKey - * @param {string} params.url - The document or image URL - * @param {string} [params.documentType='document_url'] - 'document_url' or 'image_url' - * @param {string} [params.model] - * @param {string} [params.baseURL] - * @returns {Promise} - */ -async function performOCR({ - apiKey, - url, - documentType = 'document_url', - model = 'mistral-ocr-latest', - baseURL = 'https://api.mistral.ai/v1', -}) { - const documentKey = documentType === 'image_url' ? 'image_url' : 'document_url'; - return axios - .post( - `${baseURL}/ocr`, - { - model, - image_limit: 0, - include_image_base64: false, - document: { - type: documentType, - [documentKey]: url, - }, - }, - { - headers: { - 'Content-Type': 'application/json', - Authorization: `Bearer ${apiKey}`, - }, - }, - ) - .then((res) => res.data) - .catch((error) => { - logger.error('Error performing OCR:', error.message); - throw error; - }); -} - -/** - * Uploads a file to the Mistral OCR API and processes the OCR result. - * - * @param {Object} params - The params object. - * @param {ServerRequest} params.req - The request object from Express. It should have a `user` property with an `id` - * representing the user - * @param {Express.Multer.File} params.file - The file object, which is part of the request. The file object should - * have a `mimetype` property that tells us the file type - * @param {string} params.file_id - The file ID. - * @param {string} [params.entity_id] - The entity ID, not used here but passed for consistency. - * @returns {Promise<{ filepath: string, bytes: number }>} - The result object containing the processed `text` and `images` (not currently used), - * along with the `filename` and `bytes` properties. - */ -const uploadMistralOCR = async ({ req, file, file_id, entity_id }) => { - try { - /** @type {TCustomConfig['ocr']} */ - const ocrConfig = req.app.locals?.ocr; - - const apiKeyConfig = ocrConfig.apiKey || ''; - const baseURLConfig = ocrConfig.baseURL || ''; - - const isApiKeyEnvVar = envVarRegex.test(apiKeyConfig); - const isBaseURLEnvVar = envVarRegex.test(baseURLConfig); - - const isApiKeyEmpty = !apiKeyConfig.trim(); - const isBaseURLEmpty = !baseURLConfig.trim(); - - let apiKey, baseURL; - - if (isApiKeyEnvVar || isBaseURLEnvVar || isApiKeyEmpty || isBaseURLEmpty) { - const apiKeyVarName = isApiKeyEnvVar ? extractVariableName(apiKeyConfig) : 'OCR_API_KEY'; - const baseURLVarName = isBaseURLEnvVar ? extractVariableName(baseURLConfig) : 'OCR_BASEURL'; - - const authValues = await loadAuthValues({ - userId: req.user.id, - authFields: [baseURLVarName, apiKeyVarName], - optional: new Set([baseURLVarName]), - }); - - apiKey = authValues[apiKeyVarName]; - baseURL = authValues[baseURLVarName]; - } else { - apiKey = apiKeyConfig; - baseURL = baseURLConfig; - } - - const mistralFile = await uploadDocumentToMistral({ - filePath: file.path, - fileName: file.originalname, - apiKey, - baseURL, - }); - - const modelConfig = ocrConfig.mistralModel || ''; - const model = envVarRegex.test(modelConfig) - ? extractEnvVariable(modelConfig) - : modelConfig.trim() || 'mistral-ocr-latest'; - - const signedUrlResponse = await getSignedUrl({ - apiKey, - baseURL, - fileId: mistralFile.id, - }); - - const mimetype = (file.mimetype || '').toLowerCase(); - const originalname = file.originalname || ''; - const isImage = - mimetype.startsWith('image') || /\.(png|jpe?g|gif|bmp|webp|tiff?)$/i.test(originalname); - const documentType = isImage ? 'image_url' : 'document_url'; - - const ocrResult = await performOCR({ - apiKey, - baseURL, - model, - url: signedUrlResponse.url, - documentType, - }); - - let aggregatedText = ''; - const images = []; - ocrResult.pages.forEach((page, index) => { - if (ocrResult.pages.length > 1) { - aggregatedText += `# PAGE ${index + 1}\n`; - } - - aggregatedText += page.markdown + '\n\n'; - - if (page.images && page.images.length > 0) { - page.images.forEach((image) => { - if (image.image_base64) { - images.push(image.image_base64); - } - }); - } - }); - - return { - filename: file.originalname, - bytes: aggregatedText.length * 4, - filepath: FileSources.mistral_ocr, - text: aggregatedText, - images, - }; - } catch (error) { - let message = 'Error uploading document to Mistral OCR API'; - const detail = error?.response?.data?.detail; - if (detail && detail !== '') { - message = detail; - } - - const responseMessage = error?.response?.data?.message; - throw new Error( - `${logAxiosError({ error, message })}${responseMessage && responseMessage !== '' ? ` - ${responseMessage}` : ''}`, - ); - } -}; - -module.exports = { - uploadDocumentToMistral, - uploadMistralOCR, - getSignedUrl, - performOCR, -}; diff --git a/api/server/services/Files/MistralOCR/crud.spec.js b/api/server/services/Files/MistralOCR/crud.spec.js deleted file mode 100644 index 72d6be7cb0..0000000000 --- a/api/server/services/Files/MistralOCR/crud.spec.js +++ /dev/null @@ -1,848 +0,0 @@ -const fs = require('fs'); - -const mockAxios = { - interceptors: { - request: { use: jest.fn(), eject: jest.fn() }, - response: { use: jest.fn(), eject: jest.fn() }, - }, - create: jest.fn().mockReturnValue({ - defaults: { - proxy: null, - }, - get: jest.fn().mockResolvedValue({ data: {} }), - post: jest.fn().mockResolvedValue({ data: {} }), - put: jest.fn().mockResolvedValue({ data: {} }), - delete: jest.fn().mockResolvedValue({ data: {} }), - }), - get: jest.fn().mockResolvedValue({ data: {} }), - post: jest.fn().mockResolvedValue({ data: {} }), - put: jest.fn().mockResolvedValue({ data: {} }), - delete: jest.fn().mockResolvedValue({ data: {} }), - reset: jest.fn().mockImplementation(function () { - this.get.mockClear(); - this.post.mockClear(); - this.put.mockClear(); - this.delete.mockClear(); - this.create.mockClear(); - }), -}; - -jest.mock('axios', () => mockAxios); -jest.mock('fs'); -jest.mock('~/config', () => ({ - logger: { - error: jest.fn(), - }, - createAxiosInstance: () => mockAxios, -})); -jest.mock('~/server/services/Tools/credentials', () => ({ - loadAuthValues: jest.fn(), -})); - -const { uploadDocumentToMistral, uploadMistralOCR, getSignedUrl, performOCR } = require('./crud'); - -describe('MistralOCR Service', () => { - afterEach(() => { - mockAxios.reset(); - jest.clearAllMocks(); - }); - - describe('uploadDocumentToMistral', () => { - beforeEach(() => { - // Create a more complete mock for file streams that FormData can work with - const mockReadStream = { - on: jest.fn().mockImplementation(function (event, handler) { - // Simulate immediate 'end' event to make FormData complete processing - if (event === 'end') { - handler(); - } - return this; - }), - pipe: jest.fn().mockImplementation(function () { - return this; - }), - pause: jest.fn(), - resume: jest.fn(), - emit: jest.fn(), - once: jest.fn(), - destroy: jest.fn(), - }; - - fs.createReadStream = jest.fn().mockReturnValue(mockReadStream); - - // Mock FormData's append to avoid actual stream processing - jest.mock('form-data', () => { - const mockFormData = function () { - return { - append: jest.fn(), - getHeaders: jest - .fn() - .mockReturnValue({ 'content-type': 'multipart/form-data; boundary=---boundary' }), - getBuffer: jest.fn().mockReturnValue(Buffer.from('mock-form-data')), - getLength: jest.fn().mockReturnValue(100), - }; - }; - return mockFormData; - }); - }); - - it('should upload a document to Mistral API using file streaming', async () => { - const mockResponse = { data: { id: 'file-123', purpose: 'ocr' } }; - mockAxios.post.mockResolvedValueOnce(mockResponse); - - const result = await uploadDocumentToMistral({ - filePath: '/path/to/test.pdf', - fileName: 'test.pdf', - apiKey: 'test-api-key', - }); - - // Check that createReadStream was called with the correct file path - expect(fs.createReadStream).toHaveBeenCalledWith('/path/to/test.pdf'); - - // Since we're mocking FormData, we'll just check that axios was called correctly - expect(mockAxios.post).toHaveBeenCalledWith( - 'https://api.mistral.ai/v1/files', - expect.anything(), - expect.objectContaining({ - headers: expect.objectContaining({ - Authorization: 'Bearer test-api-key', - }), - maxBodyLength: Infinity, - maxContentLength: Infinity, - }), - ); - expect(result).toEqual(mockResponse.data); - }); - - it('should handle errors during document upload', async () => { - const errorMessage = 'API error'; - mockAxios.post.mockRejectedValueOnce(new Error(errorMessage)); - - await expect( - uploadDocumentToMistral({ - filePath: '/path/to/test.pdf', - fileName: 'test.pdf', - apiKey: 'test-api-key', - }), - ).rejects.toThrow(errorMessage); - }); - }); - - describe('getSignedUrl', () => { - it('should fetch signed URL from Mistral API', async () => { - const mockResponse = { data: { url: 'https://document-url.com' } }; - mockAxios.get.mockResolvedValueOnce(mockResponse); - - const result = await getSignedUrl({ - fileId: 'file-123', - apiKey: 'test-api-key', - }); - - expect(mockAxios.get).toHaveBeenCalledWith( - 'https://api.mistral.ai/v1/files/file-123/url?expiry=24', - { - headers: { - Authorization: 'Bearer test-api-key', - }, - }, - ); - expect(result).toEqual(mockResponse.data); - }); - - it('should handle errors when fetching signed URL', async () => { - const errorMessage = 'API error'; - mockAxios.get.mockRejectedValueOnce(new Error(errorMessage)); - - await expect( - getSignedUrl({ - fileId: 'file-123', - apiKey: 'test-api-key', - }), - ).rejects.toThrow(); - - const { logger } = require('~/config'); - expect(logger.error).toHaveBeenCalledWith('Error fetching signed URL:', errorMessage); - }); - }); - - describe('performOCR', () => { - it('should perform OCR using Mistral API (document_url)', async () => { - const mockResponse = { - data: { - pages: [{ markdown: 'Page 1 content' }, { markdown: 'Page 2 content' }], - }, - }; - mockAxios.post.mockResolvedValueOnce(mockResponse); - - const result = await performOCR({ - apiKey: 'test-api-key', - url: 'https://document-url.com', - model: 'mistral-ocr-latest', - documentType: 'document_url', - }); - - expect(mockAxios.post).toHaveBeenCalledWith( - 'https://api.mistral.ai/v1/ocr', - { - model: 'mistral-ocr-latest', - include_image_base64: false, - image_limit: 0, - document: { - type: 'document_url', - document_url: 'https://document-url.com', - }, - }, - { - headers: { - 'Content-Type': 'application/json', - Authorization: 'Bearer test-api-key', - }, - }, - ); - expect(result).toEqual(mockResponse.data); - }); - - it('should perform OCR using Mistral API (image_url)', async () => { - const mockResponse = { - data: { - pages: [{ markdown: 'Image OCR content' }], - }, - }; - mockAxios.post.mockResolvedValueOnce(mockResponse); - - const result = await performOCR({ - apiKey: 'test-api-key', - url: 'https://image-url.com/image.png', - model: 'mistral-ocr-latest', - documentType: 'image_url', - }); - - expect(mockAxios.post).toHaveBeenCalledWith( - 'https://api.mistral.ai/v1/ocr', - { - model: 'mistral-ocr-latest', - include_image_base64: false, - image_limit: 0, - document: { - type: 'image_url', - image_url: 'https://image-url.com/image.png', - }, - }, - { - headers: { - 'Content-Type': 'application/json', - Authorization: 'Bearer test-api-key', - }, - }, - ); - expect(result).toEqual(mockResponse.data); - }); - - it('should handle errors during OCR processing', async () => { - const errorMessage = 'OCR processing error'; - mockAxios.post.mockRejectedValueOnce(new Error(errorMessage)); - - await expect( - performOCR({ - apiKey: 'test-api-key', - url: 'https://document-url.com', - }), - ).rejects.toThrow(); - - const { logger } = require('~/config'); - expect(logger.error).toHaveBeenCalledWith('Error performing OCR:', errorMessage); - }); - }); - - describe('uploadMistralOCR', () => { - beforeEach(() => { - const mockReadStream = { - on: jest.fn().mockImplementation(function (event, handler) { - if (event === 'end') { - handler(); - } - return this; - }), - pipe: jest.fn().mockImplementation(function () { - return this; - }), - pause: jest.fn(), - resume: jest.fn(), - emit: jest.fn(), - once: jest.fn(), - destroy: jest.fn(), - }; - - fs.createReadStream = jest.fn().mockReturnValue(mockReadStream); - }); - - it('should process OCR for a file with standard configuration', async () => { - // Setup mocks - const { loadAuthValues } = require('~/server/services/Tools/credentials'); - loadAuthValues.mockResolvedValue({ - OCR_API_KEY: 'test-api-key', - OCR_BASEURL: 'https://api.mistral.ai/v1', - }); - - // Mock file upload response - mockAxios.post.mockResolvedValueOnce({ - data: { id: 'file-123', purpose: 'ocr' }, - }); - - // Mock signed URL response - mockAxios.get.mockResolvedValueOnce({ - data: { url: 'https://signed-url.com' }, - }); - - // Mock OCR response with text and images - mockAxios.post.mockResolvedValueOnce({ - data: { - pages: [ - { - markdown: 'Page 1 content', - images: [{ image_base64: 'base64image1' }], - }, - { - markdown: 'Page 2 content', - images: [{ image_base64: 'base64image2' }], - }, - ], - }, - }); - - const req = { - user: { id: 'user123' }, - app: { - locals: { - ocr: { - // Use environment variable syntax to ensure loadAuthValues is called - apiKey: '${OCR_API_KEY}', - baseURL: '${OCR_BASEURL}', - mistralModel: 'mistral-medium', - }, - }, - }, - }; - - const file = { - path: '/tmp/upload/file.pdf', - originalname: 'document.pdf', - mimetype: 'application/pdf', - }; - - const result = await uploadMistralOCR({ - req, - file, - file_id: 'file123', - entity_id: 'entity123', - }); - - expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); - - expect(loadAuthValues).toHaveBeenCalledWith({ - userId: 'user123', - authFields: ['OCR_BASEURL', 'OCR_API_KEY'], - optional: expect.any(Set), - }); - - // Verify OCR result - expect(result).toEqual({ - filename: 'document.pdf', - bytes: expect.any(Number), - filepath: 'mistral_ocr', - text: expect.stringContaining('# PAGE 1'), - images: ['base64image1', 'base64image2'], - }); - }); - - it('should process OCR for an image file and use image_url type', async () => { - const { loadAuthValues } = require('~/server/services/Tools/credentials'); - loadAuthValues.mockResolvedValue({ - OCR_API_KEY: 'test-api-key', - OCR_BASEURL: 'https://api.mistral.ai/v1', - }); - - // Mock file upload response - mockAxios.post.mockResolvedValueOnce({ - data: { id: 'file-456', purpose: 'ocr' }, - }); - - // Mock signed URL response - mockAxios.get.mockResolvedValueOnce({ - data: { url: 'https://signed-url.com/image.png' }, - }); - - // Mock OCR response for image - mockAxios.post.mockResolvedValueOnce({ - data: { - pages: [ - { - markdown: 'Image OCR result', - images: [{ image_base64: 'imgbase64' }], - }, - ], - }, - }); - - const req = { - user: { id: 'user456' }, - app: { - locals: { - ocr: { - apiKey: '${OCR_API_KEY}', - baseURL: '${OCR_BASEURL}', - mistralModel: 'mistral-medium', - }, - }, - }, - }; - - const file = { - path: '/tmp/upload/image.png', - originalname: 'image.png', - mimetype: 'image/png', - }; - - const result = await uploadMistralOCR({ - req, - file, - file_id: 'file456', - entity_id: 'entity456', - }); - - expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/image.png'); - - expect(loadAuthValues).toHaveBeenCalledWith({ - userId: 'user456', - authFields: ['OCR_BASEURL', 'OCR_API_KEY'], - optional: expect.any(Set), - }); - - // Check that the OCR API was called with image_url type - expect(mockAxios.post).toHaveBeenCalledWith( - 'https://api.mistral.ai/v1/ocr', - expect.objectContaining({ - document: expect.objectContaining({ - type: 'image_url', - image_url: 'https://signed-url.com/image.png', - }), - }), - expect.any(Object), - ); - - expect(result).toEqual({ - filename: 'image.png', - bytes: expect.any(Number), - filepath: 'mistral_ocr', - text: expect.stringContaining('Image OCR result'), - images: ['imgbase64'], - }); - }); - - it('should process variable references in configuration', async () => { - // Setup mocks with environment variables - const { loadAuthValues } = require('~/server/services/Tools/credentials'); - loadAuthValues.mockResolvedValue({ - CUSTOM_API_KEY: 'custom-api-key', - CUSTOM_BASEURL: 'https://custom-api.mistral.ai/v1', - }); - - // Mock API responses - mockAxios.post.mockResolvedValueOnce({ - data: { id: 'file-123', purpose: 'ocr' }, - }); - mockAxios.get.mockResolvedValueOnce({ - data: { url: 'https://signed-url.com' }, - }); - mockAxios.post.mockResolvedValueOnce({ - data: { - pages: [{ markdown: 'Content from custom API' }], - }, - }); - - const req = { - user: { id: 'user123' }, - app: { - locals: { - ocr: { - apiKey: '${CUSTOM_API_KEY}', - baseURL: '${CUSTOM_BASEURL}', - mistralModel: '${CUSTOM_MODEL}', - }, - }, - }, - }; - - // Set environment variable for model - process.env.CUSTOM_MODEL = 'mistral-large'; - - const file = { - path: '/tmp/upload/file.pdf', - originalname: 'document.pdf', - }; - - const result = await uploadMistralOCR({ - req, - file, - file_id: 'file123', - entity_id: 'entity123', - }); - - expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); - - // Verify that custom environment variables were extracted and used - expect(loadAuthValues).toHaveBeenCalledWith({ - userId: 'user123', - authFields: ['CUSTOM_BASEURL', 'CUSTOM_API_KEY'], - optional: expect.any(Set), - }); - - // Check that mistral-large was used in the OCR API call - expect(mockAxios.post).toHaveBeenCalledWith( - expect.anything(), - expect.objectContaining({ - model: 'mistral-large', - }), - expect.anything(), - ); - - expect(result.text).toEqual('Content from custom API\n\n'); - }); - - it('should fall back to default values when variables are not properly formatted', async () => { - const { loadAuthValues } = require('~/server/services/Tools/credentials'); - loadAuthValues.mockResolvedValue({ - OCR_API_KEY: 'default-api-key', - OCR_BASEURL: undefined, // Testing optional parameter - }); - - mockAxios.post.mockResolvedValueOnce({ - data: { id: 'file-123', purpose: 'ocr' }, - }); - mockAxios.get.mockResolvedValueOnce({ - data: { url: 'https://signed-url.com' }, - }); - mockAxios.post.mockResolvedValueOnce({ - data: { - pages: [{ markdown: 'Default API result' }], - }, - }); - - const req = { - user: { id: 'user123' }, - app: { - locals: { - ocr: { - // Use environment variable syntax to ensure loadAuthValues is called - apiKey: '${INVALID_FORMAT}', // Using valid env var format but with an invalid name - baseURL: '${OCR_BASEURL}', // Using valid env var format - mistralModel: 'mistral-ocr-latest', // Plain string value - }, - }, - }, - }; - - const file = { - path: '/tmp/upload/file.pdf', - originalname: 'document.pdf', - }; - - await uploadMistralOCR({ - req, - file, - file_id: 'file123', - entity_id: 'entity123', - }); - - expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); - - // Should use the default values - expect(loadAuthValues).toHaveBeenCalledWith({ - userId: 'user123', - authFields: ['OCR_BASEURL', 'INVALID_FORMAT'], - optional: expect.any(Set), - }); - - // Should use the default model when not using environment variable format - expect(mockAxios.post).toHaveBeenCalledWith( - expect.anything(), - expect.objectContaining({ - model: 'mistral-ocr-latest', - }), - expect.anything(), - ); - }); - - it('should handle API errors during OCR process', async () => { - const { loadAuthValues } = require('~/server/services/Tools/credentials'); - loadAuthValues.mockResolvedValue({ - OCR_API_KEY: 'test-api-key', - }); - - // Mock file upload to fail - mockAxios.post.mockRejectedValueOnce(new Error('Upload failed')); - - const req = { - user: { id: 'user123' }, - app: { - locals: { - ocr: { - apiKey: 'OCR_API_KEY', - baseURL: 'OCR_BASEURL', - }, - }, - }, - }; - - const file = { - path: '/tmp/upload/file.pdf', - originalname: 'document.pdf', - }; - - await expect( - uploadMistralOCR({ - req, - file, - file_id: 'file123', - entity_id: 'entity123', - }), - ).rejects.toThrow('Error uploading document to Mistral OCR API'); - expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); - }); - - it('should handle single page documents without page numbering', async () => { - const { loadAuthValues } = require('~/server/services/Tools/credentials'); - loadAuthValues.mockResolvedValue({ - OCR_API_KEY: 'test-api-key', - OCR_BASEURL: 'https://api.mistral.ai/v1', // Make sure this is included - }); - - // Clear all previous mocks - mockAxios.post.mockClear(); - mockAxios.get.mockClear(); - - // 1. First mock: File upload response - mockAxios.post.mockImplementationOnce(() => - Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }), - ); - - // 2. Second mock: Signed URL response - mockAxios.get.mockImplementationOnce(() => - Promise.resolve({ data: { url: 'https://signed-url.com' } }), - ); - - // 3. Third mock: OCR response - mockAxios.post.mockImplementationOnce(() => - Promise.resolve({ - data: { - pages: [{ markdown: 'Single page content' }], - }, - }), - ); - - const req = { - user: { id: 'user123' }, - app: { - locals: { - ocr: { - apiKey: 'OCR_API_KEY', - baseURL: 'OCR_BASEURL', - mistralModel: 'mistral-ocr-latest', - }, - }, - }, - }; - - const file = { - path: '/tmp/upload/file.pdf', - originalname: 'single-page.pdf', - }; - - const result = await uploadMistralOCR({ - req, - file, - file_id: 'file123', - entity_id: 'entity123', - }); - - expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); - - // Verify that single page documents don't include page numbering - expect(result.text).not.toContain('# PAGE'); - expect(result.text).toEqual('Single page content\n\n'); - }); - - it('should use literal values in configuration when provided directly', async () => { - const { loadAuthValues } = require('~/server/services/Tools/credentials'); - // We'll still mock this but it should not be used for literal values - loadAuthValues.mockResolvedValue({}); - - // Clear all previous mocks - mockAxios.post.mockClear(); - mockAxios.get.mockClear(); - - // 1. First mock: File upload response - mockAxios.post.mockImplementationOnce(() => - Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }), - ); - - // 2. Second mock: Signed URL response - mockAxios.get.mockImplementationOnce(() => - Promise.resolve({ data: { url: 'https://signed-url.com' } }), - ); - - // 3. Third mock: OCR response - mockAxios.post.mockImplementationOnce(() => - Promise.resolve({ - data: { - pages: [{ markdown: 'Processed with literal config values' }], - }, - }), - ); - - const req = { - user: { id: 'user123' }, - app: { - locals: { - ocr: { - // Direct values that should be used as-is, without variable substitution - apiKey: 'actual-api-key-value', - baseURL: 'https://direct-api-url.mistral.ai/v1', - mistralModel: 'mistral-direct-model', - }, - }, - }, - }; - - const file = { - path: '/tmp/upload/file.pdf', - originalname: 'direct-values.pdf', - }; - - const result = await uploadMistralOCR({ - req, - file, - file_id: 'file123', - entity_id: 'entity123', - }); - - expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); - - // Verify the correct URL was used with the direct baseURL value - expect(mockAxios.post).toHaveBeenCalledWith( - 'https://direct-api-url.mistral.ai/v1/files', - expect.any(Object), - expect.objectContaining({ - headers: expect.objectContaining({ - Authorization: 'Bearer actual-api-key-value', - }), - }), - ); - - // Check the OCR call was made with the direct model value - expect(mockAxios.post).toHaveBeenCalledWith( - 'https://direct-api-url.mistral.ai/v1/ocr', - expect.objectContaining({ - model: 'mistral-direct-model', - }), - expect.any(Object), - ); - - // Verify the result - expect(result.text).toEqual('Processed with literal config values\n\n'); - - // Verify loadAuthValues was never called since we used direct values - expect(loadAuthValues).not.toHaveBeenCalled(); - }); - - it('should handle empty configuration values and use defaults', async () => { - const { loadAuthValues } = require('~/server/services/Tools/credentials'); - // Set up the mock values to be returned by loadAuthValues - loadAuthValues.mockResolvedValue({ - OCR_API_KEY: 'default-from-env-key', - OCR_BASEURL: 'https://default-from-env.mistral.ai/v1', - }); - - // Clear all previous mocks - mockAxios.post.mockClear(); - mockAxios.get.mockClear(); - - // 1. First mock: File upload response - mockAxios.post.mockImplementationOnce(() => - Promise.resolve({ data: { id: 'file-123', purpose: 'ocr' } }), - ); - - // 2. Second mock: Signed URL response - mockAxios.get.mockImplementationOnce(() => - Promise.resolve({ data: { url: 'https://signed-url.com' } }), - ); - - // 3. Third mock: OCR response - mockAxios.post.mockImplementationOnce(() => - Promise.resolve({ - data: { - pages: [{ markdown: 'Content from default configuration' }], - }, - }), - ); - - const req = { - user: { id: 'user123' }, - app: { - locals: { - ocr: { - // Empty string values - should fall back to defaults - apiKey: '', - baseURL: '', - mistralModel: '', - }, - }, - }, - }; - - const file = { - path: '/tmp/upload/file.pdf', - originalname: 'empty-config.pdf', - }; - - const result = await uploadMistralOCR({ - req, - file, - file_id: 'file123', - entity_id: 'entity123', - }); - - expect(fs.createReadStream).toHaveBeenCalledWith('/tmp/upload/file.pdf'); - - // Verify loadAuthValues was called with the default variable names - expect(loadAuthValues).toHaveBeenCalledWith({ - userId: 'user123', - authFields: ['OCR_BASEURL', 'OCR_API_KEY'], - optional: expect.any(Set), - }); - - // Verify the API calls used the default values from loadAuthValues - expect(mockAxios.post).toHaveBeenCalledWith( - 'https://default-from-env.mistral.ai/v1/files', - expect.any(Object), - expect.objectContaining({ - headers: expect.objectContaining({ - Authorization: 'Bearer default-from-env-key', - }), - }), - ); - - // Verify the OCR model defaulted to mistral-ocr-latest - expect(mockAxios.post).toHaveBeenCalledWith( - 'https://default-from-env.mistral.ai/v1/ocr', - expect.objectContaining({ - model: 'mistral-ocr-latest', - }), - expect.any(Object), - ); - - // Check result - expect(result.text).toEqual('Content from default configuration\n\n'); - }); - }); -}); diff --git a/api/server/services/Files/MistralOCR/index.js b/api/server/services/Files/MistralOCR/index.js deleted file mode 100644 index a6223d1ee5..0000000000 --- a/api/server/services/Files/MistralOCR/index.js +++ /dev/null @@ -1,5 +0,0 @@ -const crud = require('./crud'); - -module.exports = { - ...crud, -}; diff --git a/api/server/services/Files/S3/images.js b/api/server/services/Files/S3/images.js index 07faec1765..688d5eb68b 100644 --- a/api/server/services/Files/S3/images.js +++ b/api/server/services/Files/S3/images.js @@ -94,19 +94,28 @@ async function prepareImageURLS3(req, file) { * @param {Buffer} params.buffer - Avatar image buffer. * @param {string} params.userId - User's unique identifier. * @param {string} params.manual - 'true' or 'false' flag for manual update. + * @param {string} [params.agentId] - Optional agent ID if this is an agent avatar. * @param {string} [params.basePath='images'] - Base path in the bucket. * @returns {Promise} Signed URL of the uploaded avatar. */ -async function processS3Avatar({ buffer, userId, manual, basePath = defaultBasePath }) { +async function processS3Avatar({ buffer, userId, manual, agentId, basePath = defaultBasePath }) { try { const metadata = await sharp(buffer).metadata(); const extension = metadata.format === 'gif' ? 'gif' : 'png'; - const fileName = `avatar.${extension}`; + const timestamp = new Date().getTime(); + + /** Unique filename with timestamp and optional agent ID */ + const fileName = agentId + ? `agent-${agentId}-avatar-${timestamp}.${extension}` + : `avatar-${timestamp}.${extension}`; const downloadURL = await saveBufferToS3({ userId, buffer, fileName, basePath }); - if (manual === 'true') { + + // Only update user record if this is a user avatar (manual === 'true') + if (manual === 'true' && !agentId) { await updateUser(userId, { avatar: downloadURL }); } + return downloadURL; } catch (error) { logger.error('[processS3Avatar] Error processing S3 avatar:', error.message); diff --git a/api/server/services/Files/VectorDB/crud.js b/api/server/services/Files/VectorDB/crud.js index 37a1e81487..1aeabc6c46 100644 --- a/api/server/services/Files/VectorDB/crud.js +++ b/api/server/services/Files/VectorDB/crud.js @@ -1,9 +1,9 @@ const fs = require('fs'); const axios = require('axios'); const FormData = require('form-data'); +const { logAxiosError } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { FileSources } = require('librechat-data-provider'); -const { logAxiosError } = require('~/utils'); -const { logger } = require('~/config'); /** * Deletes a file from the vector database. This function takes a file object, constructs the full path, and diff --git a/api/server/services/Files/images/encode.js b/api/server/services/Files/images/encode.js index 154941fd89..e87654b378 100644 --- a/api/server/services/Files/images/encode.js +++ b/api/server/services/Files/images/encode.js @@ -1,4 +1,5 @@ const axios = require('axios'); +const { logAxiosError } = require('@librechat/api'); const { FileSources, VisionModes, @@ -7,8 +8,6 @@ const { EModelEndpoint, } = require('librechat-data-provider'); const { getStrategyFunctions } = require('~/server/services/Files/strategies'); -const { logAxiosError } = require('~/utils'); -const { logger } = require('~/config'); /** * Converts a readable stream to a base64 encoded string. diff --git a/api/server/services/Files/process.js b/api/server/services/Files/process.js index 94b1bc4dad..8910163047 100644 --- a/api/server/services/Files/process.js +++ b/api/server/services/Files/process.js @@ -522,7 +522,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => { throw new Error('OCR capability is not enabled for Agents'); } - const { handleFileUpload: uploadMistralOCR } = getStrategyFunctions( + const { handleFileUpload: uploadOCR } = getStrategyFunctions( req.app.locals?.ocr?.strategy ?? FileSources.mistral_ocr, ); const { file_id, temp_file_id } = metadata; @@ -534,7 +534,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => { images, filename, filepath: ocrFileURL, - } = await uploadMistralOCR({ req, file, file_id, entity_id: agent_id, basePath }); + } = await uploadOCR({ req, file, loadAuthValues }); const fileInfo = removeNullishValues({ text, diff --git a/api/server/services/Files/strategies.js b/api/server/services/Files/strategies.js index c6cfe77069..41dcd5518a 100644 --- a/api/server/services/Files/strategies.js +++ b/api/server/services/Files/strategies.js @@ -1,4 +1,5 @@ const { FileSources } = require('librechat-data-provider'); +const { uploadMistralOCR, uploadAzureMistralOCR } = require('@librechat/api'); const { getFirebaseURL, prepareImageURL, @@ -46,7 +47,6 @@ const { const { uploadOpenAIFile, deleteOpenAIFile, getOpenAIFileStream } = require('./OpenAI'); const { getCodeOutputDownloadStream, uploadCodeEnvFile } = require('./Code'); const { uploadVectors, deleteVectors } = require('./VectorDB'); -const { uploadMistralOCR } = require('./MistralOCR'); /** * Firebase Storage Strategy Functions @@ -202,6 +202,26 @@ const mistralOCRStrategy = () => ({ handleFileUpload: uploadMistralOCR, }); +const azureMistralOCRStrategy = () => ({ + /** @type {typeof saveFileFromURL | null} */ + saveURL: null, + /** @type {typeof getLocalFileURL | null} */ + getFileURL: null, + /** @type {typeof saveLocalBuffer | null} */ + saveBuffer: null, + /** @type {typeof processLocalAvatar | null} */ + processAvatar: null, + /** @type {typeof uploadLocalImage | null} */ + handleImageUpload: null, + /** @type {typeof prepareImagesLocal | null} */ + prepareImagePayload: null, + /** @type {typeof deleteLocalFile | null} */ + deleteFile: null, + /** @type {typeof getLocalFileStream | null} */ + getDownloadStream: null, + handleFileUpload: uploadAzureMistralOCR, +}); + // Strategy Selector const getStrategyFunctions = (fileSource) => { if (fileSource === FileSources.firebase) { @@ -222,6 +242,8 @@ const getStrategyFunctions = (fileSource) => { return codeOutputStrategy(); } else if (fileSource === FileSources.mistral_ocr) { return mistralOCRStrategy(); + } else if (fileSource === FileSources.azure_mistral_ocr) { + return azureMistralOCRStrategy(); } else { throw new Error('Invalid file source'); } diff --git a/api/server/services/MCP.js b/api/server/services/MCP.js index b9baef462e..527fe2d514 100644 --- a/api/server/services/MCP.js +++ b/api/server/services/MCP.js @@ -1,27 +1,111 @@ const { z } = require('zod'); const { tool } = require('@langchain/core/tools'); -const { normalizeServerName } = require('librechat-mcp'); -const { Constants: AgentConstants, Providers } = require('@librechat/agents'); +const { logger } = require('@librechat/data-schemas'); +const { Time, CacheKeys, StepTypes } = require('librechat-data-provider'); +const { sendEvent, normalizeServerName, MCPOAuthHandler } = require('@librechat/api'); +const { Constants: AgentConstants, Providers, GraphEvents } = require('@librechat/agents'); const { Constants, ContentTypes, isAssistantsEndpoint, convertJsonSchemaToZod, } = require('librechat-data-provider'); -const { logger, getMCPManager } = require('~/config'); +const { getMCPManager, getFlowStateManager } = require('~/config'); +const { findToken, createToken, updateToken } = require('~/models'); +const { getCachedTools } = require('./Config'); +const { getLogStores } = require('~/cache'); + +/** + * @param {object} params + * @param {ServerResponse} params.res - The Express response object for sending events. + * @param {string} params.stepId - The ID of the step in the flow. + * @param {ToolCallChunk} params.toolCall - The tool call object containing tool information. + * @param {string} params.loginFlowId - The ID of the login flow. + * @param {FlowStateManager} params.flowManager - The flow manager instance. + */ +function createOAuthStart({ res, stepId, toolCall, loginFlowId, flowManager, signal }) { + /** + * Creates a function to handle OAuth login requests. + * @param {string} authURL - The URL to redirect the user for OAuth authentication. + * @returns {Promise} Returns true to indicate the event was sent successfully. + */ + return async function (authURL) { + /** @type {{ id: string; delta: AgentToolCallDelta }} */ + const data = { + id: stepId, + delta: { + type: StepTypes.TOOL_CALLS, + tool_calls: [{ ...toolCall, args: '' }], + auth: authURL, + expires_at: Date.now() + Time.TWO_MINUTES, + }, + }; + /** Used to ensure the handler (use of `sendEvent`) is only invoked once */ + await flowManager.createFlowWithHandler( + loginFlowId, + 'oauth_login', + async () => { + sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data }); + logger.debug('Sent OAuth login request to client'); + return true; + }, + signal, + ); + }; +} + +/** + * @param {object} params + * @param {ServerResponse} params.res - The Express response object for sending events. + * @param {string} params.stepId - The ID of the step in the flow. + * @param {ToolCallChunk} params.toolCall - The tool call object containing tool information. + * @param {string} params.loginFlowId - The ID of the login flow. + * @param {FlowStateManager} params.flowManager - The flow manager instance. + */ +function createOAuthEnd({ res, stepId, toolCall }) { + return async function () { + /** @type {{ id: string; delta: AgentToolCallDelta }} */ + const data = { + id: stepId, + delta: { + type: StepTypes.TOOL_CALLS, + tool_calls: [{ ...toolCall }], + }, + }; + sendEvent(res, { event: GraphEvents.ON_RUN_STEP_DELTA, data }); + logger.debug('Sent OAuth login success to client'); + }; +} + +/** + * @param {object} params + * @param {string} params.userId - The ID of the user. + * @param {string} params.serverName - The name of the server. + * @param {string} params.toolName - The name of the tool. + * @param {FlowStateManager} params.flowManager - The flow manager instance. + */ +function createAbortHandler({ userId, serverName, toolName, flowManager }) { + return function () { + logger.info(`[MCP][User: ${userId}][${serverName}][${toolName}] Tool call aborted`); + const flowId = MCPOAuthHandler.generateFlowId(userId, serverName); + flowManager.failFlow(flowId, 'mcp_oauth', new Error('Tool call aborted')); + }; +} /** * Creates a general tool for an entire action set. * * @param {Object} params - The parameters for loading action sets. * @param {ServerRequest} params.req - The Express request object, containing user/request info. + * @param {ServerResponse} params.res - The Express response object for sending events. * @param {string} params.toolKey - The toolKey for the tool. * @param {import('@librechat/agents').Providers | EModelEndpoint} params.provider - The provider for the tool. * @param {string} params.model - The model for the tool. * @returns { Promise unknown}> } An object with `_call` method to execute the tool input. */ -async function createMCPTool({ req, toolKey, provider: _provider }) { - const toolDefinition = req.app.locals.availableTools[toolKey]?.function; +async function createMCPTool({ req, res, toolKey, provider: _provider }) { + const availableTools = await getCachedTools({ includeGlobal: true }); + const toolDefinition = availableTools?.[toolKey]?.function; if (!toolDefinition) { logger.error(`Tool ${toolKey} not found in available tools`); return null; @@ -50,19 +134,61 @@ async function createMCPTool({ req, toolKey, provider: _provider }) { /** @type {(toolArguments: Object | string, config?: GraphRunnableConfig) => Promise} */ const _call = async (toolArguments, config) => { + const userId = config?.configurable?.user?.id || config?.configurable?.user_id; + /** @type {ReturnType} */ + let abortHandler = null; + /** @type {AbortSignal} */ + let derivedSignal = null; + try { - const derivedSignal = config?.signal ? AbortSignal.any([config.signal]) : undefined; - const mcpManager = getMCPManager(config?.configurable?.user_id); + const flowsCache = getLogStores(CacheKeys.FLOWS); + const flowManager = getFlowStateManager(flowsCache); + derivedSignal = config?.signal ? AbortSignal.any([config.signal]) : undefined; + const mcpManager = getMCPManager(userId); const provider = (config?.metadata?.provider || _provider)?.toLowerCase(); + + const { args: _args, stepId, ...toolCall } = config.toolCall ?? {}; + const loginFlowId = `${serverName}:oauth_login:${config.metadata.thread_id}:${config.metadata.run_id}`; + const oauthStart = createOAuthStart({ + res, + stepId, + toolCall, + loginFlowId, + flowManager, + signal: derivedSignal, + }); + const oauthEnd = createOAuthEnd({ + res, + stepId, + toolCall, + }); + + if (derivedSignal) { + abortHandler = createAbortHandler({ userId, serverName, toolName, flowManager }); + derivedSignal.addEventListener('abort', abortHandler, { once: true }); + } + + const customUserVars = + config?.configurable?.userMCPAuthMap?.[`${Constants.mcp_prefix}${serverName}`]; + const result = await mcpManager.callTool({ serverName, toolName, provider, toolArguments, options: { - userId: config?.configurable?.user_id, signal: derivedSignal, }, + user: config?.configurable?.user, + customUserVars, + flowManager, + tokenMethods: { + findToken, + createToken, + updateToken, + }, + oauthStart, + oauthEnd, }); if (isAssistantsEndpoint(provider) && Array.isArray(result)) { @@ -74,12 +200,31 @@ async function createMCPTool({ req, toolKey, provider: _provider }) { return result; } catch (error) { logger.error( - `[MCP][User: ${config?.configurable?.user_id}][${serverName}] Error calling "${toolName}" MCP tool:`, + `[MCP][User: ${userId}][${serverName}] Error calling "${toolName}" MCP tool:`, error, ); + + /** OAuth error, provide a helpful message */ + const isOAuthError = + error.message?.includes('401') || + error.message?.includes('OAuth') || + error.message?.includes('authentication') || + error.message?.includes('Non-200 status code (401)'); + + if (isOAuthError) { + throw new Error( + `OAuth authentication required for ${serverName}. Please check the server logs for the authentication URL.`, + ); + } + throw new Error( `"${toolKey}" tool call failed${error?.message ? `: ${error?.message}` : '.'}`, ); + } finally { + // Clean up abort handler to prevent memory leaks + if (abortHandler && derivedSignal) { + derivedSignal.removeEventListener('abort', abortHandler); + } } }; diff --git a/api/server/services/ModelService.js b/api/server/services/ModelService.js index a1ccd7643b..0db13ec318 100644 --- a/api/server/services/ModelService.js +++ b/api/server/services/ModelService.js @@ -1,12 +1,13 @@ const axios = require('axios'); const { Providers } = require('@librechat/agents'); +const { logAxiosError } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { HttpsProxyAgent } = require('https-proxy-agent'); const { EModelEndpoint, defaultModels, CacheKeys } = require('librechat-data-provider'); -const { inputSchema, logAxiosError, extractBaseURL, processModelData } = require('~/utils'); +const { inputSchema, extractBaseURL, processModelData } = require('~/utils'); const { OllamaClient } = require('~/app/clients/OllamaClient'); const { isUserProvided } = require('~/server/utils'); const getLogStores = require('~/cache/getLogStores'); -const { logger } = require('~/config'); /** * Splits a string by commas and trims each resulting value. diff --git a/api/server/services/ModelService.spec.js b/api/server/services/ModelService.spec.js index fb4481f840..33ab9a7aaf 100644 --- a/api/server/services/ModelService.spec.js +++ b/api/server/services/ModelService.spec.js @@ -1,6 +1,6 @@ const axios = require('axios'); +const { logger } = require('@librechat/data-schemas'); const { EModelEndpoint, defaultModels } = require('librechat-data-provider'); -const { logger } = require('~/config'); const { fetchModels, @@ -28,7 +28,8 @@ jest.mock('~/cache/getLogStores', () => set: jest.fn().mockResolvedValue(true), })), ); -jest.mock('~/config', () => ({ +jest.mock('@librechat/data-schemas', () => ({ + ...jest.requireActual('@librechat/data-schemas'), logger: { error: jest.fn(), }, diff --git a/api/server/services/PluginService.js b/api/server/services/PluginService.js index 7463e0814e..af42e0471c 100644 --- a/api/server/services/PluginService.js +++ b/api/server/services/PluginService.js @@ -1,6 +1,6 @@ -const { encrypt, decrypt } = require('~/server/utils/crypto'); -const { PluginAuth } = require('~/db/models'); -const { logger } = require('~/config'); +const { logger } = require('@librechat/data-schemas'); +const { encrypt, decrypt } = require('@librechat/api'); +const { findOnePluginAuth, updatePluginAuth, deletePluginAuth } = require('~/models'); /** * Asynchronously retrieves and decrypts the authentication value for a user's plugin, based on a specified authentication field. @@ -25,7 +25,7 @@ const { logger } = require('~/config'); */ const getUserPluginAuthValue = async (userId, authField, throwError = true) => { try { - const pluginAuth = await PluginAuth.findOne({ userId, authField }).lean(); + const pluginAuth = await findOnePluginAuth({ userId, authField }); if (!pluginAuth) { throw new Error(`No plugin auth ${authField} found for user ${userId}`); } @@ -79,23 +79,12 @@ const getUserPluginAuthValue = async (userId, authField, throwError = true) => { const updateUserPluginAuth = async (userId, authField, pluginKey, value) => { try { const encryptedValue = await encrypt(value); - const pluginAuth = await PluginAuth.findOne({ userId, authField }).lean(); - if (pluginAuth) { - return await PluginAuth.findOneAndUpdate( - { userId, authField }, - { $set: { value: encryptedValue } }, - { new: true, upsert: true }, - ).lean(); - } else { - const newPluginAuth = await new PluginAuth({ - userId, - authField, - value: encryptedValue, - pluginKey, - }); - await newPluginAuth.save(); - return newPluginAuth.toObject(); - } + return await updatePluginAuth({ + userId, + authField, + pluginKey, + value: encryptedValue, + }); } catch (err) { logger.error('[updateUserPluginAuth]', err); return err; @@ -105,26 +94,25 @@ const updateUserPluginAuth = async (userId, authField, pluginKey, value) => { /** * @async * @param {string} userId - * @param {string} authField - * @param {boolean} [all] + * @param {string | null} authField - The specific authField to delete, or null if `all` is true. + * @param {boolean} [all=false] - Whether to delete all auths for the user (or for a specific pluginKey if provided). + * @param {string} [pluginKey] - Optional. If `all` is true and `pluginKey` is provided, delete all auths for this user and pluginKey. * @returns {Promise} * @throws {Error} */ -const deleteUserPluginAuth = async (userId, authField, all = false) => { - if (all) { - try { - const response = await PluginAuth.deleteMany({ userId }); - return response; - } catch (err) { - logger.error('[deleteUserPluginAuth]', err); - return err; - } - } - +const deleteUserPluginAuth = async (userId, authField, all = false, pluginKey) => { try { - return await PluginAuth.deleteOne({ userId, authField }); + return await deletePluginAuth({ + userId, + authField, + pluginKey, + all, + }); } catch (err) { - logger.error('[deleteUserPluginAuth]', err); + logger.error( + `[deleteUserPluginAuth] Error deleting ${all ? 'all' : 'single'} auth(s) for userId: ${userId}${pluginKey ? ` and pluginKey: ${pluginKey}` : ''}`, + err, + ); return err; } }; diff --git a/api/server/services/Runs/methods.js b/api/server/services/Runs/methods.js index 3c18e9969b..167b9cc2ba 100644 --- a/api/server/services/Runs/methods.js +++ b/api/server/services/Runs/methods.js @@ -1,6 +1,6 @@ const axios = require('axios'); +const { logAxiosError } = require('@librechat/api'); const { EModelEndpoint } = require('librechat-data-provider'); -const { logAxiosError } = require('~/utils'); /** * @typedef {Object} RetrieveOptions diff --git a/api/server/services/TokenService.js b/api/server/services/TokenService.js deleted file mode 100644 index 3dd2e79ffa..0000000000 --- a/api/server/services/TokenService.js +++ /dev/null @@ -1,172 +0,0 @@ -const axios = require('axios'); -const { handleOAuthToken } = require('~/models/Token'); -const { decryptV2 } = require('~/server/utils/crypto'); -const { logAxiosError } = require('~/utils'); -const { logger } = require('~/config'); - -/** - * Processes the access tokens and stores them in the database. - * @param {object} tokenData - * @param {string} tokenData.access_token - * @param {number} tokenData.expires_in - * @param {string} [tokenData.refresh_token] - * @param {number} [tokenData.refresh_token_expires_in] - * @param {object} metadata - * @param {string} metadata.userId - * @param {string} metadata.identifier - * @returns {Promise} - */ -async function processAccessTokens(tokenData, { userId, identifier }) { - const { access_token, expires_in = 3600, refresh_token, refresh_token_expires_in } = tokenData; - if (!access_token) { - logger.error('Access token not found: ', tokenData); - throw new Error('Access token not found'); - } - await handleOAuthToken({ - identifier, - token: access_token, - expiresIn: expires_in, - userId, - }); - - if (refresh_token != null) { - logger.debug('Processing refresh token'); - await handleOAuthToken({ - token: refresh_token, - type: 'oauth_refresh', - userId, - identifier: `${identifier}:refresh`, - expiresIn: refresh_token_expires_in ?? null, - }); - } - logger.debug('Access tokens processed'); -} - -/** - * Refreshes the access token using the refresh token. - * @param {object} fields - * @param {string} fields.userId - The ID of the user. - * @param {string} fields.client_url - The URL of the OAuth provider. - * @param {string} fields.identifier - The identifier for the token. - * @param {string} fields.refresh_token - The refresh token to use. - * @param {string} fields.encrypted_oauth_client_id - The client ID for the OAuth provider. - * @param {string} fields.encrypted_oauth_client_secret - The client secret for the OAuth provider. - * @returns {Promise<{ - * access_token: string, - * expires_in: number, - * refresh_token?: string, - * refresh_token_expires_in?: number, - * }>} - */ -const refreshAccessToken = async ({ - userId, - client_url, - identifier, - refresh_token, - encrypted_oauth_client_id, - encrypted_oauth_client_secret, -}) => { - try { - const oauth_client_id = await decryptV2(encrypted_oauth_client_id); - const oauth_client_secret = await decryptV2(encrypted_oauth_client_secret); - const params = new URLSearchParams({ - client_id: oauth_client_id, - client_secret: oauth_client_secret, - grant_type: 'refresh_token', - refresh_token, - }); - - const response = await axios({ - method: 'POST', - url: client_url, - headers: { - 'Content-Type': 'application/x-www-form-urlencoded', - Accept: 'application/json', - }, - data: params.toString(), - }); - await processAccessTokens(response.data, { - userId, - identifier, - }); - logger.debug(`Access token refreshed successfully for ${identifier}`); - return response.data; - } catch (error) { - const message = 'Error refreshing OAuth tokens'; - throw new Error( - logAxiosError({ - message, - error, - }), - ); - } -}; - -/** - * Handles the OAuth callback and exchanges the authorization code for tokens. - * @param {object} fields - * @param {string} fields.code - The authorization code returned by the provider. - * @param {string} fields.userId - The ID of the user. - * @param {string} fields.identifier - The identifier for the token. - * @param {string} fields.client_url - The URL of the OAuth provider. - * @param {string} fields.redirect_uri - The redirect URI for the OAuth provider. - * @param {string} fields.encrypted_oauth_client_id - The client ID for the OAuth provider. - * @param {string} fields.encrypted_oauth_client_secret - The client secret for the OAuth provider. - * @returns {Promise<{ - * access_token: string, - * expires_in: number, - * refresh_token?: string, - * refresh_token_expires_in?: number, - * }>} - */ -const getAccessToken = async ({ - code, - userId, - identifier, - client_url, - redirect_uri, - encrypted_oauth_client_id, - encrypted_oauth_client_secret, -}) => { - const oauth_client_id = await decryptV2(encrypted_oauth_client_id); - const oauth_client_secret = await decryptV2(encrypted_oauth_client_secret); - const params = new URLSearchParams({ - code, - client_id: oauth_client_id, - client_secret: oauth_client_secret, - grant_type: 'authorization_code', - redirect_uri, - }); - - try { - const response = await axios({ - method: 'POST', - url: client_url, - headers: { - 'Content-Type': 'application/x-www-form-urlencoded', - Accept: 'application/json', - }, - data: params.toString(), - }); - - await processAccessTokens(response.data, { - userId, - identifier, - }); - logger.debug(`Access tokens successfully created for ${identifier}`); - return response.data; - } catch (error) { - const message = 'Error exchanging OAuth code'; - throw new Error( - logAxiosError({ - message, - error, - }), - ); - } -}; - -module.exports = { - getAccessToken, - refreshAccessToken, -}; diff --git a/api/server/services/ToolService.js b/api/server/services/ToolService.js index 8dd2fbf865..f1567a3783 100644 --- a/api/server/services/ToolService.js +++ b/api/server/services/ToolService.js @@ -1,5 +1,7 @@ const fs = require('fs'); const path = require('path'); +const { sleep } = require('@librechat/agents'); +const { logger } = require('@librechat/data-schemas'); const { zodToJsonSchema } = require('zod-to-json-schema'); const { Calculator } = require('@langchain/community/tools/calculator'); const { tool: toolFn, Tool, DynamicStructuredTool } = require('@langchain/core/tools'); @@ -31,14 +33,12 @@ const { toolkits, } = require('~/app/clients/tools'); const { processFileURL, uploadImageBuffer } = require('~/server/services/Files/process'); +const { getEndpointsConfig, getCachedTools } = require('~/server/services/Config'); const { createOnSearchResults } = require('~/server/services/Tools/search'); const { isActionDomainAllowed } = require('~/server/services/domains'); -const { getEndpointsConfig } = require('~/server/services/Config'); const { recordUsage } = require('~/server/services/Threads'); const { loadTools } = require('~/app/clients/tools/util'); const { redactMessage } = require('~/config/parsers'); -const { sleep } = require('~/server/utils'); -const { logger } = require('~/config'); /** * @param {string} toolName @@ -226,7 +226,7 @@ async function processRequiredActions(client, requiredActions) { `[required actions] user: ${client.req.user.id} | thread_id: ${requiredActions[0].thread_id} | run_id: ${requiredActions[0].run_id}`, requiredActions, ); - const toolDefinitions = client.req.app.locals.availableTools; + const toolDefinitions = await getCachedTools({ includeGlobal: true }); const seenToolkits = new Set(); const tools = requiredActions .map((action) => { @@ -500,6 +500,8 @@ async function processRequiredActions(client, requiredActions) { async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey }) { if (!agent.tools || agent.tools.length === 0) { return {}; + } else if (agent.tools && agent.tools.length === 1 && agent.tools[0] === AgentCapabilities.ocr) { + return {}; } const endpointsConfig = await getEndpointsConfig(req); @@ -551,6 +553,7 @@ async function loadAgentTools({ req, res, agent, tool_resources, openAIApiKey }) tools: _agentTools, options: { req, + res, openAIApiKey, tool_resources, processFileURL, diff --git a/api/server/services/UserService.js b/api/server/services/UserService.js index b729607f69..7cf2f832a3 100644 --- a/api/server/services/UserService.js +++ b/api/server/services/UserService.js @@ -1,6 +1,6 @@ const { logger } = require('@librechat/data-schemas'); +const { encrypt, decrypt } = require('@librechat/api'); const { ErrorTypes } = require('librechat-data-provider'); -const { encrypt, decrypt } = require('~/server/utils/crypto'); const { updateUser } = require('~/models'); const { Key } = require('~/db/models'); @@ -70,6 +70,7 @@ const getUserKeyValues = async ({ userId, name }) => { try { userValues = JSON.parse(userValues); } catch (e) { + logger.error('[getUserKeyValues]', e); throw new Error( JSON.stringify({ type: ErrorTypes.INVALID_USER_KEY, diff --git a/api/server/services/initializeMCP.js b/api/server/services/initializeMCP.js new file mode 100644 index 0000000000..d7c5ab7d8a --- /dev/null +++ b/api/server/services/initializeMCP.js @@ -0,0 +1,54 @@ +const { logger } = require('@librechat/data-schemas'); +const { CacheKeys, processMCPEnv } = require('librechat-data-provider'); +const { getMCPManager, getFlowStateManager } = require('~/config'); +const { getCachedTools, setCachedTools } = require('./Config'); +const { getLogStores } = require('~/cache'); +const { findToken, updateToken, createToken, deleteTokens } = require('~/models'); + +/** + * Initialize MCP servers + * @param {import('express').Application} app - Express app instance + */ +async function initializeMCP(app) { + const mcpServers = app.locals.mcpConfig; + if (!mcpServers) { + return; + } + + logger.info('Initializing MCP servers...'); + const mcpManager = getMCPManager(); + const flowsCache = getLogStores(CacheKeys.FLOWS); + const flowManager = flowsCache ? getFlowStateManager(flowsCache) : null; + + try { + await mcpManager.initializeMCP({ + mcpServers, + flowManager, + tokenMethods: { + findToken, + updateToken, + createToken, + deleteTokens, + }, + processMCPEnv, + }); + + delete app.locals.mcpConfig; + const availableTools = await getCachedTools(); + + if (!availableTools) { + logger.warn('No available tools found in cache during MCP initialization'); + return; + } + + const toolsCopy = { ...availableTools }; + await mcpManager.mapAvailableTools(toolsCopy, flowManager); + await setCachedTools(toolsCopy, { isGlobal: true }); + + logger.info('MCP servers initialized successfully'); + } catch (error) { + logger.error('Failed to initialize MCP servers:', error); + } +} + +module.exports = initializeMCP; diff --git a/api/server/services/start/agents.js b/api/server/services/start/agents.js deleted file mode 100644 index 10653f3fb6..0000000000 --- a/api/server/services/start/agents.js +++ /dev/null @@ -1,14 +0,0 @@ -const { EModelEndpoint, agentsEndpointSChema } = require('librechat-data-provider'); - -/** - * Sets up the Agents configuration from the config (`librechat.yaml`) file. - * @param {TCustomConfig} config - The loaded custom configuration. - * @returns {Partial} The Agents endpoint configuration. - */ -function agentsConfigSetup(config) { - const agentsConfig = config.endpoints[EModelEndpoint.agents]; - const parsedConfig = agentsEndpointSChema.parse(agentsConfig); - return parsedConfig; -} - -module.exports = { agentsConfigSetup }; diff --git a/api/server/services/start/interface.js b/api/server/services/start/interface.js index 7578c036b2..c98fdb60bc 100644 --- a/api/server/services/start/interface.js +++ b/api/server/services/start/interface.js @@ -2,6 +2,7 @@ const { SystemRoles, Permissions, PermissionTypes, + isMemoryEnabled, removeNullishValues, } = require('librechat-data-provider'); const { updateAccessPermissions } = require('~/models/Role'); @@ -20,6 +21,14 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol const hasModelSpecs = config?.modelSpecs?.list?.length > 0; const includesAddedEndpoints = config?.modelSpecs?.addedEndpoints?.length > 0; + const memoryConfig = config?.memory; + const memoryEnabled = isMemoryEnabled(memoryConfig); + /** Only disable memories if memory config is present but disabled/invalid */ + const shouldDisableMemories = memoryConfig && !memoryEnabled; + /** Check if personalization is enabled (defaults to true if memory is configured and enabled) */ + const isPersonalizationEnabled = + memoryConfig && memoryEnabled && memoryConfig.personalize !== false; + /** @type {TCustomConfig['interface']} */ const loadedInterface = removeNullishValues({ endpointsMenu: @@ -33,6 +42,7 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol privacyPolicy: interfaceConfig?.privacyPolicy ?? defaults.privacyPolicy, termsOfService: interfaceConfig?.termsOfService ?? defaults.termsOfService, bookmarks: interfaceConfig?.bookmarks ?? defaults.bookmarks, + memories: shouldDisableMemories ? false : (interfaceConfig?.memories ?? defaults.memories), prompts: interfaceConfig?.prompts ?? defaults.prompts, multiConvo: interfaceConfig?.multiConvo ?? defaults.multiConvo, agents: interfaceConfig?.agents ?? defaults.agents, @@ -45,6 +55,10 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol await updateAccessPermissions(roleName, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: loadedInterface.prompts }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: loadedInterface.bookmarks }, + [PermissionTypes.MEMORIES]: { + [Permissions.USE]: loadedInterface.memories, + [Permissions.OPT_OUT]: isPersonalizationEnabled, + }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: loadedInterface.multiConvo }, [PermissionTypes.AGENTS]: { [Permissions.USE]: loadedInterface.agents }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: loadedInterface.temporaryChat }, @@ -54,6 +68,10 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol await updateAccessPermissions(SystemRoles.ADMIN, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: loadedInterface.prompts }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: loadedInterface.bookmarks }, + [PermissionTypes.MEMORIES]: { + [Permissions.USE]: loadedInterface.memories, + [Permissions.OPT_OUT]: isPersonalizationEnabled, + }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: loadedInterface.multiConvo }, [PermissionTypes.AGENTS]: { [Permissions.USE]: loadedInterface.agents }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: loadedInterface.temporaryChat }, diff --git a/api/server/services/start/interface.spec.js b/api/server/services/start/interface.spec.js index d0dcfaf55f..1a05c9cf12 100644 --- a/api/server/services/start/interface.spec.js +++ b/api/server/services/start/interface.spec.js @@ -12,6 +12,7 @@ describe('loadDefaultInterface', () => { interface: { prompts: true, bookmarks: true, + memories: true, multiConvo: true, agents: true, temporaryChat: true, @@ -26,6 +27,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: true }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: true }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true }, [PermissionTypes.AGENTS]: { [Permissions.USE]: true }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true }, @@ -39,6 +41,7 @@ describe('loadDefaultInterface', () => { interface: { prompts: false, bookmarks: false, + memories: false, multiConvo: false, agents: false, temporaryChat: false, @@ -53,6 +56,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: false }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: false }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: false }, [PermissionTypes.AGENTS]: { [Permissions.USE]: false }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: false }, @@ -70,6 +74,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: undefined }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: undefined }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, @@ -83,6 +88,7 @@ describe('loadDefaultInterface', () => { interface: { prompts: undefined, bookmarks: undefined, + memories: undefined, multiConvo: undefined, agents: undefined, temporaryChat: undefined, @@ -97,6 +103,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: undefined }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: undefined }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, @@ -110,6 +117,7 @@ describe('loadDefaultInterface', () => { interface: { prompts: true, bookmarks: false, + memories: true, multiConvo: undefined, agents: true, temporaryChat: undefined, @@ -124,6 +132,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: true }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: true }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined }, [PermissionTypes.AGENTS]: { [Permissions.USE]: true }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, @@ -138,6 +147,7 @@ describe('loadDefaultInterface', () => { interface: { prompts: true, bookmarks: true, + memories: true, multiConvo: true, agents: true, temporaryChat: true, @@ -151,6 +161,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: true }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: true }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true }, [PermissionTypes.AGENTS]: { [Permissions.USE]: true }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true }, @@ -168,6 +179,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: undefined }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: undefined }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, @@ -185,6 +197,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: undefined }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: undefined }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: false }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, @@ -202,6 +215,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: undefined }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: undefined }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: undefined }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: undefined }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, @@ -215,6 +229,7 @@ describe('loadDefaultInterface', () => { interface: { prompts: true, bookmarks: false, + memories: true, multiConvo: true, agents: false, temporaryChat: true, @@ -228,6 +243,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: true }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: true }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true }, [PermissionTypes.AGENTS]: { [Permissions.USE]: false }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true }, @@ -242,6 +258,7 @@ describe('loadDefaultInterface', () => { interface: { prompts: true, bookmarks: true, + memories: false, multiConvo: false, agents: undefined, temporaryChat: undefined, @@ -255,6 +272,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: true }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: false }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: false }, [PermissionTypes.AGENTS]: { [Permissions.USE]: undefined }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: undefined }, @@ -268,6 +286,7 @@ describe('loadDefaultInterface', () => { interface: { prompts: true, bookmarks: false, + memories: true, multiConvo: true, agents: false, temporaryChat: true, @@ -281,6 +300,7 @@ describe('loadDefaultInterface', () => { expect(updateAccessPermissions).toHaveBeenCalledWith(SystemRoles.USER, { [PermissionTypes.PROMPTS]: { [Permissions.USE]: true }, [PermissionTypes.BOOKMARKS]: { [Permissions.USE]: false }, + [PermissionTypes.MEMORIES]: { [Permissions.USE]: true }, [PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true }, [PermissionTypes.AGENTS]: { [Permissions.USE]: false }, [PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true }, diff --git a/api/server/services/twoFactorService.js b/api/server/services/twoFactorService.js index 0274842367..4ac86a5549 100644 --- a/api/server/services/twoFactorService.js +++ b/api/server/services/twoFactorService.js @@ -1,5 +1,5 @@ const { webcrypto } = require('node:crypto'); -const { hashBackupCode, decryptV3, decryptV2 } = require('~/server/utils/crypto'); +const { hashBackupCode, decryptV3, decryptV2 } = require('@librechat/api'); const { updateUser } = require('~/models'); // Base32 alphabet for TOTP secret encoding. diff --git a/api/server/utils/handleText.js b/api/server/utils/handleText.js index 86c17f1dda..680da5da44 100644 --- a/api/server/utils/handleText.js +++ b/api/server/utils/handleText.js @@ -1,5 +1,3 @@ -const path = require('path'); -const crypto = require('crypto'); const { Capabilities, EModelEndpoint, @@ -218,38 +216,6 @@ function normalizeEndpointName(name = '') { return name.toLowerCase() === Providers.OLLAMA ? Providers.OLLAMA : name; } -/** - * Sanitize a filename by removing any directory components, replacing non-alphanumeric characters - * @param {string} inputName - * @returns {string} - */ -function sanitizeFilename(inputName) { - // Remove any directory components - let name = path.basename(inputName); - - // Replace any non-alphanumeric characters except for '.' and '-' - name = name.replace(/[^a-zA-Z0-9.-]/g, '_'); - - // Ensure the name doesn't start with a dot (hidden file in Unix-like systems) - if (name.startsWith('.') || name === '') { - name = '_' + name; - } - - // Limit the length of the filename - const MAX_LENGTH = 255; - if (name.length > MAX_LENGTH) { - const ext = path.extname(name); - const nameWithoutExt = path.basename(name, ext); - name = - nameWithoutExt.slice(0, MAX_LENGTH - ext.length - 7) + - '-' + - crypto.randomBytes(3).toString('hex') + - ext; - } - - return name; -} - module.exports = { isEnabled, handleText, @@ -260,6 +226,5 @@ module.exports = { generateConfig, addSpaceIfNeeded, createOnProgress, - sanitizeFilename, normalizeEndpointName, }; diff --git a/api/server/utils/handleText.spec.js b/api/server/utils/handleText.spec.js deleted file mode 100644 index 2cd6c51f91..0000000000 --- a/api/server/utils/handleText.spec.js +++ /dev/null @@ -1,103 +0,0 @@ -const { isEnabled, sanitizeFilename } = require('./handleText'); - -describe('isEnabled', () => { - test('should return true when input is "true"', () => { - expect(isEnabled('true')).toBe(true); - }); - - test('should return true when input is "TRUE"', () => { - expect(isEnabled('TRUE')).toBe(true); - }); - - test('should return true when input is true', () => { - expect(isEnabled(true)).toBe(true); - }); - - test('should return false when input is "false"', () => { - expect(isEnabled('false')).toBe(false); - }); - - test('should return false when input is false', () => { - expect(isEnabled(false)).toBe(false); - }); - - test('should return false when input is null', () => { - expect(isEnabled(null)).toBe(false); - }); - - test('should return false when input is undefined', () => { - expect(isEnabled()).toBe(false); - }); - - test('should return false when input is an empty string', () => { - expect(isEnabled('')).toBe(false); - }); - - test('should return false when input is a whitespace string', () => { - expect(isEnabled(' ')).toBe(false); - }); - - test('should return false when input is a number', () => { - expect(isEnabled(123)).toBe(false); - }); - - test('should return false when input is an object', () => { - expect(isEnabled({})).toBe(false); - }); - - test('should return false when input is an array', () => { - expect(isEnabled([])).toBe(false); - }); -}); - -jest.mock('crypto', () => { - const actualModule = jest.requireActual('crypto'); - return { - ...actualModule, - randomBytes: jest.fn().mockReturnValue(Buffer.from('abc123', 'hex')), - }; -}); - -describe('sanitizeFilename', () => { - test('removes directory components (1/2)', () => { - expect(sanitizeFilename('/path/to/file.txt')).toBe('file.txt'); - }); - - test('removes directory components (2/2)', () => { - expect(sanitizeFilename('../../../../file.txt')).toBe('file.txt'); - }); - - test('replaces non-alphanumeric characters', () => { - expect(sanitizeFilename('file name@#$.txt')).toBe('file_name___.txt'); - }); - - test('preserves dots and hyphens', () => { - expect(sanitizeFilename('file-name.with.dots.txt')).toBe('file-name.with.dots.txt'); - }); - - test('prepends underscore to filenames starting with a dot', () => { - expect(sanitizeFilename('.hiddenfile')).toBe('_.hiddenfile'); - }); - - test('truncates long filenames', () => { - const longName = 'a'.repeat(300) + '.txt'; - const result = sanitizeFilename(longName); - expect(result.length).toBe(255); - expect(result).toMatch(/^a+-abc123\.txt$/); - }); - - test('handles filenames with no extension', () => { - const longName = 'a'.repeat(300); - const result = sanitizeFilename(longName); - expect(result.length).toBe(255); - expect(result).toMatch(/^a+-abc123$/); - }); - - test('handles empty input', () => { - expect(sanitizeFilename('')).toBe('_'); - }); - - test('handles input with only special characters', () => { - expect(sanitizeFilename('@#$%^&*')).toBe('_______'); - }); -}); diff --git a/api/server/utils/index.js b/api/server/utils/index.js index b79b42f00d..2661ff75e1 100644 --- a/api/server/utils/index.js +++ b/api/server/utils/index.js @@ -3,7 +3,6 @@ const removePorts = require('./removePorts'); const countTokens = require('./countTokens'); const handleText = require('./handleText'); const sendEmail = require('./sendEmail'); -const cryptoUtils = require('./crypto'); const queue = require('./queue'); const files = require('./files'); const math = require('./math'); @@ -13,18 +12,24 @@ const math = require('./math'); * @returns {Boolean} */ function checkEmailConfig() { - return ( + // Check if Mailgun is configured + const hasMailgunConfig = + !!process.env.MAILGUN_API_KEY && !!process.env.MAILGUN_DOMAIN && !!process.env.EMAIL_FROM; + + // Check if SMTP is configured + const hasSMTPConfig = (!!process.env.EMAIL_SERVICE || !!process.env.EMAIL_HOST) && !!process.env.EMAIL_USERNAME && !!process.env.EMAIL_PASSWORD && - !!process.env.EMAIL_FROM - ); + !!process.env.EMAIL_FROM; + + // Return true if either Mailgun or SMTP is properly configured + return hasMailgunConfig || hasSMTPConfig; } module.exports = { ...streamResponse, checkEmailConfig, - ...cryptoUtils, ...handleText, countTokens, removePorts, diff --git a/api/server/utils/sendEmail.js b/api/server/utils/sendEmail.js index 59d75830f4..c0afd0eebe 100644 --- a/api/server/utils/sendEmail.js +++ b/api/server/utils/sendEmail.js @@ -1,9 +1,69 @@ const fs = require('fs'); const path = require('path'); +const axios = require('axios'); +const FormData = require('form-data'); const nodemailer = require('nodemailer'); const handlebars = require('handlebars'); +const { logAxiosError } = require('@librechat/api'); +const { logger } = require('@librechat/data-schemas'); const { isEnabled } = require('~/server/utils/handleText'); -const logger = require('~/config/winston'); + +/** + * Sends an email using Mailgun API. + * + * @async + * @function sendEmailViaMailgun + * @param {Object} params - The parameters for sending the email. + * @param {string} params.to - The recipient's email address. + * @param {string} params.from - The sender's email address. + * @param {string} params.subject - The subject of the email. + * @param {string} params.html - The HTML content of the email. + * @returns {Promise} - A promise that resolves to the response from Mailgun API. + */ +const sendEmailViaMailgun = async ({ to, from, subject, html }) => { + const mailgunApiKey = process.env.MAILGUN_API_KEY; + const mailgunDomain = process.env.MAILGUN_DOMAIN; + const mailgunHost = process.env.MAILGUN_HOST || 'https://api.mailgun.net'; + + if (!mailgunApiKey || !mailgunDomain) { + throw new Error('Mailgun API key and domain are required'); + } + + const formData = new FormData(); + formData.append('from', from); + formData.append('to', to); + formData.append('subject', subject); + formData.append('html', html); + formData.append('o:tracking-clicks', 'no'); + + try { + const response = await axios.post(`${mailgunHost}/v3/${mailgunDomain}/messages`, formData, { + headers: { + ...formData.getHeaders(), + Authorization: `Basic ${Buffer.from(`api:${mailgunApiKey}`).toString('base64')}`, + }, + }); + + return response.data; + } catch (error) { + throw new Error(logAxiosError({ error, message: 'Failed to send email via Mailgun' })); + } +}; + +/** + * Sends an email using SMTP via Nodemailer. + * + * @async + * @function sendEmailViaSMTP + * @param {Object} params - The parameters for sending the email. + * @param {Object} params.transporterOptions - The transporter configuration options. + * @param {Object} params.mailOptions - The email options. + * @returns {Promise} - A promise that resolves to the info object of the sent email. + */ +const sendEmailViaSMTP = async ({ transporterOptions, mailOptions }) => { + const transporter = nodemailer.createTransport(transporterOptions); + return await transporter.sendMail(mailOptions); +}; /** * Sends an email using the specified template, subject, and payload. @@ -34,6 +94,30 @@ const logger = require('~/config/winston'); */ const sendEmail = async ({ email, subject, payload, template, throwError = true }) => { try { + // Read and compile the email template + const source = fs.readFileSync(path.join(__dirname, 'emails', template), 'utf8'); + const compiledTemplate = handlebars.compile(source); + const html = compiledTemplate(payload); + + // Prepare common email data + const fromName = process.env.EMAIL_FROM_NAME || process.env.APP_TITLE; + const fromEmail = process.env.EMAIL_FROM; + const fromAddress = `"${fromName}" <${fromEmail}>`; + const toAddress = `"${payload.name}" <${email}>`; + + // Check if Mailgun is configured + if (process.env.MAILGUN_API_KEY && process.env.MAILGUN_DOMAIN) { + logger.debug('[sendEmail] Using Mailgun provider'); + return await sendEmailViaMailgun({ + from: fromAddress, + to: toAddress, + subject: subject, + html: html, + }); + } + + // Default to SMTP + logger.debug('[sendEmail] Using SMTP provider'); const transporterOptions = { // Use STARTTLS by default instead of obligatory TLS secure: process.env.EMAIL_ENCRYPTION === 'tls', @@ -62,30 +146,21 @@ const sendEmail = async ({ email, subject, payload, template, throwError = true transporterOptions.port = process.env.EMAIL_PORT ?? 25; } - const transporter = nodemailer.createTransport(transporterOptions); - - const source = fs.readFileSync(path.join(__dirname, 'emails', template), 'utf8'); - const compiledTemplate = handlebars.compile(source); - const options = () => { - return { - // Header address should contain name-addr - from: - `"${process.env.EMAIL_FROM_NAME || process.env.APP_TITLE}"` + - `<${process.env.EMAIL_FROM}>`, - to: `"${payload.name}" <${email}>`, - envelope: { - // Envelope from should contain addr-spec - // Mistake in the Nodemailer documentation? - from: process.env.EMAIL_FROM, - to: email, - }, - subject: subject, - html: compiledTemplate(payload), - }; + const mailOptions = { + // Header address should contain name-addr + from: fromAddress, + to: toAddress, + envelope: { + // Envelope from should contain addr-spec + // Mistake in the Nodemailer documentation? + from: fromEmail, + to: email, + }, + subject: subject, + html: html, }; - // Send email - return await transporter.sendMail(options()); + return await sendEmailViaSMTP({ transporterOptions, mailOptions }); } catch (error) { if (throwError) { throw error; diff --git a/api/strategies/localStrategy.js b/api/strategies/localStrategy.js index edc749ee9e..bc84e7c6b5 100644 --- a/api/strategies/localStrategy.js +++ b/api/strategies/localStrategy.js @@ -29,6 +29,12 @@ async function passportLogin(req, email, password, done) { return done(null, false, { message: 'Email does not exist.' }); } + if (!user.password) { + logError('Passport Local Strategy - User has no password', { email }); + logger.error(`[Login] [Login failed] [Username: ${email}] [Request-IP: ${req.ip}]`); + return done(null, false, { message: 'Email does not exist.' }); + } + const isMatch = await comparePassword(user, password); if (!isMatch) { logError('Passport Local Strategy - Password does not match', { isMatch }); diff --git a/api/strategies/openidStrategy.js b/api/strategies/openidStrategy.js index 1d0a6bc5e6..2449872a9d 100644 --- a/api/strategies/openidStrategy.js +++ b/api/strategies/openidStrategy.js @@ -1,3 +1,4 @@ +const undici = require('undici'); const fetch = require('node-fetch'); const passport = require('passport'); const client = require('openid-client'); @@ -6,17 +7,87 @@ const { CacheKeys } = require('librechat-data-provider'); const { HttpsProxyAgent } = require('https-proxy-agent'); const { hashToken, logger } = require('@librechat/data-schemas'); const { Strategy: OpenIDStrategy } = require('openid-client/passport'); +const { isEnabled, safeStringify, logHeaders } = require('@librechat/api'); const { getStrategyFunctions } = require('~/server/services/Files/strategies'); const { findUser, createUser, updateUser } = require('~/models'); const { getBalanceConfig } = require('~/server/services/Config'); const getLogStores = require('~/cache/getLogStores'); -const { isEnabled } = require('~/server/utils'); /** * @typedef {import('openid-client').ClientMetadata} ClientMetadata * @typedef {import('openid-client').Configuration} Configuration **/ +/** + * @param {string} url + * @param {client.CustomFetchOptions} options + */ +async function customFetch(url, options) { + const urlStr = url.toString(); + logger.debug(`[openidStrategy] Request to: ${urlStr}`); + const debugOpenId = isEnabled(process.env.DEBUG_OPENID_REQUESTS); + if (debugOpenId) { + logger.debug(`[openidStrategy] Request method: ${options.method || 'GET'}`); + logger.debug(`[openidStrategy] Request headers: ${logHeaders(options.headers)}`); + if (options.body) { + let bodyForLogging = ''; + if (options.body instanceof URLSearchParams) { + bodyForLogging = options.body.toString(); + } else if (typeof options.body === 'string') { + bodyForLogging = options.body; + } else { + bodyForLogging = safeStringify(options.body); + } + logger.debug(`[openidStrategy] Request body: ${bodyForLogging}`); + } + } + + try { + /** @type {undici.RequestInit} */ + let fetchOptions = options; + if (process.env.PROXY) { + logger.info(`[openidStrategy] proxy agent configured: ${process.env.PROXY}`); + fetchOptions = { + ...options, + dispatcher: new HttpsProxyAgent(process.env.PROXY), + }; + } + + const response = await undici.fetch(url, fetchOptions); + + if (debugOpenId) { + logger.debug(`[openidStrategy] Response status: ${response.status} ${response.statusText}`); + logger.debug(`[openidStrategy] Response headers: ${logHeaders(response.headers)}`); + } + + if (response.status === 200 && response.headers.has('www-authenticate')) { + const wwwAuth = response.headers.get('www-authenticate'); + logger.warn(`[openidStrategy] Non-standard WWW-Authenticate header found in successful response (200 OK): ${wwwAuth}. +This violates RFC 7235 and may cause issues with strict OAuth clients. Removing header for compatibility.`); + + /** Cloned response without the WWW-Authenticate header */ + const responseBody = await response.arrayBuffer(); + const newHeaders = new Headers(); + for (const [key, value] of response.headers.entries()) { + if (key.toLowerCase() !== 'www-authenticate') { + newHeaders.append(key, value); + } + } + + return new Response(responseBody, { + status: response.status, + statusText: response.statusText, + headers: newHeaders, + }); + } + + return response; + } catch (error) { + logger.error(`[openidStrategy] Fetch error: ${error.message}`); + throw error; + } +} + /** @typedef {Configuration | null} */ let openidConfig = null; @@ -208,14 +279,12 @@ async function setupOpenId() { new URL(process.env.OPENID_ISSUER), process.env.OPENID_CLIENT_ID, clientMetadata, + undefined, + { + [client.customFetch]: customFetch, + }, ); - if (process.env.PROXY) { - const proxyAgent = new HttpsProxyAgent(process.env.PROXY); - openidConfig[client.customFetch] = (...args) => { - return fetch(args[0], { ...args[1], agent: proxyAgent }); - }; - logger.info(`[openidStrategy] proxy agent added: ${process.env.PROXY}`); - } + const requiredRole = process.env.OPENID_REQUIRED_ROLE; const requiredRoleParameterPath = process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH; const requiredRoleTokenKind = process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND; diff --git a/api/strategies/openidStrategy.spec.js b/api/strategies/openidStrategy.spec.js index 3e52ad01f1..1e6750384e 100644 --- a/api/strategies/openidStrategy.spec.js +++ b/api/strategies/openidStrategy.spec.js @@ -21,19 +21,18 @@ jest.mock('~/models', () => ({ createUser: jest.fn(), updateUser: jest.fn(), })); -jest.mock('~/server/utils/crypto', () => ({ - hashToken: jest.fn().mockResolvedValue('hashed-token'), -})); -jest.mock('~/server/utils', () => ({ +jest.mock('@librechat/api', () => ({ + ...jest.requireActual('@librechat/api'), isEnabled: jest.fn(() => false), })); -jest.mock('~/config', () => ({ +jest.mock('@librechat/data-schemas', () => ({ + ...jest.requireActual('@librechat/api'), logger: { info: jest.fn(), debug: jest.fn(), error: jest.fn(), - warn: jest.fn(), }, + hashToken: jest.fn().mockResolvedValue('hashed-token'), })); jest.mock('~/cache/getLogStores', () => jest.fn(() => ({ diff --git a/api/strategies/process.js b/api/strategies/process.js index 774d8d015e..1f7e7c81d2 100644 --- a/api/strategies/process.js +++ b/api/strategies/process.js @@ -31,7 +31,7 @@ const handleExistingUser = async (oldUser, avatarUrl) => { input: avatarUrl, }); const { processAvatar } = getStrategyFunctions(fileStrategy); - updatedAvatar = await processAvatar({ buffer: resizedBuffer, userId }); + updatedAvatar = await processAvatar({ buffer: resizedBuffer, userId, manual: 'false' }); } if (updatedAvatar) { @@ -90,7 +90,11 @@ const createSocialUser = async ({ input: avatarUrl, }); const { processAvatar } = getStrategyFunctions(fileStrategy); - const avatar = await processAvatar({ buffer: resizedBuffer, userId: newUserId }); + const avatar = await processAvatar({ + buffer: resizedBuffer, + userId: newUserId, + manual: 'false', + }); await updateUser(newUserId, { avatar }); } diff --git a/api/strategies/samlStrategy.spec.js b/api/strategies/samlStrategy.spec.js index 675bdc998b..fc8329a31a 100644 --- a/api/strategies/samlStrategy.spec.js +++ b/api/strategies/samlStrategy.spec.js @@ -1,15 +1,17 @@ -const fs = require('fs'); -const path = require('path'); -const fetch = require('node-fetch'); -const { Strategy: SamlStrategy } = require('@node-saml/passport-saml'); -const { findUser, createUser, updateUser } = require('~/models'); -const { setupSaml, getCertificateContent } = require('./samlStrategy'); - // --- Mocks --- +jest.mock('tiktoken'); jest.mock('fs'); jest.mock('path'); jest.mock('node-fetch'); jest.mock('@node-saml/passport-saml'); +jest.mock('@librechat/data-schemas', () => ({ + logger: { + info: jest.fn(), + debug: jest.fn(), + error: jest.fn(), + }, + hashToken: jest.fn().mockResolvedValue('hashed-token'), +})); jest.mock('~/models', () => ({ findUser: jest.fn(), createUser: jest.fn(), @@ -29,26 +31,26 @@ jest.mock('~/server/services/Config', () => ({ jest.mock('~/server/services/Config/EndpointService', () => ({ config: {}, })); -jest.mock('~/server/utils', () => ({ - isEnabled: jest.fn(() => false), - isUserProvided: jest.fn(() => false), -})); jest.mock('~/server/services/Files/strategies', () => ({ getStrategyFunctions: jest.fn(() => ({ saveBuffer: jest.fn().mockResolvedValue('/fake/path/to/avatar.png'), })), })); -jest.mock('~/server/utils/crypto', () => ({ - hashToken: jest.fn().mockResolvedValue('hashed-token'), -})); -jest.mock('~/config', () => ({ - logger: { - info: jest.fn(), - debug: jest.fn(), - error: jest.fn(), - }, +jest.mock('~/config/paths', () => ({ + root: '/fake/root/path', })); +const fs = require('fs'); +const path = require('path'); +const fetch = require('node-fetch'); +const { Strategy: SamlStrategy } = require('@node-saml/passport-saml'); +const { setupSaml, getCertificateContent } = require('./samlStrategy'); + +// Configure fs mock +jest.mocked(fs).existsSync = jest.fn(); +jest.mocked(fs).statSync = jest.fn(); +jest.mocked(fs).readFileSync = jest.fn(); + // To capture the verify callback from the strategy, we grab it from the mock constructor let verifyCallback; SamlStrategy.mockImplementation((options, verify) => { diff --git a/api/test/__mocks__/logger.js b/api/test/__mocks__/logger.js index f9f6d78c87..56fb28cbab 100644 --- a/api/test/__mocks__/logger.js +++ b/api/test/__mocks__/logger.js @@ -41,10 +41,7 @@ jest.mock('winston-daily-rotate-file', () => { }); jest.mock('~/config', () => { - const actualModule = jest.requireActual('~/config'); return { - sendEvent: actualModule.sendEvent, - createAxiosInstance: actualModule.createAxiosInstance, logger: { info: jest.fn(), warn: jest.fn(), diff --git a/api/typedefs.js b/api/typedefs.js index 8da5b34809..58cd802425 100644 --- a/api/typedefs.js +++ b/api/typedefs.js @@ -476,11 +476,18 @@ * @memberof typedefs */ +/** + * @exports ToolCallChunk + * @typedef {import('librechat-data-provider').Agents.ToolCallChunk} ToolCallChunk + * @memberof typedefs + */ + /** * @exports MessageContentImageUrl * @typedef {import('librechat-data-provider').Agents.MessageContentImageUrl} MessageContentImageUrl * @memberof typedefs */ + /** Web Search */ /** @@ -1073,7 +1080,7 @@ /** * @exports MCPServers - * @typedef {import('librechat-mcp').MCPServers} MCPServers + * @typedef {import('@librechat/api').MCPServers} MCPServers * @memberof typedefs */ @@ -1085,31 +1092,31 @@ /** * @exports MCPManager - * @typedef {import('librechat-mcp').MCPManager} MCPManager + * @typedef {import('@librechat/api').MCPManager} MCPManager * @memberof typedefs */ /** * @exports FlowStateManager - * @typedef {import('librechat-mcp').FlowStateManager} FlowStateManager + * @typedef {import('@librechat/api').FlowStateManager} FlowStateManager * @memberof typedefs */ /** * @exports LCAvailableTools - * @typedef {import('librechat-mcp').LCAvailableTools} LCAvailableTools + * @typedef {import('@librechat/api').LCAvailableTools} LCAvailableTools * @memberof typedefs */ /** * @exports LCTool - * @typedef {import('librechat-mcp').LCTool} LCTool + * @typedef {import('@librechat/api').LCTool} LCTool * @memberof typedefs */ /** * @exports FormattedContent - * @typedef {import('librechat-mcp').FormattedContent} FormattedContent + * @typedef {import('@librechat/api').FormattedContent} FormattedContent * @memberof typedefs */ @@ -1232,7 +1239,7 @@ * @typedef {Object} AgentClientOptions * @property {Agent} agent - The agent configuration object * @property {string} endpoint - The endpoint identifier for the agent - * @property {Object} req - The request object + * @property {ServerRequest} req - The request object * @property {string} [name] - The username * @property {string} [modelLabel] - The label for the model being used * @property {number} [maxContextTokens] - Maximum number of tokens allowed in context diff --git a/api/utils/axios.js b/api/utils/axios.js deleted file mode 100644 index 91c1fbb223..0000000000 --- a/api/utils/axios.js +++ /dev/null @@ -1,46 +0,0 @@ -const { logger } = require('~/config'); - -/** - * Logs Axios errors based on the error object and a custom message. - * - * @param {Object} options - The options object. - * @param {string} options.message - The custom message to be logged. - * @param {import('axios').AxiosError} options.error - The Axios error object. - * @returns {string} The log message. - */ -const logAxiosError = ({ message, error }) => { - let logMessage = message; - try { - const stack = error.stack || 'No stack trace available'; - - if (error.response?.status) { - const { status, headers, data } = error.response; - logMessage = `${message} The server responded with status ${status}: ${error.message}`; - logger.error(logMessage, { - status, - headers, - data, - stack, - }); - } else if (error.request) { - const { method, url } = error.config || {}; - logMessage = `${message} No response received for ${method ? method.toUpperCase() : ''} ${url || ''}: ${error.message}`; - logger.error(logMessage, { - requestInfo: { method, url }, - stack, - }); - } else if (error?.message?.includes("Cannot read properties of undefined (reading 'status')")) { - logMessage = `${message} It appears the request timed out or was unsuccessful: ${error.message}`; - logger.error(logMessage, { stack }); - } else { - logMessage = `${message} An error occurred while setting up the request: ${error.message}`; - logger.error(logMessage, { stack }); - } - } catch (err) { - logMessage = `Error in logAxiosError: ${err.message}`; - logger.error(logMessage, { stack: err.stack || 'No stack trace available' }); - } - return logMessage; -}; - -module.exports = { logAxiosError }; diff --git a/api/utils/index.js b/api/utils/index.js index 62d61586bf..50b8c46d99 100644 --- a/api/utils/index.js +++ b/api/utils/index.js @@ -1,7 +1,5 @@ const loadYaml = require('./loadYaml'); -const axiosHelpers = require('./axios'); const tokenHelpers = require('./tokens'); -const azureUtils = require('./azureUtils'); const deriveBaseURL = require('./deriveBaseURL'); const extractBaseURL = require('./extractBaseURL'); const findMessageContent = require('./findMessageContent'); @@ -10,8 +8,6 @@ module.exports = { loadYaml, deriveBaseURL, extractBaseURL, - ...azureUtils, - ...axiosHelpers, ...tokenHelpers, findMessageContent, }; diff --git a/client/package.json b/client/package.json index 7cb983d218..67cbec2820 100644 --- a/client/package.json +++ b/client/package.json @@ -65,6 +65,7 @@ "export-from-json": "^1.7.2", "filenamify": "^6.0.0", "framer-motion": "^11.5.4", + "heic-to": "^1.1.14", "html-to-image": "^1.11.11", "i18next": "^24.2.2", "i18next-browser-languagedetector": "^8.0.3", @@ -74,6 +75,7 @@ "lodash": "^4.17.21", "lucide-react": "^0.394.0", "match-sorter": "^6.3.4", + "micromark-extension-llm-math": "^3.1.0", "qrcode.react": "^4.2.0", "rc-input-number": "^7.4.2", "react": "^18.2.0", diff --git a/client/src/Providers/AgentPanelContext.tsx b/client/src/Providers/AgentPanelContext.tsx new file mode 100644 index 0000000000..2cc64ba3ed --- /dev/null +++ b/client/src/Providers/AgentPanelContext.tsx @@ -0,0 +1,97 @@ +import React, { createContext, useContext, useState } from 'react'; +import { Constants, EModelEndpoint } from 'librechat-data-provider'; +import type { TPlugin, AgentToolType, Action, MCP } from 'librechat-data-provider'; +import type { AgentPanelContextType } from '~/common'; +import { useAvailableToolsQuery, useGetActionsQuery } from '~/data-provider'; +import { useLocalize } from '~/hooks'; +import { Panel } from '~/common'; + +const AgentPanelContext = createContext(undefined); + +export function useAgentPanelContext() { + const context = useContext(AgentPanelContext); + if (context === undefined) { + throw new Error('useAgentPanelContext must be used within an AgentPanelProvider'); + } + return context; +} + +/** Houses relevant state for the Agent Form Panels (formerly 'commonProps') */ +export function AgentPanelProvider({ children }: { children: React.ReactNode }) { + const localize = useLocalize(); + const [mcp, setMcp] = useState(undefined); + const [mcps, setMcps] = useState(undefined); + const [action, setAction] = useState(undefined); + const [activePanel, setActivePanel] = useState(Panel.builder); + const [agent_id, setCurrentAgentId] = useState(undefined); + + const { data: actions } = useGetActionsQuery(EModelEndpoint.agents, { + enabled: !!agent_id, + }); + + const { data: pluginTools } = useAvailableToolsQuery(EModelEndpoint.agents, { + enabled: !!agent_id, + }); + + const tools = + pluginTools?.map((tool) => ({ + tool_id: tool.pluginKey, + metadata: tool as TPlugin, + agent_id: agent_id || '', + })) || []; + + const groupedTools = + tools?.reduce( + (acc, tool) => { + if (tool.tool_id.includes(Constants.mcp_delimiter)) { + const [_toolName, serverName] = tool.tool_id.split(Constants.mcp_delimiter); + const groupKey = `${serverName.toLowerCase()}`; + if (!acc[groupKey]) { + acc[groupKey] = { + tool_id: groupKey, + metadata: { + name: `${serverName}`, + pluginKey: groupKey, + description: `${localize('com_ui_tool_collection_prefix')} ${serverName}`, + icon: tool.metadata.icon || '', + } as TPlugin, + agent_id: agent_id || '', + tools: [], + }; + } + acc[groupKey].tools?.push({ + tool_id: tool.tool_id, + metadata: tool.metadata, + agent_id: agent_id || '', + }); + } else { + acc[tool.tool_id] = { + tool_id: tool.tool_id, + metadata: tool.metadata, + agent_id: agent_id || '', + }; + } + return acc; + }, + {} as Record, + ) || {}; + + const value = { + action, + setAction, + mcp, + setMcp, + mcps, + setMcps, + activePanel, + setActivePanel, + setCurrentAgentId, + agent_id, + groupedTools, + /** Query data for actions and tools */ + actions, + tools, + }; + + return {children}; +} diff --git a/client/src/Providers/AgentsContext.tsx b/client/src/Providers/AgentsContext.tsx index e793a3f087..a90a53ecb5 100644 --- a/client/src/Providers/AgentsContext.tsx +++ b/client/src/Providers/AgentsContext.tsx @@ -1,8 +1,8 @@ import { useForm, FormProvider } from 'react-hook-form'; import { createContext, useContext } from 'react'; -import { defaultAgentFormValues } from 'librechat-data-provider'; import type { UseFormReturn } from 'react-hook-form'; import type { AgentForm } from '~/common'; +import { getDefaultAgentFormValues } from '~/utils'; type AgentsContextType = UseFormReturn; @@ -20,7 +20,7 @@ export function useAgentsContext() { export default function AgentsProvider({ children }) { const methods = useForm({ - defaultValues: defaultAgentFormValues, + defaultValues: getDefaultAgentFormValues(), }); return {children}; diff --git a/client/src/Providers/index.ts b/client/src/Providers/index.ts index 00191318e0..41c9cdceb3 100644 --- a/client/src/Providers/index.ts +++ b/client/src/Providers/index.ts @@ -1,6 +1,7 @@ -export { default as ToastProvider } from './ToastContext'; export { default as AssistantsProvider } from './AssistantsContext'; export { default as AgentsProvider } from './AgentsContext'; +export { default as ToastProvider } from './ToastContext'; +export * from './AgentPanelContext'; export * from './ChatContext'; export * from './ShareContext'; export * from './ToastContext'; diff --git a/client/src/common/mcp.ts b/client/src/common/mcp.ts new file mode 100644 index 0000000000..b4f44a1f94 --- /dev/null +++ b/client/src/common/mcp.ts @@ -0,0 +1,26 @@ +import { + AuthorizationTypeEnum, + AuthTypeEnum, + TokenExchangeMethodEnum, +} from 'librechat-data-provider'; +import { MCPForm } from '~/common/types'; + +export const defaultMCPFormValues: MCPForm = { + type: AuthTypeEnum.None, + saved_auth_fields: false, + api_key: '', + authorization_type: AuthorizationTypeEnum.Basic, + custom_auth_header: '', + oauth_client_id: '', + oauth_client_secret: '', + authorization_url: '', + client_url: '', + scope: '', + token_exchange_method: TokenExchangeMethodEnum.DefaultPost, + name: '', + description: '', + url: '', + tools: [], + icon: '', + trust: false, +}; diff --git a/client/src/common/types.ts b/client/src/common/types.ts index 0ac6387c33..214dc349b5 100644 --- a/client/src/common/types.ts +++ b/client/src/common/types.ts @@ -143,6 +143,7 @@ export enum Panel { actions = 'actions', model = 'model', version = 'version', + mcp = 'mcp', } export type FileSetter = @@ -166,6 +167,15 @@ export type ActionAuthForm = { token_exchange_method: t.TokenExchangeMethodEnum; }; +export type MCPForm = ActionAuthForm & { + name?: string; + description?: string; + url?: string; + tools?: string[]; + icon?: string; + trust?: boolean; +}; + export type ActionWithNullableMetadata = Omit & { metadata: t.ActionMetadata | null; }; @@ -188,16 +198,35 @@ export type AgentPanelProps = { index?: number; agent_id?: string; activePanel?: string; + mcp?: t.MCP; + mcps?: t.MCP[]; action?: t.Action; actions?: t.Action[]; createMutation: UseMutationResult; setActivePanel: React.Dispatch>; + setMcp: React.Dispatch>; setAction: React.Dispatch>; endpointsConfig?: t.TEndpointsConfig; setCurrentAgentId: React.Dispatch>; agentsConfig?: t.TAgentsEndpoint | null; }; +export type AgentPanelContextType = { + action?: t.Action; + actions?: t.Action[]; + setAction: React.Dispatch>; + mcp?: t.MCP; + mcps?: t.MCP[]; + setMcp: React.Dispatch>; + setMcps: React.Dispatch>; + groupedTools: Record; + tools: t.AgentToolType[]; + activePanel?: string; + setActivePanel: React.Dispatch>; + setCurrentAgentId: React.Dispatch>; + agent_id?: string; +}; + export type AgentModelPanelProps = { agent_id?: string; providers: Option[]; diff --git a/client/src/components/Artifacts/Artifact.tsx b/client/src/components/Artifacts/Artifact.tsx index 2b06a2ccc0..902ac9191a 100644 --- a/client/src/components/Artifacts/Artifact.tsx +++ b/client/src/components/Artifacts/Artifact.tsx @@ -40,7 +40,7 @@ const defaultType = 'unknown'; const defaultIdentifier = 'lc-no-identifier'; export function Artifact({ - node, + node: _node, ...props }: Artifact & { children: React.ReactNode | { props: { children: React.ReactNode } }; @@ -95,7 +95,7 @@ export function Artifact({ setArtifacts((prevArtifacts) => { if ( prevArtifacts?.[artifactKey] != null && - prevArtifacts[artifactKey].content === content + prevArtifacts[artifactKey]?.content === content ) { return prevArtifacts; } diff --git a/client/src/components/Artifacts/useDebounceCodeBlock.ts b/client/src/components/Artifacts/useDebounceCodeBlock.ts deleted file mode 100644 index 27aaf5bc83..0000000000 --- a/client/src/components/Artifacts/useDebounceCodeBlock.ts +++ /dev/null @@ -1,37 +0,0 @@ -// client/src/hooks/useDebounceCodeBlock.ts -import { useCallback, useEffect } from 'react'; -import debounce from 'lodash/debounce'; -import { useSetRecoilState } from 'recoil'; -import { codeBlocksState, codeBlockIdsState } from '~/store/artifacts'; -import type { CodeBlock } from '~/common'; - -export function useDebounceCodeBlock() { - const setCodeBlocks = useSetRecoilState(codeBlocksState); - const setCodeBlockIds = useSetRecoilState(codeBlockIdsState); - - const updateCodeBlock = useCallback((codeBlock: CodeBlock) => { - console.log('Updating code block:', codeBlock); - setCodeBlocks((prev) => ({ - ...prev, - [codeBlock.id]: codeBlock, - })); - setCodeBlockIds((prev) => - prev.includes(codeBlock.id) ? prev : [...prev, codeBlock.id], - ); - }, [setCodeBlocks, setCodeBlockIds]); - - const debouncedUpdateCodeBlock = useCallback( - debounce((codeBlock: CodeBlock) => { - updateCodeBlock(codeBlock); - }, 25), - [updateCodeBlock], - ); - - useEffect(() => { - return () => { - debouncedUpdateCodeBlock.cancel(); - }; - }, [debouncedUpdateCodeBlock]); - - return debouncedUpdateCodeBlock; -} diff --git a/client/src/components/Bookmarks/DeleteBookmarkButton.tsx b/client/src/components/Bookmarks/DeleteBookmarkButton.tsx index e9dcf0e4d1..911659de74 100644 --- a/client/src/components/Bookmarks/DeleteBookmarkButton.tsx +++ b/client/src/components/Bookmarks/DeleteBookmarkButton.tsx @@ -1,11 +1,10 @@ import { useCallback, useState } from 'react'; import type { FC } from 'react'; -import { Label, OGDialog, OGDialogTrigger, TooltipAnchor } from '~/components/ui'; +import { Button, TrashIcon, Label, OGDialog, OGDialogTrigger, TooltipAnchor } from '~/components'; import { useDeleteConversationTagMutation } from '~/data-provider'; import OGDialogTemplate from '~/components/ui/OGDialogTemplate'; import { NotificationSeverity } from '~/common'; import { useToastContext } from '~/Providers'; -import { TrashIcon } from '~/components/svg'; import { useLocalize } from '~/hooks'; const DeleteBookmarkButton: FC<{ @@ -36,31 +35,26 @@ const DeleteBookmarkButton: FC<{ await deleteBookmarkMutation.mutateAsync(bookmark); }, [bookmark, deleteBookmarkMutation]); - const handleKeyDown = (event: React.KeyboardEvent) => { - if (event.key === 'Enter' || event.key === ' ') { - event.preventDefault(); - event.stopPropagation(); - setOpen(!open); - } - }; - return ( <> setOpen(!open)} - onKeyDown={handleKeyDown} - > - - + render={ + + } + /> ) => { - if (event.key === 'Enter' || event.key === ' ') { - setOpen(!open); - } - }; - return ( setOpen(!open)} - className="flex size-7 items-center justify-center rounded-lg transition-colors duration-200 hover:bg-surface-hover" - onKeyDown={handleKeyDown} - > - - + render={ + + } + /> ); diff --git a/client/src/components/Chat/Header.tsx b/client/src/components/Chat/Header.tsx index 8a9bd80c23..93a265f4a2 100644 --- a/client/src/components/Chat/Header.tsx +++ b/client/src/components/Chat/Header.tsx @@ -39,7 +39,7 @@ export default function Header() {
diff --git a/client/src/components/Chat/Input/MCPSelect.tsx b/client/src/components/Chat/Input/MCPSelect.tsx index 0cb0206bcd..ebe56c8024 100644 --- a/client/src/components/Chat/Input/MCPSelect.tsx +++ b/client/src/components/Chat/Input/MCPSelect.tsx @@ -1,13 +1,31 @@ -import React, { memo, useRef, useMemo, useEffect, useCallback } from 'react'; +import React, { memo, useRef, useMemo, useEffect, useCallback, useState } from 'react'; import { useRecoilState } from 'recoil'; +import { Settings2 } from 'lucide-react'; +import { useUpdateUserPluginsMutation } from 'librechat-data-provider/react-query'; import { Constants, EModelEndpoint, LocalStorageKeys } from 'librechat-data-provider'; +import type { TPlugin, TPluginAuthConfig, TUpdateUserPlugins } from 'librechat-data-provider'; +import MCPConfigDialog, { type ConfigFieldDetail } from '~/components/ui/MCPConfigDialog'; import { useAvailableToolsQuery } from '~/data-provider'; import useLocalStorage from '~/hooks/useLocalStorageAlt'; import MultiSelect from '~/components/ui/MultiSelect'; import { ephemeralAgentByConvoId } from '~/store'; +import { useToastContext } from '~/Providers'; import MCPIcon from '~/components/ui/MCPIcon'; import { useLocalize } from '~/hooks'; +interface McpServerInfo { + name: string; + pluginKey: string; + authConfig?: TPluginAuthConfig[]; + authenticated?: boolean; +} + +// Helper function to extract mcp_serverName from a full pluginKey like action_mcp_serverName +const getBaseMCPPluginKey = (fullPluginKey: string): string => { + const parts = fullPluginKey.split(Constants.mcp_delimiter); + return Constants.mcp_prefix + parts[parts.length - 1]; +}; + const storageCondition = (value: unknown, rawCurrentValue?: string | null) => { if (rawCurrentValue) { try { @@ -24,20 +42,45 @@ const storageCondition = (value: unknown, rawCurrentValue?: string | null) => { function MCPSelect({ conversationId }: { conversationId?: string | null }) { const localize = useLocalize(); + const { showToast } = useToastContext(); const key = conversationId ?? Constants.NEW_CONVO; const hasSetFetched = useRef(null); + const [isConfigModalOpen, setIsConfigModalOpen] = useState(false); + const [selectedToolForConfig, setSelectedToolForConfig] = useState(null); - const { data: mcpServerSet, isFetched } = useAvailableToolsQuery(EModelEndpoint.agents, { - select: (data) => { - const serverNames = new Set(); + const { data: mcpToolDetails, isFetched } = useAvailableToolsQuery(EModelEndpoint.agents, { + select: (data: TPlugin[]) => { + const mcpToolsMap = new Map(); data.forEach((tool) => { const isMCP = tool.pluginKey.includes(Constants.mcp_delimiter); if (isMCP && tool.chatMenu !== false) { const parts = tool.pluginKey.split(Constants.mcp_delimiter); - serverNames.add(parts[parts.length - 1]); + const serverName = parts[parts.length - 1]; + if (!mcpToolsMap.has(serverName)) { + mcpToolsMap.set(serverName, { + name: serverName, + pluginKey: tool.pluginKey, + authConfig: tool.authConfig, + authenticated: tool.authenticated, + }); + } } }); - return serverNames; + return Array.from(mcpToolsMap.values()); + }, + }); + + const updateUserPluginsMutation = useUpdateUserPluginsMutation({ + onSuccess: () => { + setIsConfigModalOpen(false); + showToast({ message: localize('com_nav_mcp_vars_updated'), status: 'success' }); + }, + onError: (error: unknown) => { + console.error('Error updating MCP auth:', error); + showToast({ + message: localize('com_nav_mcp_vars_update_error'), + status: 'error', + }); }, }); @@ -76,12 +119,12 @@ function MCPSelect({ conversationId }: { conversationId?: string | null }) { return; } hasSetFetched.current = key; - if ((mcpServerSet?.size ?? 0) > 0) { - setMCPValues(mcpValues.filter((mcp) => mcpServerSet?.has(mcp))); + if ((mcpToolDetails?.length ?? 0) > 0) { + setMCPValues(mcpValues.filter((mcp) => mcpToolDetails?.some((tool) => tool.name === mcp))); return; } setMCPValues([]); - }, [isFetched, setMCPValues, mcpServerSet, key, mcpValues]); + }, [isFetched, setMCPValues, mcpToolDetails, key, mcpValues]); const renderSelectedValues = useCallback( (values: string[], placeholder?: string) => { @@ -96,28 +139,140 @@ function MCPSelect({ conversationId }: { conversationId?: string | null }) { [localize], ); - const mcpServers = useMemo(() => { - return Array.from(mcpServerSet ?? []); - }, [mcpServerSet]); + const mcpServerNames = useMemo(() => { + return (mcpToolDetails ?? []).map((tool) => tool.name); + }, [mcpToolDetails]); - if (!mcpServerSet || mcpServerSet.size === 0) { + const handleConfigSave = useCallback( + (targetName: string, authData: Record) => { + if (selectedToolForConfig && selectedToolForConfig.name === targetName) { + const basePluginKey = getBaseMCPPluginKey(selectedToolForConfig.pluginKey); + + const payload: TUpdateUserPlugins = { + pluginKey: basePluginKey, + action: 'install', + auth: authData, + }; + updateUserPluginsMutation.mutate(payload); + } + }, + [selectedToolForConfig, updateUserPluginsMutation], + ); + + const handleConfigRevoke = useCallback( + (targetName: string) => { + if (selectedToolForConfig && selectedToolForConfig.name === targetName) { + const basePluginKey = getBaseMCPPluginKey(selectedToolForConfig.pluginKey); + + const payload: TUpdateUserPlugins = { + pluginKey: basePluginKey, + action: 'uninstall', + auth: {}, + }; + updateUserPluginsMutation.mutate(payload); + } + }, + [selectedToolForConfig, updateUserPluginsMutation], + ); + + const renderItemContent = useCallback( + (serverName: string, defaultContent: React.ReactNode) => { + const tool = mcpToolDetails?.find((t) => t.name === serverName); + const hasAuthConfig = tool?.authConfig && tool.authConfig.length > 0; + + // Common wrapper for the main content (check mark + text) + // Ensures Check & Text are adjacent and the group takes available space. + const mainContentWrapper = ( +
{defaultContent}
+ ); + + if (tool && hasAuthConfig) { + return ( +
+ {mainContentWrapper} + +
+ ); + } + // For items without a settings icon, return the consistently wrapped main content. + return mainContentWrapper; + }, + [mcpToolDetails, setSelectedToolForConfig, setIsConfigModalOpen], + ); + + if (!mcpToolDetails || mcpToolDetails.length === 0) { return null; } return ( - } - selectItemsClassName="border border-blue-600/50 bg-blue-500/10 hover:bg-blue-700/10" - selectClassName="group relative inline-flex items-center justify-center md:justify-start gap-1.5 rounded-full border border-border-medium text-sm font-medium transition-all md:w-full size-9 p-2 md:p-3 bg-transparent shadow-sm hover:bg-surface-hover hover:shadow-md active:shadow-inner" - /> + <> + } + selectItemsClassName="border border-blue-600/50 bg-blue-500/10 hover:bg-blue-700/10" + selectClassName="group relative inline-flex items-center justify-center md:justify-start gap-1.5 rounded-full border border-border-medium text-sm font-medium transition-all md:w-full size-9 p-2 md:p-3 bg-transparent shadow-sm hover:bg-surface-hover hover:shadow-md active:shadow-inner" + /> + {selectedToolForConfig && ( + { + const schema: Record = {}; + if (selectedToolForConfig?.authConfig) { + selectedToolForConfig.authConfig.forEach((field) => { + schema[field.authField] = { + title: field.label, + description: field.description, + }; + }); + } + return schema; + })()} + initialValues={(() => { + const initial: Record = {}; + // Note: Actual initial values might need to be fetched if they are stored user-specifically + if (selectedToolForConfig?.authConfig) { + selectedToolForConfig.authConfig.forEach((field) => { + initial[field.authField] = ''; // Or fetched value + }); + } + return initial; + })()} + onSave={(authData) => { + if (selectedToolForConfig) { + handleConfigSave(selectedToolForConfig.name, authData); + } + }} + onRevoke={() => { + if (selectedToolForConfig) { + handleConfigRevoke(selectedToolForConfig.name); + } + }} + isSubmitting={updateUserPluginsMutation.isLoading} + /> + )} + ); } diff --git a/client/src/components/Chat/Menus/BookmarkMenu.tsx b/client/src/components/Chat/Menus/BookmarkMenu.tsx index 58fcbfdd8f..1f31249bda 100644 --- a/client/src/components/Chat/Menus/BookmarkMenu.tsx +++ b/client/src/components/Chat/Menus/BookmarkMenu.tsx @@ -157,8 +157,9 @@ const BookmarkMenu: FC = () => { return ( + {hasReasoningParts && (
diff --git a/client/src/components/Chat/Messages/Content/DialogImage.tsx b/client/src/components/Chat/Messages/Content/DialogImage.tsx index d41f482eb0..907902f4ed 100644 --- a/client/src/components/Chat/Messages/Content/DialogImage.tsx +++ b/client/src/components/Chat/Messages/Content/DialogImage.tsx @@ -68,7 +68,7 @@ export default function DialogImage({ isOpen, onOpenChange, src = '', downloadIm />
downloadImage()} variant="ghost" className="h-10 w-10 p-0"> @@ -108,7 +108,7 @@ export default function DialogImage({ isOpen, onOpenChange, src = '', downloadIm alt="Image" className="max-h-full max-w-full object-contain" style={{ - maxHeight: 'calc(100vh - 6rem)', // Account for header and padding + maxHeight: 'calc(100vh - 6rem)', maxWidth: '100%', }} /> @@ -117,7 +117,7 @@ export default function DialogImage({ isOpen, onOpenChange, src = '', downloadIm {/* Side Panel */}
@@ -132,7 +132,7 @@ export default function DialogImage({ isOpen, onOpenChange, src = '', downloadIm
{/* Prompt Section */}
-

+

{localize('com_ui_prompt')}

@@ -144,20 +144,18 @@ export default function DialogImage({ isOpen, onOpenChange, src = '', downloadIm {/* Generation Settings */}
-

+

{localize('com_ui_generation_settings')}

- {localize('com_ui_size')}: + {localize('com_ui_size')}: {args?.size || 'Unknown'}
- - {localize('com_ui_quality')}: - + {localize('com_ui_quality')}:
- + {localize('com_ui_file_size')}: diff --git a/client/src/components/Chat/Messages/Content/Image.tsx b/client/src/components/Chat/Messages/Content/Image.tsx index 85c3fdb3f2..ba4f65671a 100644 --- a/client/src/components/Chat/Messages/Content/Image.tsx +++ b/client/src/components/Chat/Messages/Content/Image.tsx @@ -46,13 +46,33 @@ const Image = ({ [placeholderDimensions, height, width], ); - const downloadImage = () => { - const link = document.createElement('a'); - link.href = imagePath; - link.download = altText; - document.body.appendChild(link); - link.click(); - document.body.removeChild(link); + const downloadImage = async () => { + try { + const response = await fetch(imagePath); + if (!response.ok) { + throw new Error(`Failed to fetch image: ${response.status}`); + } + + const blob = await response.blob(); + const url = window.URL.createObjectURL(blob); + + const link = document.createElement('a'); + link.href = url; + link.download = altText || 'image.png'; + document.body.appendChild(link); + link.click(); + document.body.removeChild(link); + + window.URL.revokeObjectURL(url); + } catch (error) { + console.error('Download failed:', error); + const link = document.createElement('a'); + link.href = imagePath; + link.download = altText || 'image.png'; + document.body.appendChild(link); + link.click(); + document.body.removeChild(link); + } }; return ( diff --git a/client/src/components/Chat/Messages/Content/Markdown.tsx b/client/src/components/Chat/Messages/Content/Markdown.tsx index 740bf66670..7bd6511cfa 100644 --- a/client/src/components/Chat/Messages/Content/Markdown.tsx +++ b/client/src/components/Chat/Messages/Content/Markdown.tsx @@ -204,7 +204,7 @@ const Markdown = memo(({ content = '', isLatestMessage }: TContentProps) => { remarkGfm, remarkDirective, artifactPlugin, - [remarkMath, { singleDollarTextMath: true }], + [remarkMath, { singleDollarTextMath: false }], unicodeCitation, ]; diff --git a/client/src/components/Chat/Messages/Content/MarkdownLite.tsx b/client/src/components/Chat/Messages/Content/MarkdownLite.tsx index 972395c425..019783607c 100644 --- a/client/src/components/Chat/Messages/Content/MarkdownLite.tsx +++ b/client/src/components/Chat/Messages/Content/MarkdownLite.tsx @@ -32,7 +32,7 @@ const MarkdownLite = memo( /** @ts-ignore */ supersub, remarkGfm, - [remarkMath, { singleDollarTextMath: true }], + [remarkMath, { singleDollarTextMath: false }], ]} /** @ts-ignore */ rehypePlugins={rehypePlugins} diff --git a/client/src/components/Chat/Messages/Content/MemoryArtifacts.tsx b/client/src/components/Chat/Messages/Content/MemoryArtifacts.tsx new file mode 100644 index 0000000000..7af4e9fcdd --- /dev/null +++ b/client/src/components/Chat/Messages/Content/MemoryArtifacts.tsx @@ -0,0 +1,143 @@ +import { Tools } from 'librechat-data-provider'; +import { useState, useRef, useMemo, useLayoutEffect, useEffect } from 'react'; +import type { MemoryArtifact, TAttachment } from 'librechat-data-provider'; +import MemoryInfo from './MemoryInfo'; +import { useLocalize } from '~/hooks'; +import { cn } from '~/utils'; + +export default function MemoryArtifacts({ attachments }: { attachments?: TAttachment[] }) { + const localize = useLocalize(); + const [showInfo, setShowInfo] = useState(false); + const contentRef = useRef(null); + const [contentHeight, setContentHeight] = useState(0); + const [isAnimating, setIsAnimating] = useState(false); + const prevShowInfoRef = useRef(showInfo); + + const memoryArtifacts = useMemo(() => { + const result: MemoryArtifact[] = []; + for (const attachment of attachments ?? []) { + if (attachment?.[Tools.memory] != null) { + result.push(attachment[Tools.memory]); + } + } + return result; + }, [attachments]); + + useLayoutEffect(() => { + if (showInfo !== prevShowInfoRef.current) { + prevShowInfoRef.current = showInfo; + setIsAnimating(true); + + if (showInfo && contentRef.current) { + requestAnimationFrame(() => { + if (contentRef.current) { + const height = contentRef.current.scrollHeight; + setContentHeight(height + 4); + } + }); + } else { + setContentHeight(0); + } + + const timer = setTimeout(() => { + setIsAnimating(false); + }, 400); + + return () => clearTimeout(timer); + } + }, [showInfo]); + + useEffect(() => { + if (!contentRef.current) { + return; + } + const resizeObserver = new ResizeObserver((entries) => { + if (showInfo && !isAnimating) { + for (const entry of entries) { + if (entry.target === contentRef.current) { + setContentHeight(entry.contentRect.height + 4); + } + } + } + }); + resizeObserver.observe(contentRef.current); + return () => { + resizeObserver.disconnect(); + }; + }, [showInfo, isAnimating]); + + if (!memoryArtifacts || memoryArtifacts.length === 0) { + return null; + } + + return ( + <> +
+
+ +
+
+
+
+
+ {showInfo && } +
+
+
+ + ); +} diff --git a/client/src/components/Chat/Messages/Content/MemoryInfo.tsx b/client/src/components/Chat/Messages/Content/MemoryInfo.tsx new file mode 100644 index 0000000000..574c2e8f5f --- /dev/null +++ b/client/src/components/Chat/Messages/Content/MemoryInfo.tsx @@ -0,0 +1,61 @@ +import type { MemoryArtifact } from 'librechat-data-provider'; +import { useLocalize } from '~/hooks'; + +export default function MemoryInfo({ memoryArtifacts }: { memoryArtifacts: MemoryArtifact[] }) { + const localize = useLocalize(); + if (memoryArtifacts.length === 0) { + return null; + } + + // Group artifacts by type + const updatedMemories = memoryArtifacts.filter((artifact) => artifact.type === 'update'); + const deletedMemories = memoryArtifacts.filter((artifact) => artifact.type === 'delete'); + + if (updatedMemories.length === 0 && deletedMemories.length === 0) { + return null; + } + + return ( +
+ {updatedMemories.length > 0 && ( +
+

+ {localize('com_ui_memory_updated_items')} +

+
+ {updatedMemories.map((artifact, index) => ( +
+
+ {artifact.key} +
+
+ {artifact.value} +
+
+ ))} +
+
+ )} + + {deletedMemories.length > 0 && ( +
+

+ {localize('com_ui_memory_deleted_items')} +

+
+ {deletedMemories.map((artifact, index) => ( +
+
+ {artifact.key} +
+
+ {localize('com_ui_memory_deleted')} +
+
+ ))} +
+
+ )} +
+ ); +} diff --git a/client/src/components/Chat/Messages/Content/Parts/EditTextPart.tsx b/client/src/components/Chat/Messages/Content/Parts/EditTextPart.tsx index e6736b192e..1ce207fe1c 100644 --- a/client/src/components/Chat/Messages/Content/Parts/EditTextPart.tsx +++ b/client/src/components/Chat/Messages/Content/Parts/EditTextPart.tsx @@ -117,9 +117,9 @@ const EditTextPart = ({ messages.map((msg) => msg.messageId === messageId ? { - ...msg, - content: updatedContent, - } + ...msg, + content: updatedContent, + } : msg, ), ); diff --git a/client/src/components/Chat/Messages/Content/Parts/OpenAIImageGen/OpenAIImageGen.tsx b/client/src/components/Chat/Messages/Content/Parts/OpenAIImageGen/OpenAIImageGen.tsx index e13e68312d..ef24c3553e 100644 --- a/client/src/components/Chat/Messages/Content/Parts/OpenAIImageGen/OpenAIImageGen.tsx +++ b/client/src/components/Chat/Messages/Content/Parts/OpenAIImageGen/OpenAIImageGen.tsx @@ -178,17 +178,6 @@ export default function OpenAIImageGen({
- - {/* {showInfo && hasInfo && ( - 0 && !cancelled && initialProgress < 1} - /> - )} */} -
{dimensions.width !== 'auto' && progress < 1 && ( diff --git a/client/src/components/Chat/Messages/Feedback.tsx b/client/src/components/Chat/Messages/Feedback.tsx index cf7ccadbab..4879808d90 100644 --- a/client/src/components/Chat/Messages/Feedback.tsx +++ b/client/src/components/Chat/Messages/Feedback.tsx @@ -216,18 +216,12 @@ function FeedbackButtons({ function buttonClasses(isActive: boolean, isLast: boolean) { return cn( - 'hover-button rounded-lg p-1.5', - 'focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-black dark:focus-visible:ring-white', - 'hover:bg-gray-100 hover:text-gray-500', - 'data-[state=open]:active data-[state=open]:bg-gray-100 data-[state=open]:text-gray-500', - isActive ? 'text-gray-500 dark:text-gray-200 font-bold' : 'dark:text-gray-400/70', - 'dark:hover:bg-gray-700 dark:hover:text-gray-200', - 'data-[state=open]:dark:bg-gray-700 data-[state=open]:dark:text-gray-200', - 'disabled:dark:hover:text-gray-400', - isLast - ? '' - : 'data-[state=open]:opacity-100 md:opacity-0 md:group-focus-within:opacity-100 md:group-hover:opacity-100', - 'md:group-focus-within:visible md:group-hover:visible md:group-[.final-completion]:visible', + 'hover-button rounded-lg p-1.5 text-text-secondary-alt transition-colors duration-200', + 'hover:text-text-primary hover:bg-surface-hover', + 'md:group-hover:visible md:group-focus-within:visible md:group-[.final-completion]:visible', + !isLast && 'md:opacity-0 md:group-hover:opacity-100 md:group-focus-within:opacity-100', + 'focus-visible:ring-2 focus-visible:ring-black dark:focus-visible:ring-white focus-visible:outline-none', + isActive && 'active text-text-primary bg-surface-hover', ); } diff --git a/client/src/components/Chat/Messages/Fork.tsx b/client/src/components/Chat/Messages/Fork.tsx index 5bc0bd8839..1cc319c3dd 100644 --- a/client/src/components/Chat/Messages/Fork.tsx +++ b/client/src/components/Chat/Messages/Fork.tsx @@ -211,14 +211,12 @@ export default function Fork({ }); const buttonStyle = cn( - 'hover-button rounded-lg p-1.5', - 'hover:bg-gray-100 hover:text-gray-500', - 'dark:text-gray-400/70 dark:hover:bg-gray-700 dark:hover:text-gray-200', - 'disabled:dark:hover:text-gray-400', + 'hover-button rounded-lg p-1.5 text-text-secondary-alt transition-colors duration-200', + 'hover:text-text-primary hover:bg-surface-hover', 'md:group-hover:visible md:group-focus-within:visible md:group-[.final-completion]:visible', !isLast && 'md:opacity-0 md:group-hover:opacity-100 md:group-focus-within:opacity-100', 'focus-visible:ring-2 focus-visible:ring-black dark:focus-visible:ring-white focus-visible:outline-none', - isActive && 'active text-gray-700 dark:text-gray-200 bg-gray-100 bg-gray-700', + isActive && 'active text-text-primary bg-surface-hover', ); const forkConvo = useForkConvoMutation({ diff --git a/client/src/components/Chat/Messages/HoverButtons.tsx b/client/src/components/Chat/Messages/HoverButtons.tsx index 5783540bb8..a13266f04c 100644 --- a/client/src/components/Chat/Messages/HoverButtons.tsx +++ b/client/src/components/Chat/Messages/HoverButtons.tsx @@ -25,6 +25,7 @@ type THoverButtons = { }; type HoverButtonProps = { + id?: string; onClick: (e?: React.MouseEvent) => void; title: string; icon: React.ReactNode; @@ -67,6 +68,7 @@ const extractMessageContent = (message: TMessage): string => { const HoverButton = memo( ({ + id, onClick, title, icon, @@ -77,26 +79,19 @@ const HoverButton = memo( className = '', }: HoverButtonProps) => { const buttonStyle = cn( - 'hover-button rounded-lg p-1.5', - - 'hover:bg-gray-100 hover:text-gray-500', - - 'dark:text-gray-400/70 dark:hover:bg-gray-700 dark:hover:text-gray-200', - 'disabled:dark:hover:text-gray-400', - + 'hover-button rounded-lg p-1.5 text-text-secondary-alt transition-colors duration-200', + 'hover:text-text-primary hover:bg-surface-hover', 'md:group-hover:visible md:group-focus-within:visible md:group-[.final-completion]:visible', !isLast && 'md:opacity-0 md:group-hover:opacity-100 md:group-focus-within:opacity-100', !isVisible && 'opacity-0', - 'focus-visible:ring-2 focus-visible:ring-black dark:focus-visible:ring-white focus-visible:outline-none', - - isActive && isVisible && 'active text-gray-700 dark:text-gray-200 bg-gray-100 bg-gray-700', - + isActive && isVisible && 'active text-text-primary bg-surface-hover', className, ); return ( - + {siblingIdx + 1} / {siblingCount} -
+ ) : null; } diff --git a/client/src/components/Nav/Nav.tsx b/client/src/components/Nav/Nav.tsx index d425468ca6..a692274719 100644 --- a/client/src/components/Nav/Nav.tsx +++ b/client/src/components/Nav/Nav.tsx @@ -30,7 +30,7 @@ const NavMask = memo( id="mobile-nav-mask-toggle" role="button" tabIndex={0} - className={`nav-mask transition-opacity duration-500 ease-in-out ${navVisible ? 'active opacity-100' : 'opacity-0'}`} + className={`nav-mask transition-opacity duration-200 ease-in-out ${navVisible ? 'active opacity-100' : 'opacity-0'}`} onClick={toggleNavVisible} onKeyDown={(e) => { if (e.key === 'Enter' || e.key === ' ') { @@ -186,7 +186,7 @@ const Nav = memo(
); } diff --git a/client/src/components/Prompts/Groups/PanelNavigation.tsx b/client/src/components/Prompts/Groups/PanelNavigation.tsx index 797f0221b1..867df7a784 100644 --- a/client/src/components/Prompts/Groups/PanelNavigation.tsx +++ b/client/src/components/Prompts/Groups/PanelNavigation.tsx @@ -43,7 +43,7 @@ function PanelNavigation({ {localize('com_ui_next')}
-
+ ); } diff --git a/client/src/components/Prompts/Groups/VariableForm.tsx b/client/src/components/Prompts/Groups/VariableForm.tsx index 57458fb444..09bdcd40da 100644 --- a/client/src/components/Prompts/Groups/VariableForm.tsx +++ b/client/src/components/Prompts/Groups/VariableForm.tsx @@ -143,7 +143,7 @@ export default function VariableForm({
{ /** @ts-ignore */ supersub, remarkGfm, - [remarkMath, { singleDollarTextMath: true }], + [remarkMath, { singleDollarTextMath: false }], ]} rehypePlugins={[ /** @ts-ignore */ diff --git a/client/src/components/Prompts/PromptEditor.tsx b/client/src/components/Prompts/PromptEditor.tsx index 7db89c7078..f10a94c11e 100644 --- a/client/src/components/Prompts/PromptEditor.tsx +++ b/client/src/components/Prompts/PromptEditor.tsx @@ -130,7 +130,7 @@ const PromptEditor: React.FC = ({ name, isEditing, setIsEditing }) => { /** @ts-ignore */ supersub, remarkGfm, - [remarkMath, { singleDollarTextMath: true }], + [remarkMath, { singleDollarTextMath: false }], ]} /** @ts-ignore */ rehypePlugins={rehypePlugins} diff --git a/client/src/components/SidePanel/Agents/ActionsPanel.tsx b/client/src/components/SidePanel/Agents/ActionsPanel.tsx index 514e1b61eb..89b9a87f92 100644 --- a/client/src/components/SidePanel/Agents/ActionsPanel.tsx +++ b/client/src/components/SidePanel/Agents/ActionsPanel.tsx @@ -1,31 +1,27 @@ import { useEffect } from 'react'; +import { ChevronLeft } from 'lucide-react'; import { useForm, FormProvider } from 'react-hook-form'; import { AuthTypeEnum, AuthorizationTypeEnum, TokenExchangeMethodEnum, } from 'librechat-data-provider'; -import { ChevronLeft } from 'lucide-react'; -import type { AgentPanelProps, ActionAuthForm } from '~/common'; import ActionsAuth from '~/components/SidePanel/Builder/ActionsAuth'; +import { useAgentPanelContext } from '~/Providers/AgentPanelContext'; import { OGDialog, OGDialogTrigger, Label } from '~/components/ui'; import OGDialogTemplate from '~/components/ui/OGDialogTemplate'; import { useDeleteAgentAction } from '~/data-provider'; +import type { ActionAuthForm } from '~/common'; import useLocalize from '~/hooks/useLocalize'; import { useToastContext } from '~/Providers'; import { TrashIcon } from '~/components/svg'; import ActionsInput from './ActionsInput'; import { Panel } from '~/common'; -export default function ActionsPanel({ - // activePanel, - action, - setAction, - agent_id, - setActivePanel, -}: AgentPanelProps) { +export default function ActionsPanel() { const localize = useLocalize(); const { showToast } = useToastContext(); + const { setActivePanel, action, setAction, agent_id } = useAgentPanelContext(); const deleteAgentAction = useDeleteAgentAction({ onSuccess: () => { showToast({ @@ -62,7 +58,7 @@ export default function ActionsPanel({ }, }); - const { reset, watch } = methods; + const { reset } = methods; useEffect(() => { if (action?.metadata.auth) { @@ -128,7 +124,7 @@ export default function ActionsPanel({ selectHandler: () => { if (!agent_id) { return showToast({ - message: 'No agent_id found, is the agent created?', + message: localize('com_agents_no_agent_id_error'), status: 'error', }); } diff --git a/client/src/components/SidePanel/Agents/AgentConfig.tsx b/client/src/components/SidePanel/Agents/AgentConfig.tsx index c04018e69a..2afa56601c 100644 --- a/client/src/components/SidePanel/Agents/AgentConfig.tsx +++ b/client/src/components/SidePanel/Agents/AgentConfig.tsx @@ -1,11 +1,9 @@ import React, { useState, useMemo, useCallback } from 'react'; -import { useQueryClient } from '@tanstack/react-query'; import { Controller, useWatch, useFormContext } from 'react-hook-form'; -import { QueryKeys, EModelEndpoint, AgentCapabilities } from 'librechat-data-provider'; -import type { TPlugin } from 'librechat-data-provider'; +import { EModelEndpoint, AgentCapabilities } from 'librechat-data-provider'; import type { AgentForm, AgentPanelProps, IconComponentTypes } from '~/common'; import { cn, defaultTextProps, removeFocusOutlines, getEndpointField, getIconKey } from '~/utils'; -import { useToastContext, useFileMapContext } from '~/Providers'; +import { useToastContext, useFileMapContext, useAgentPanelContext } from '~/Providers'; import Action from '~/components/SidePanel/Builder/Action'; import { ToolSelectDialog } from '~/components/Tools'; import { icons } from '~/hooks/Endpoint/Icons'; @@ -29,23 +27,16 @@ const inputClass = cn( ); export default function AgentConfig({ - setAction, - actions = [], agentsConfig, createMutation, - setActivePanel, endpointsConfig, -}: AgentPanelProps) { - const fileMap = useFileMapContext(); - const queryClient = useQueryClient(); - - const allTools = queryClient.getQueryData([QueryKeys.tools]) ?? []; - const { showToast } = useToastContext(); +}: Pick) { const localize = useLocalize(); - - const [showToolDialog, setShowToolDialog] = useState(false); - + const fileMap = useFileMapContext(); + const { showToast } = useToastContext(); const methods = useFormContext(); + const [showToolDialog, setShowToolDialog] = useState(false); + const { actions, setAction, groupedTools: allTools, setActivePanel } = useAgentPanelContext(); const { control } = methods; const provider = useWatch({ control, name: 'provider' }); @@ -172,6 +163,20 @@ export default function AgentConfig({ Icon = icons[iconKey]; } + // Determine what to show + const selectedToolIds = tools ?? []; + const visibleToolIds = new Set(selectedToolIds); + + // Check what group parent tools should be shown if any subtool is present + Object.entries(allTools).forEach(([toolId, toolObj]) => { + if (toolObj.tools?.length) { + // if any subtool of this group is selected, ensure group parent tool rendered + if (toolObj.tools.some((st) => selectedToolIds.includes(st.tool_id))) { + visibleToolIds.add(toolId); + } + } + }); + return ( <>
@@ -290,28 +295,37 @@ export default function AgentConfig({ ${toolsEnabled === true && actionsEnabled === true ? ' + ' : ''} ${actionsEnabled === true ? localize('com_assistants_actions') : ''}`} -
- {tools?.map((func, i) => ( - - ))} - {actions - .filter((action) => action.agent_id === agent_id) - .map((action, i) => ( - { - setAction(action); - setActivePanel(Panel.actions); - }} - /> - ))} -
+
+
+ {/* // Render all visible IDs (including groups with subtools selected) */} + {[...visibleToolIds].map((toolId, i) => { + const tool = allTools[toolId]; + if (!tool) return null; + return ( + + ); + })} +
+
+ {(actions ?? []) + .filter((action) => action.agent_id === agent_id) + .map((action, i) => ( + { + setAction(action); + setActivePanel(Panel.actions); + }} + /> + ))} +
+
{(toolsEnabled ?? false) && (
+ {/* MCP Section */} + {/* */}
diff --git a/client/src/components/SidePanel/Agents/AgentPanel.tsx b/client/src/components/SidePanel/Agents/AgentPanel.tsx index 34f83b9257..78874e41c5 100644 --- a/client/src/components/SidePanel/Agents/AgentPanel.tsx +++ b/client/src/components/SidePanel/Agents/AgentPanel.tsx @@ -7,20 +7,22 @@ import { Constants, SystemRoles, EModelEndpoint, + TAgentsEndpoint, + TEndpointsConfig, isAssistantsEndpoint, - defaultAgentFormValues, } from 'librechat-data-provider'; -import type { AgentForm, AgentPanelProps, StringOption } from '~/common'; +import type { AgentForm, StringOption } from '~/common'; import { useCreateAgentMutation, useUpdateAgentMutation, useGetAgentByIdQuery, } from '~/data-provider'; +import { createProviderOption, getDefaultAgentFormValues } from '~/utils'; import { useSelectAgent, useLocalize, useAuthContext } from '~/hooks'; +import { useAgentPanelContext } from '~/Providers/AgentPanelContext'; import AgentPanelSkeleton from './AgentPanelSkeleton'; -import { createProviderOption } from '~/utils'; -import { useToastContext } from '~/Providers'; import AdvancedPanel from './Advanced/AdvancedPanel'; +import { useToastContext } from '~/Providers'; import AgentConfig from './AgentConfig'; import AgentSelect from './AgentSelect'; import AgentFooter from './AgentFooter'; @@ -29,18 +31,21 @@ import ModelPanel from './ModelPanel'; import { Panel } from '~/common'; export default function AgentPanel({ - setAction, - activePanel, - actions = [], - setActivePanel, - agent_id: current_agent_id, - setCurrentAgentId, agentsConfig, endpointsConfig, -}: AgentPanelProps) { +}: { + agentsConfig: TAgentsEndpoint | null; + endpointsConfig: TEndpointsConfig; +}) { const localize = useLocalize(); const { user } = useAuthContext(); const { showToast } = useToastContext(); + const { + activePanel, + setActivePanel, + setCurrentAgentId, + agent_id: current_agent_id, + } = useAgentPanelContext(); const { onSelect: onSelectAgent } = useSelectAgent(); @@ -51,7 +56,7 @@ export default function AgentPanel({ const models = useMemo(() => modelsQuery.data ?? {}, [modelsQuery.data]); const methods = useForm({ - defaultValues: defaultAgentFormValues, + defaultValues: getDefaultAgentFormValues(), }); const { control, handleSubmit, reset } = methods; @@ -277,7 +282,7 @@ export default function AgentPanel({ variant="outline" className="w-full justify-center" onClick={() => { - reset(defaultAgentFormValues); + reset(getDefaultAgentFormValues()); setCurrentAgentId(undefined); }} disabled={agentQuery.isInitialLoading} @@ -315,22 +320,13 @@ export default function AgentPanel({
)} {canEditAgent && !agentQuery.isInitialLoading && activePanel === Panel.model && ( - + )} {canEditAgent && !agentQuery.isInitialLoading && activePanel === Panel.builder && ( )} {canEditAgent && !agentQuery.isInitialLoading && activePanel === Panel.advanced && ( diff --git a/client/src/components/SidePanel/Agents/AgentPanelSwitch.tsx b/client/src/components/SidePanel/Agents/AgentPanelSwitch.tsx index 4dc54c9b60..495b047b51 100644 --- a/client/src/components/SidePanel/Agents/AgentPanelSwitch.tsx +++ b/client/src/components/SidePanel/Agents/AgentPanelSwitch.tsx @@ -1,22 +1,29 @@ -import { useState, useEffect, useMemo } from 'react'; +import { useEffect, useMemo } from 'react'; import { EModelEndpoint, AgentCapabilities } from 'librechat-data-provider'; -import type { ActionsEndpoint } from '~/common'; -import type { Action, TConfig, TEndpointsConfig, TAgentsEndpoint } from 'librechat-data-provider'; -import { useGetActionsQuery, useGetEndpointsQuery, useCreateAgentMutation } from '~/data-provider'; +import type { TConfig, TEndpointsConfig, TAgentsEndpoint } from 'librechat-data-provider'; +import { AgentPanelProvider, useAgentPanelContext } from '~/Providers/AgentPanelContext'; +import { useGetEndpointsQuery } from '~/data-provider'; +import VersionPanel from './Version/VersionPanel'; import { useChatContext } from '~/Providers'; import ActionsPanel from './ActionsPanel'; import AgentPanel from './AgentPanel'; -import VersionPanel from './Version/VersionPanel'; +import MCPPanel from './MCPPanel'; import { Panel } from '~/common'; export default function AgentPanelSwitch() { - const { conversation, index } = useChatContext(); - const [activePanel, setActivePanel] = useState(Panel.builder); - const [action, setAction] = useState(undefined); - const [currentAgentId, setCurrentAgentId] = useState(conversation?.agent_id); - const { data: actions = [] } = useGetActionsQuery(conversation?.endpoint as ActionsEndpoint); + return ( + + + + ); +} + +function AgentPanelSwitchWithContext() { + const { conversation } = useChatContext(); + const { activePanel, setCurrentAgentId } = useAgentPanelContext(); + + // TODO: Implement MCP endpoint const { data: endpointsConfig = {} as TEndpointsConfig } = useGetEndpointsQuery(); - const createMutation = useCreateAgentMutation(); const agentsConfig = useMemo(() => { const config = endpointsConfig?.[EModelEndpoint.agents] ?? null; @@ -35,39 +42,20 @@ export default function AgentPanelSwitch() { if (agent_id) { setCurrentAgentId(agent_id); } - }, [conversation?.agent_id]); + }, [setCurrentAgentId, conversation?.agent_id]); if (!conversation?.endpoint) { return null; } - const commonProps = { - index, - action, - actions, - setAction, - activePanel, - setActivePanel, - setCurrentAgentId, - agent_id: currentAgentId, - createMutation, - }; - if (activePanel === Panel.actions) { - return ; + return ; } - if (activePanel === Panel.version) { - return ( - - ); + return ; } - - return ( - - ); + if (activePanel === Panel.mcp) { + return ; + } + return ; } diff --git a/client/src/components/SidePanel/Agents/AgentSelect.tsx b/client/src/components/SidePanel/Agents/AgentSelect.tsx index d265ba201f..496dd4ee6d 100644 --- a/client/src/components/SidePanel/Agents/AgentSelect.tsx +++ b/client/src/components/SidePanel/Agents/AgentSelect.tsx @@ -5,8 +5,8 @@ import { AgentCapabilities, defaultAgentFormValues } from 'librechat-data-provid import type { UseMutationResult, QueryObserverResult } from '@tanstack/react-query'; import type { Agent, AgentCreateParams } from 'librechat-data-provider'; import type { TAgentCapabilities, AgentForm } from '~/common'; +import { cn, createProviderOption, processAgentOption, getDefaultAgentFormValues } from '~/utils'; import { useListAgentsQuery, useGetStartupConfig } from '~/data-provider'; -import { cn, createProviderOption, processAgentOption } from '~/utils'; import ControlCombobox from '~/components/ui/ControlCombobox'; import { useLocalize } from '~/hooks'; @@ -32,7 +32,10 @@ export default function AgentSelect({ select: (res) => res.data.map((agent) => processAgentOption({ - agent, + agent: { + ...agent, + name: agent.name || agent.id, + }, instanceProjectId: startupConfig?.instanceProjectId, }), ), @@ -124,9 +127,7 @@ export default function AgentSelect({ createMutation.reset(); if (!agentExists) { setCurrentAgentId(undefined); - return reset({ - ...defaultAgentFormValues, - }); + return reset(getDefaultAgentFormValues()); } setCurrentAgentId(selectedId); @@ -179,7 +180,7 @@ export default function AgentSelect({ containerClassName="px-0" selectedValue={(field?.value?.value ?? '') + ''} displayValue={field?.value?.label ?? ''} - selectPlaceholder={createAgent} + selectPlaceholder={field?.value?.value ?? createAgent} iconSide="right" searchPlaceholder={localize('com_agents_search_name')} SelectIcon={field?.value?.icon} diff --git a/client/src/components/SidePanel/Agents/AgentTool.tsx b/client/src/components/SidePanel/Agents/AgentTool.tsx index 59c2e267f1..4876f447fb 100644 --- a/client/src/components/SidePanel/Agents/AgentTool.tsx +++ b/client/src/components/SidePanel/Agents/AgentTool.tsx @@ -1,41 +1,69 @@ import React, { useState } from 'react'; +import * as Ariakit from '@ariakit/react'; +import { ChevronDown } from 'lucide-react'; import { useFormContext } from 'react-hook-form'; -import type { TPlugin } from 'librechat-data-provider'; +import * as AccordionPrimitive from '@radix-ui/react-accordion'; import { useUpdateUserPluginsMutation } from 'librechat-data-provider/react-query'; -import { OGDialog, OGDialogTrigger, Label } from '~/components/ui'; +import type { AgentToolType } from 'librechat-data-provider'; +import type { AgentForm } from '~/common'; +import { Accordion, AccordionItem, AccordionContent } from '~/components/ui/Accordion'; +import { OGDialog, OGDialogTrigger, Label, Checkbox } from '~/components/ui'; +import { TrashIcon, CircleHelpIcon } from '~/components/svg'; import OGDialogTemplate from '~/components/ui/OGDialogTemplate'; import { useToastContext } from '~/Providers'; -import { TrashIcon } from '~/components/svg'; import { useLocalize } from '~/hooks'; import { cn } from '~/utils'; export default function AgentTool({ tool, allTools, - agent_id = '', }: { tool: string; - allTools: TPlugin[]; + allTools: Record; agent_id?: string; }) { const [isHovering, setIsHovering] = useState(false); + const [isFocused, setIsFocused] = useState(false); + const [hoveredToolId, setHoveredToolId] = useState(null); + const [accordionValue, setAccordionValue] = useState(''); const localize = useLocalize(); const { showToast } = useToastContext(); const updateUserPlugins = useUpdateUserPluginsMutation(); - const { getValues, setValue } = useFormContext(); - const currentTool = allTools.find((t) => t.pluginKey === tool); + const { getValues, setValue } = useFormContext(); + const currentTool = allTools[tool]; + + const getSelectedTools = () => { + if (!currentTool?.tools) return []; + const formTools = getValues('tools') || []; + return currentTool.tools.filter((t) => formTools.includes(t.tool_id)).map((t) => t.tool_id); + }; + + const updateFormTools = (newSelectedTools: string[]) => { + const currentTools = getValues('tools') || []; + const otherTools = currentTools.filter( + (t: string) => !currentTool?.tools?.some((st) => st.tool_id === t), + ); + setValue('tools', [...otherTools, ...newSelectedTools]); + }; + + const removeTool = (toolId: string) => { + if (toolId) { + const toolIdsToRemove = + isGroup && currentTool.tools + ? [toolId, ...currentTool.tools.map((t) => t.tool_id)] + : [toolId]; - const removeTool = (tool: string) => { - if (tool) { updateUserPlugins.mutate( - { pluginKey: tool, action: 'uninstall', auth: null, isEntityTool: true }, + { pluginKey: toolId, action: 'uninstall', auth: {}, isEntityTool: true }, { onError: (error: unknown) => { showToast({ message: `Error while deleting the tool: ${error}`, status: 'error' }); }, onSuccess: () => { - const tools = getValues('tools').filter((fn: string) => fn !== tool); - setValue('tools', tools); + const remainingToolIds = getValues('tools')?.filter( + (toolId: string) => !toolIdsToRemove.includes(toolId), + ); + setValue('tools', remainingToolIds); showToast({ message: 'Tool deleted successfully', status: 'success' }); }, }, @@ -47,41 +75,309 @@ export default function AgentTool({ return null; } - return ( - -
setIsHovering(true)} - onMouseLeave={() => setIsHovering(false)} - > -
- {currentTool.icon && ( -
-
-
- )} -
- {currentTool.name} -
-
+ const isGroup = currentTool.tools && currentTool.tools.length > 0; + const selectedTools = getSelectedTools(); + const isExpanded = accordionValue === currentTool.tool_id; + + if (!isGroup) { + return ( + +
setIsHovering(true)} + onMouseLeave={() => setIsHovering(false)} + onFocus={() => setIsFocused(true)} + onBlur={(e) => { + // Check if focus is moving to a child element + if (!e.currentTarget.contains(e.relatedTarget)) { + setIsFocused(false); + } + }} + > +
+ {currentTool.metadata.icon && ( +
+
+
+ )} +
+ {currentTool.metadata.name} +
+
- {isHovering && ( - )} -
+
+ + {localize('com_ui_delete_tool_confirm')} + + } + selection={{ + selectHandler: () => removeTool(currentTool.tool_id), + selectClasses: + 'bg-red-700 dark:bg-red-600 hover:bg-red-800 dark:hover:bg-red-800 transition-color duration-200 text-white', + selectText: localize('com_ui_delete'), + }} + /> +
+ ); + } + + // Group tool with accordion + return ( + + + +
setIsHovering(true)} + onMouseLeave={() => setIsHovering(false)} + onFocus={() => setIsFocused(true)} + onBlur={(e) => { + // Check if focus is moving to a child element + if (!e.currentTarget.contains(e.relatedTarget)) { + setIsFocused(false); + } + }} + > + + + + +
+
+
+ + + +
+ + +
+ {currentTool.tools?.map((subTool) => ( + + ))} +
+
+ + } selection={{ - selectHandler: () => removeTool(currentTool.pluginKey), + selectHandler: () => removeTool(currentTool.tool_id), selectClasses: 'bg-red-700 dark:bg-red-600 hover:bg-red-800 dark:hover:bg-red-800 transition-color duration-200 text-white', selectText: localize('com_ui_delete'), diff --git a/client/src/components/SidePanel/Agents/DeleteButton.tsx b/client/src/components/SidePanel/Agents/DeleteButton.tsx index 388d8f25e0..bfe08c666f 100644 --- a/client/src/components/SidePanel/Agents/DeleteButton.tsx +++ b/client/src/components/SidePanel/Agents/DeleteButton.tsx @@ -1,12 +1,11 @@ import { useFormContext } from 'react-hook-form'; -import { defaultAgentFormValues } from 'librechat-data-provider'; import type { Agent, AgentCreateParams } from 'librechat-data-provider'; import type { UseMutationResult } from '@tanstack/react-query'; +import { cn, logger, removeFocusOutlines, getDefaultAgentFormValues } from '~/utils'; import { OGDialog, OGDialogTrigger, Label } from '~/components/ui'; -import { useChatContext, useToastContext } from '~/Providers'; import OGDialogTemplate from '~/components/ui/OGDialogTemplate'; +import { useChatContext, useToastContext } from '~/Providers'; import { useLocalize, useSetIndexOptions } from '~/hooks'; -import { cn, removeFocusOutlines, logger } from '~/utils'; import { useDeleteAgentMutation } from '~/data-provider'; import { TrashIcon } from '~/components/svg'; @@ -45,9 +44,7 @@ export default function DeleteButton({ const firstAgent = updatedList[0] as Agent | undefined; if (!firstAgent) { setCurrentAgentId(undefined); - reset({ - ...defaultAgentFormValues, - }); + reset(getDefaultAgentFormValues()); return setOption('agent_id')(''); } diff --git a/client/src/components/SidePanel/Agents/MCPIcon.tsx b/client/src/components/SidePanel/Agents/MCPIcon.tsx new file mode 100644 index 0000000000..101abb4c9e --- /dev/null +++ b/client/src/components/SidePanel/Agents/MCPIcon.tsx @@ -0,0 +1,64 @@ +import { useState, useEffect, useRef } from 'react'; +import SquirclePlusIcon from '~/components/svg/SquirclePlusIcon'; +import { useLocalize } from '~/hooks'; + +interface MCPIconProps { + icon?: string; + onIconChange: (e: React.ChangeEvent) => void; +} + +export default function MCPIcon({ icon, onIconChange }: MCPIconProps) { + const [previewUrl, setPreviewUrl] = useState(''); + const fileInputRef = useRef(null); + const localize = useLocalize(); + + useEffect(() => { + if (icon) { + setPreviewUrl(icon); + } else { + setPreviewUrl(''); + } + }, [icon]); + + const handleClick = () => { + if (fileInputRef.current) { + fileInputRef.current.value = ''; + fileInputRef.current.click(); + } + }; + + return ( +
+
+ {previewUrl ? ( + MCP Icon + ) : ( + + )} +
+
+ + {localize('com_ui_icon')} {localize('com_ui_optional')} + + {localize('com_agents_mcp_icon_size')} +
+ +
+ ); +} diff --git a/client/src/components/SidePanel/Agents/MCPInput.tsx b/client/src/components/SidePanel/Agents/MCPInput.tsx new file mode 100644 index 0000000000..078f4109dc --- /dev/null +++ b/client/src/components/SidePanel/Agents/MCPInput.tsx @@ -0,0 +1,288 @@ +import { useState, useEffect } from 'react'; +import { useFormContext, Controller } from 'react-hook-form'; +import { MCP } from 'librechat-data-provider/dist/types/types/assistants'; +import MCPAuth from '~/components/SidePanel/Builder/MCPAuth'; +import MCPIcon from '~/components/SidePanel/Agents/MCPIcon'; +import { Label, Checkbox } from '~/components/ui'; +import useLocalize from '~/hooks/useLocalize'; +import { useToastContext } from '~/Providers'; +import { Spinner } from '~/components/svg'; +import { MCPForm } from '~/common/types'; + +function useUpdateAgentMCP({ + onSuccess, + onError, +}: { + onSuccess: (data: [string, MCP]) => void; + onError: (error: Error) => void; +}) { + return { + mutate: async ({ + mcp_id, + metadata, + agent_id, + }: { + mcp_id?: string; + metadata: MCP['metadata']; + agent_id: string; + }) => { + try { + // TODO: Implement MCP endpoint + onSuccess(['success', { mcp_id, metadata, agent_id } as MCP]); + } catch (error) { + onError(error as Error); + } + }, + isLoading: false, + }; +} + +interface MCPInputProps { + mcp?: MCP; + agent_id?: string; + setMCP: React.Dispatch>; +} + +export default function MCPInput({ mcp, agent_id, setMCP }: MCPInputProps) { + const localize = useLocalize(); + const { showToast } = useToastContext(); + const { + handleSubmit, + register, + formState: { errors }, + control, + } = useFormContext(); + const [isLoading, setIsLoading] = useState(false); + const [showTools, setShowTools] = useState(false); + const [selectedTools, setSelectedTools] = useState([]); + + // Initialize tools list if editing existing MCP + useEffect(() => { + if (mcp?.mcp_id && mcp.metadata.tools) { + setShowTools(true); + setSelectedTools(mcp.metadata.tools); + } + }, [mcp]); + + const updateAgentMCP = useUpdateAgentMCP({ + onSuccess(data) { + showToast({ + message: localize('com_ui_update_mcp_success'), + status: 'success', + }); + setMCP(data[1]); + setShowTools(true); + setSelectedTools(data[1].metadata.tools ?? []); + setIsLoading(false); + }, + onError(error) { + showToast({ + message: (error as Error).message || localize('com_ui_update_mcp_error'), + status: 'error', + }); + setIsLoading(false); + }, + }); + + const saveMCP = handleSubmit(async (data: MCPForm) => { + setIsLoading(true); + try { + const response = await updateAgentMCP.mutate({ + agent_id: agent_id ?? '', + mcp_id: mcp?.mcp_id, + metadata: { + ...data, + tools: selectedTools, + }, + }); + setMCP(response[1]); + showToast({ + message: localize('com_ui_update_mcp_success'), + status: 'success', + }); + } catch { + showToast({ + message: localize('com_ui_update_mcp_error'), + status: 'error', + }); + } finally { + setIsLoading(false); + } + }); + + const handleSelectAll = () => { + if (mcp?.metadata.tools) { + setSelectedTools(mcp.metadata.tools); + } + }; + + const handleDeselectAll = () => { + setSelectedTools([]); + }; + + const handleToolToggle = (tool: string) => { + setSelectedTools((prev) => + prev.includes(tool) ? prev.filter((t) => t !== tool) : [...prev, tool], + ); + }; + + const handleToggleAll = () => { + if (selectedTools.length === mcp?.metadata.tools?.length) { + handleDeselectAll(); + } else { + handleSelectAll(); + } + }; + + const handleIconChange = (e: React.ChangeEvent) => { + const file = e.target.files?.[0]; + if (file) { + const reader = new FileReader(); + reader.onloadend = () => { + const base64String = reader.result as string; + setMCP({ + mcp_id: mcp?.mcp_id ?? '', + agent_id: agent_id ?? '', + metadata: { + ...mcp?.metadata, + icon: base64String, + }, + }); + }; + reader.readAsDataURL(file); + } + }; + + return ( +
+ {/* Icon Picker */} +
+ +
+ {/* name, description, url */} +
+
+ + + {errors.name && ( + {localize('com_ui_field_required')} + )} +
+
+ + +
+
+ + + {errors.url && ( + + {errors.url.type === 'required' + ? localize('com_ui_field_required') + : errors.url.message} + + )} +
+ +
+ ( + + )} + /> + +
+ {errors.trust && ( + {localize('com_ui_field_required')} + )} +
+ +
+ +
+ + {showTools && mcp?.metadata.tools && ( +
+
+

+ {localize('com_ui_available_tools')} +

+ +
+
+ {mcp.metadata.tools.map((tool) => ( + + ))} +
+
+ )} +
+ ); +} diff --git a/client/src/components/SidePanel/Agents/MCPPanel.tsx b/client/src/components/SidePanel/Agents/MCPPanel.tsx new file mode 100644 index 0000000000..016e445e4f --- /dev/null +++ b/client/src/components/SidePanel/Agents/MCPPanel.tsx @@ -0,0 +1,172 @@ +import { useEffect } from 'react'; +import { ChevronLeft } from 'lucide-react'; +import { useForm, FormProvider } from 'react-hook-form'; +import { useAgentPanelContext } from '~/Providers/AgentPanelContext'; +import { OGDialog, OGDialogTrigger, Label } from '~/components/ui'; +import OGDialogTemplate from '~/components/ui/OGDialogTemplate'; +import { defaultMCPFormValues } from '~/common/mcp'; +import useLocalize from '~/hooks/useLocalize'; +import { useToastContext } from '~/Providers'; +import { TrashIcon } from '~/components/svg'; +import type { MCPForm } from '~/common'; +import MCPInput from './MCPInput'; +import { Panel } from '~/common'; +import { + AuthTypeEnum, + AuthorizationTypeEnum, + TokenExchangeMethodEnum, +} from 'librechat-data-provider'; +// TODO: Add MCP delete (for now mocked for ui) +// import { useDeleteAgentMCP } from '~/data-provider'; + +function useDeleteAgentMCP({ + onSuccess, + onError, +}: { + onSuccess: () => void; + onError: (error: Error) => void; +}) { + return { + mutate: async ({ mcp_id, agent_id }: { mcp_id: string; agent_id: string }) => { + try { + console.log('Mock delete MCP:', { mcp_id, agent_id }); + onSuccess(); + } catch (error) { + onError(error as Error); + } + }, + }; +} + +export default function MCPPanel() { + const localize = useLocalize(); + const { showToast } = useToastContext(); + const { mcp, setMcp, agent_id, setActivePanel } = useAgentPanelContext(); + const deleteAgentMCP = useDeleteAgentMCP({ + onSuccess: () => { + showToast({ + message: localize('com_ui_delete_mcp_success'), + status: 'success', + }); + setActivePanel(Panel.builder); + setMcp(undefined); + }, + onError(error) { + showToast({ + message: (error as Error).message ?? localize('com_ui_delete_mcp_error'), + status: 'error', + }); + }, + }); + + const methods = useForm({ + defaultValues: defaultMCPFormValues, + }); + + const { reset } = methods; + + useEffect(() => { + if (mcp) { + const formData = { + icon: mcp.metadata.icon ?? '', + name: mcp.metadata.name ?? '', + description: mcp.metadata.description ?? '', + url: mcp.metadata.url ?? '', + tools: mcp.metadata.tools ?? [], + trust: mcp.metadata.trust ?? false, + }; + + if (mcp.metadata.auth) { + Object.assign(formData, { + type: mcp.metadata.auth.type || AuthTypeEnum.None, + saved_auth_fields: false, + api_key: mcp.metadata.api_key ?? '', + authorization_type: mcp.metadata.auth.authorization_type || AuthorizationTypeEnum.Basic, + oauth_client_id: mcp.metadata.oauth_client_id ?? '', + oauth_client_secret: mcp.metadata.oauth_client_secret ?? '', + authorization_url: mcp.metadata.auth.authorization_url ?? '', + client_url: mcp.metadata.auth.client_url ?? '', + scope: mcp.metadata.auth.scope ?? '', + token_exchange_method: + mcp.metadata.auth.token_exchange_method ?? TokenExchangeMethodEnum.DefaultPost, + }); + } + + reset(formData); + } + }, [mcp, reset]); + + return ( + +
+
+
+
+ +
+ + {!!mcp && ( + + +
+ +
+
+ + {localize('com_ui_delete_mcp_confirm')} + + } + selection={{ + selectHandler: () => { + if (!agent_id) { + return showToast({ + message: localize('com_agents_no_agent_id_error'), + status: 'error', + }); + } + deleteAgentMCP.mutate({ + mcp_id: mcp.mcp_id, + agent_id, + }); + }, + selectClasses: + 'bg-red-700 dark:bg-red-600 hover:bg-red-800 dark:hover:bg-red-800 transition-color duration-200 text-white', + selectText: localize('com_ui_delete'), + }} + /> +
+ )} + +
+ {mcp ? localize('com_ui_edit_mcp_server') : localize('com_ui_add_mcp_server')} +
+
{localize('com_agents_mcp_info')}
+
+ +
+
+
+ ); +} diff --git a/client/src/components/SidePanel/Agents/MCPSection.tsx b/client/src/components/SidePanel/Agents/MCPSection.tsx new file mode 100644 index 0000000000..4575eb8f25 --- /dev/null +++ b/client/src/components/SidePanel/Agents/MCPSection.tsx @@ -0,0 +1,57 @@ +import { useCallback } from 'react'; +import { useLocalize } from '~/hooks'; +import { useToastContext } from '~/Providers'; +import { useAgentPanelContext } from '~/Providers/AgentPanelContext'; +import MCP from '~/components/SidePanel/Builder/MCP'; +import { Panel } from '~/common'; + +export default function MCPSection() { + const { showToast } = useToastContext(); + const localize = useLocalize(); + const { mcps = [], agent_id, setMcp, setActivePanel } = useAgentPanelContext(); + + const handleAddMCP = useCallback(() => { + if (!agent_id) { + showToast({ + message: localize('com_agents_mcps_disabled'), + status: 'warning', + }); + return; + } + setActivePanel(Panel.mcp); + }, [agent_id, setActivePanel, showToast, localize]); + + return ( +
+ +
+ {mcps + .filter((mcp) => mcp.agent_id === agent_id) + .map((mcp, i) => ( + { + setMcp(mcp); + setActivePanel(Panel.mcp); + }} + /> + ))} +
+ +
+
+
+ ); +} diff --git a/client/src/components/SidePanel/Agents/ModelPanel.tsx b/client/src/components/SidePanel/Agents/ModelPanel.tsx index 9b4b12cf67..3987f24dcf 100644 --- a/client/src/components/SidePanel/Agents/ModelPanel.tsx +++ b/client/src/components/SidePanel/Agents/ModelPanel.tsx @@ -1,27 +1,28 @@ +import keyBy from 'lodash/keyBy'; import React, { useMemo, useEffect } from 'react'; import { ChevronLeft, RotateCcw } from 'lucide-react'; import { useFormContext, useWatch, Controller } from 'react-hook-form'; +import { componentMapping } from '~/components/SidePanel/Parameters/components'; import { alternateName, getSettingsKeys, + LocalStorageKeys, SettingDefinition, agentParamSettings, } from 'librechat-data-provider'; import type * as t from 'librechat-data-provider'; import type { AgentForm, AgentModelPanelProps, StringOption } from '~/common'; -import { componentMapping } from '~/components/SidePanel/Parameters/components'; import ControlCombobox from '~/components/ui/ControlCombobox'; import { useGetEndpointsQuery } from '~/data-provider'; import { getEndpointField, cn } from '~/utils'; import { useLocalize } from '~/hooks'; import { Panel } from '~/common'; -import keyBy from 'lodash/keyBy'; export default function ModelPanel({ - setActivePanel, providers, + setActivePanel, models: modelsData, -}: AgentModelPanelProps) { +}: Pick) { const localize = useLocalize(); const { control, setValue } = useFormContext(); @@ -50,6 +51,8 @@ export default function ModelPanel({ const newModels = modelsData[provider] ?? []; setValue('model', newModels[0] ?? ''); } + localStorage.setItem(LocalStorageKeys.LAST_AGENT_MODEL, _model); + localStorage.setItem(LocalStorageKeys.LAST_AGENT_PROVIDER, provider); } if (provider && !_model) { diff --git a/client/src/components/SidePanel/Agents/Version/VersionPanel.tsx b/client/src/components/SidePanel/Agents/Version/VersionPanel.tsx index 9d76aba75a..0f89199216 100644 --- a/client/src/components/SidePanel/Agents/Version/VersionPanel.tsx +++ b/client/src/components/SidePanel/Agents/Version/VersionPanel.tsx @@ -1,12 +1,12 @@ -import type { Agent, TAgentsEndpoint } from 'librechat-data-provider'; import { ChevronLeft } from 'lucide-react'; import { useCallback, useMemo } from 'react'; -import type { AgentPanelProps } from '~/common'; -import { Panel } from '~/common'; import { useGetAgentByIdQuery, useRevertAgentVersionMutation } from '~/data-provider'; +import type { Agent } from 'librechat-data-provider'; +import { isActiveVersion } from './isActiveVersion'; +import { useAgentPanelContext } from '~/Providers'; import { useLocalize, useToast } from '~/hooks'; import VersionContent from './VersionContent'; -import { isActiveVersion } from './isActiveVersion'; +import { Panel } from '~/common'; export type VersionRecord = Record; @@ -39,15 +39,13 @@ export interface AgentWithVersions extends Agent { versions?: Array; } -export type VersionPanelProps = { - agentsConfig: TAgentsEndpoint | null; - setActivePanel: AgentPanelProps['setActivePanel']; - selectedAgentId?: string; -}; - -export default function VersionPanel({ setActivePanel, selectedAgentId = '' }: VersionPanelProps) { +export default function VersionPanel() { const localize = useLocalize(); const { showToast } = useToast(); + const { agent_id, setActivePanel } = useAgentPanelContext(); + + const selectedAgentId = agent_id ?? ''; + const { data: agent, isLoading, diff --git a/client/src/components/SidePanel/Agents/Version/__tests__/VersionPanel.spec.tsx b/client/src/components/SidePanel/Agents/Version/__tests__/VersionPanel.spec.tsx index d9cf31eff3..3258de3d66 100644 --- a/client/src/components/SidePanel/Agents/Version/__tests__/VersionPanel.spec.tsx +++ b/client/src/components/SidePanel/Agents/Version/__tests__/VersionPanel.spec.tsx @@ -55,13 +55,18 @@ jest.mock('~/hooks', () => ({ useToast: jest.fn(() => ({ showToast: jest.fn() })), })); +// Mock the AgentPanelContext +jest.mock('~/Providers/AgentPanelContext', () => ({ + ...jest.requireActual('~/Providers/AgentPanelContext'), + useAgentPanelContext: jest.fn(), +})); + describe('VersionPanel', () => { const mockSetActivePanel = jest.fn(); - const defaultProps = { - agentsConfig: null, - setActivePanel: mockSetActivePanel, - selectedAgentId: 'agent-123', - }; + const mockUseAgentPanelContext = jest.requireMock( + '~/Providers/AgentPanelContext', + ).useAgentPanelContext; + const mockUseGetAgentByIdQuery = jest.requireMock('~/data-provider').useGetAgentByIdQuery; beforeEach(() => { @@ -72,10 +77,17 @@ describe('VersionPanel', () => { error: null, refetch: jest.fn(), }); + + // Set up the default context mock + mockUseAgentPanelContext.mockReturnValue({ + setActivePanel: mockSetActivePanel, + agent_id: 'agent-123', + activePanel: Panel.version, + }); }); test('renders panel UI and handles navigation', () => { - render(); + render(); expect(screen.getByText('com_ui_agent_version_history')).toBeInTheDocument(); expect(screen.getByTestId('version-content')).toBeInTheDocument(); @@ -84,7 +96,7 @@ describe('VersionPanel', () => { }); test('VersionContent receives correct props', () => { - render(); + render(); expect(VersionContent).toHaveBeenCalledWith( expect.objectContaining({ selectedAgentId: 'agent-123', @@ -101,19 +113,31 @@ describe('VersionPanel', () => { }); test('handles data state variations', () => { - render(); + // Test with empty agent_id + mockUseAgentPanelContext.mockReturnValueOnce({ + setActivePanel: mockSetActivePanel, + agent_id: '', + activePanel: Panel.version, + }); + render(); expect(VersionContent).toHaveBeenCalledWith( expect.objectContaining({ selectedAgentId: '' }), expect.anything(), ); + // Test with null data mockUseGetAgentByIdQuery.mockReturnValueOnce({ data: null, isLoading: false, error: null, refetch: jest.fn(), }); - render(); + mockUseAgentPanelContext.mockReturnValueOnce({ + setActivePanel: mockSetActivePanel, + agent_id: 'agent-123', + activePanel: Panel.version, + }); + render(); expect(VersionContent).toHaveBeenCalledWith( expect.objectContaining({ versionContext: expect.objectContaining({ @@ -125,13 +149,14 @@ describe('VersionPanel', () => { expect.anything(), ); + // 3. versions is undefined mockUseGetAgentByIdQuery.mockReturnValueOnce({ data: { ...mockAgentData, versions: undefined }, isLoading: false, error: null, refetch: jest.fn(), }); - render(); + render(); expect(VersionContent).toHaveBeenCalledWith( expect.objectContaining({ versionContext: expect.objectContaining({ versions: [] }), @@ -139,18 +164,20 @@ describe('VersionPanel', () => { expect.anything(), ); + // 4. loading state mockUseGetAgentByIdQuery.mockReturnValueOnce({ data: null, isLoading: true, error: null, refetch: jest.fn(), }); - render(); + render(); expect(VersionContent).toHaveBeenCalledWith( expect.objectContaining({ isLoading: true }), expect.anything(), ); + // 5. error state const testError = new Error('Test error'); mockUseGetAgentByIdQuery.mockReturnValueOnce({ data: null, @@ -158,7 +185,7 @@ describe('VersionPanel', () => { error: testError, refetch: jest.fn(), }); - render(); + render(); expect(VersionContent).toHaveBeenCalledWith( expect.objectContaining({ error: testError }), expect.anything(), @@ -173,7 +200,7 @@ describe('VersionPanel', () => { refetch: jest.fn(), }); - render(); + render(); expect(VersionContent).toHaveBeenCalledWith( expect.objectContaining({ versionContext: expect.objectContaining({ diff --git a/client/src/components/SidePanel/Bookmarks/BookmarkTable.tsx b/client/src/components/SidePanel/Bookmarks/BookmarkTable.tsx index 12de8b452d..0506b3b0dd 100644 --- a/client/src/components/SidePanel/Bookmarks/BookmarkTable.tsx +++ b/client/src/components/SidePanel/Bookmarks/BookmarkTable.tsx @@ -80,13 +80,13 @@ const BookmarkTable = () => { -
{localize('com_ui_bookmarks_title')}
+
{localize('com_ui_bookmarks_title')}
-
{localize('com_ui_bookmarks_count')}
+
{localize('com_ui_bookmarks_count')}
-
{localize('com_assistants_actions')}
+
{localize('com_assistants_actions')}
diff --git a/client/src/components/SidePanel/Bookmarks/BookmarkTableRow.tsx b/client/src/components/SidePanel/Bookmarks/BookmarkTableRow.tsx index ea7a039692..42fea897b3 100644 --- a/client/src/components/SidePanel/Bookmarks/BookmarkTableRow.tsx +++ b/client/src/components/SidePanel/Bookmarks/BookmarkTableRow.tsx @@ -30,6 +30,12 @@ const BookmarkTableRow: React.FC = ({ row, moveRow, posit mutation.mutate( { ...row, position: item.index }, { + onSuccess: () => { + showToast({ + message: localize('com_ui_bookmarks_update_success'), + severity: NotificationSeverity.SUCCESS, + }); + }, onError: () => { showToast({ message: localize('com_ui_bookmarks_update_error'), @@ -44,7 +50,9 @@ const BookmarkTableRow: React.FC = ({ row, moveRow, posit accept: 'bookmark', drop: handleDrop, hover(item: DragItem) { - if (!ref.current || item.index === position) {return;} + if (!ref.current || item.index === position) { + return; + } moveRow(item.index, position); item.index = position; }, diff --git a/client/src/components/SidePanel/Builder/MCP.tsx b/client/src/components/SidePanel/Builder/MCP.tsx new file mode 100644 index 0000000000..f632b6c00d --- /dev/null +++ b/client/src/components/SidePanel/Builder/MCP.tsx @@ -0,0 +1,60 @@ +import { useState } from 'react'; +import type { MCP } from 'librechat-data-provider'; +import GearIcon from '~/components/svg/GearIcon'; +import MCPIcon from '~/components/svg/MCPIcon'; +import { cn } from '~/utils'; + +type MCPProps = { + mcp: MCP; + onClick: () => void; +}; + +export default function MCP({ mcp, onClick }: MCPProps) { + const [isHovering, setIsHovering] = useState(false); + + return ( +
{ + if (e.key === 'Enter' || e.key === ' ') { + onClick(); + } + }} + className="group flex w-full rounded-lg border border-border-medium text-sm hover:cursor-pointer focus:outline-none focus:ring-2 focus:ring-text-primary" + onMouseEnter={() => setIsHovering(true)} + onMouseLeave={() => setIsHovering(false)} + aria-label={`MCP for ${mcp.metadata.name}`} + > +
+ {mcp.metadata.icon ? ( + {`${mcp.metadata.name} + ) : ( +
+ +
+ )} +
+ {mcp.metadata.name} +
+
+
+
+
+ ); +} diff --git a/client/src/components/SidePanel/Builder/MCPAuth.tsx b/client/src/components/SidePanel/Builder/MCPAuth.tsx new file mode 100644 index 0000000000..4ea3faae61 --- /dev/null +++ b/client/src/components/SidePanel/Builder/MCPAuth.tsx @@ -0,0 +1,55 @@ +import { useEffect } from 'react'; +import { FormProvider, useForm } from 'react-hook-form'; +import ActionsAuth from '~/components/SidePanel/Builder/ActionsAuth'; +import { + AuthorizationTypeEnum, + TokenExchangeMethodEnum, + AuthTypeEnum, +} from 'librechat-data-provider'; + +export default function MCPAuth() { + // Create a separate form for auth + const authMethods = useForm({ + defaultValues: { + /* General */ + type: AuthTypeEnum.None, + saved_auth_fields: false, + /* API key */ + api_key: '', + authorization_type: AuthorizationTypeEnum.Basic, + custom_auth_header: '', + /* OAuth */ + oauth_client_id: '', + oauth_client_secret: '', + authorization_url: '', + client_url: '', + scope: '', + token_exchange_method: TokenExchangeMethodEnum.DefaultPost, + }, + }); + + const { watch, setValue } = authMethods; + const type = watch('type'); + + // Sync form state when auth type changes + useEffect(() => { + if (type === 'none') { + // Reset auth fields when type is none + setValue('api_key', ''); + setValue('authorization_type', AuthorizationTypeEnum.Basic); + setValue('custom_auth_header', ''); + setValue('oauth_client_id', ''); + setValue('oauth_client_secret', ''); + setValue('authorization_url', ''); + setValue('client_url', ''); + setValue('scope', ''); + setValue('token_exchange_method', TokenExchangeMethodEnum.DefaultPost); + } + }, [type, setValue]); + + return ( + + + + ); +} diff --git a/client/src/components/SidePanel/MCP/MCPPanel.tsx b/client/src/components/SidePanel/MCP/MCPPanel.tsx new file mode 100644 index 0000000000..aa2bf72112 --- /dev/null +++ b/client/src/components/SidePanel/MCP/MCPPanel.tsx @@ -0,0 +1,253 @@ +import React, { useState, useCallback, useMemo, useEffect } from 'react'; +import { ChevronLeft } from 'lucide-react'; +import { Constants } from 'librechat-data-provider'; +import { useForm, Controller } from 'react-hook-form'; +import { useUpdateUserPluginsMutation } from 'librechat-data-provider/react-query'; +import type { TUpdateUserPlugins } from 'librechat-data-provider'; +import { Button, Input, Label } from '~/components/ui'; +import { useGetStartupConfig } from '~/data-provider'; +import MCPPanelSkeleton from './MCPPanelSkeleton'; +import { useToastContext } from '~/Providers'; +import { useLocalize } from '~/hooks'; + +interface ServerConfigWithVars { + serverName: string; + config: { + customUserVars: Record; + }; +} + +export default function MCPPanel() { + const localize = useLocalize(); + const { showToast } = useToastContext(); + const { data: startupConfig, isLoading: startupConfigLoading } = useGetStartupConfig(); + const [selectedServerNameForEditing, setSelectedServerNameForEditing] = useState( + null, + ); + + const mcpServerDefinitions = useMemo(() => { + if (!startupConfig?.mcpServers) { + return []; + } + return Object.entries(startupConfig.mcpServers) + .filter( + ([, serverConfig]) => + serverConfig.customUserVars && Object.keys(serverConfig.customUserVars).length > 0, + ) + .map(([serverName, config]) => ({ + serverName, + iconPath: null, + config: { + ...config, + customUserVars: config.customUserVars ?? {}, + }, + })); + }, [startupConfig?.mcpServers]); + + const updateUserPluginsMutation = useUpdateUserPluginsMutation({ + onSuccess: () => { + showToast({ message: localize('com_nav_mcp_vars_updated'), status: 'success' }); + }, + onError: (error) => { + console.error('Error updating MCP custom user variables:', error); + showToast({ + message: localize('com_nav_mcp_vars_update_error'), + status: 'error', + }); + }, + }); + + const handleSaveServerVars = useCallback( + (serverName: string, updatedValues: Record) => { + const payload: TUpdateUserPlugins = { + pluginKey: `${Constants.mcp_prefix}${serverName}`, + action: 'install', // 'install' action is used to set/update credentials/variables + auth: updatedValues, + }; + updateUserPluginsMutation.mutate(payload); + }, + [updateUserPluginsMutation], + ); + + const handleRevokeServerVars = useCallback( + (serverName: string) => { + const payload: TUpdateUserPlugins = { + pluginKey: `${Constants.mcp_prefix}${serverName}`, + action: 'uninstall', // 'uninstall' action clears the variables + auth: {}, // Empty auth for uninstall + }; + updateUserPluginsMutation.mutate(payload); + }, + [updateUserPluginsMutation], + ); + + const handleServerClickToEdit = (serverName: string) => { + setSelectedServerNameForEditing(serverName); + }; + + const handleGoBackToList = () => { + setSelectedServerNameForEditing(null); + }; + + if (startupConfigLoading) { + return ; + } + + if (mcpServerDefinitions.length === 0) { + return ( +
+ {localize('com_sidepanel_mcp_no_servers_with_vars')} +
+ ); + } + + if (selectedServerNameForEditing) { + // Editing View + const serverBeingEdited = mcpServerDefinitions.find( + (s) => s.serverName === selectedServerNameForEditing, + ); + + if (!serverBeingEdited) { + // Fallback to list view if server not found + setSelectedServerNameForEditing(null); + return ( +
+ {localize('com_ui_error')}: {localize('com_ui_mcp_server_not_found')} +
+ ); + } + + return ( +
+ +

+ {localize('com_sidepanel_mcp_variables_for', { '0': serverBeingEdited.serverName })} +

+ +
+ ); + } else { + // Server List View + return ( +
+
+ {mcpServerDefinitions.map((server) => ( + + ))} +
+
+ ); + } +} + +// Inner component for the form - remains the same +interface MCPVariableEditorProps { + server: ServerConfigWithVars; + onSave: (serverName: string, updatedValues: Record) => void; + onRevoke: (serverName: string) => void; + isSubmitting: boolean; +} + +function MCPVariableEditor({ server, onSave, onRevoke, isSubmitting }: MCPVariableEditorProps) { + const localize = useLocalize(); + + const { + control, + handleSubmit, + reset, + formState: { errors, isDirty }, + } = useForm>({ + defaultValues: {}, // Initialize empty, will be reset by useEffect + }); + + useEffect(() => { + // Always initialize with empty strings based on the schema + const initialFormValues = Object.keys(server.config.customUserVars).reduce( + (acc, key) => { + acc[key] = ''; + return acc; + }, + {} as Record, + ); + reset(initialFormValues); + }, [reset, server.config.customUserVars]); + + const onFormSubmit = (data: Record) => { + onSave(server.serverName, data); + }; + + const handleRevokeClick = () => { + onRevoke(server.serverName); + }; + + return ( +
+ {Object.entries(server.config.customUserVars).map(([key, details]) => ( +
+ + ( + + )} + /> + {details.description && ( +

+ )} + {errors[key] &&

{errors[key]?.message}

} +
+ ))} +
+ {Object.keys(server.config.customUserVars).length > 0 && ( + + )} + +
+
+ ); +} diff --git a/client/src/components/SidePanel/MCP/MCPPanelSkeleton.tsx b/client/src/components/SidePanel/MCP/MCPPanelSkeleton.tsx new file mode 100644 index 0000000000..61afbfcc2f --- /dev/null +++ b/client/src/components/SidePanel/MCP/MCPPanelSkeleton.tsx @@ -0,0 +1,21 @@ +import React from 'react'; +import { Skeleton } from '~/components/ui'; + +export default function MCPPanelSkeleton() { + return ( +
+ {[1, 2].map((serverIdx) => ( +
+ {/* Server Name */} + {[1, 2].map((varIdx) => ( +
+ {/* Variable Title */} + {/* Input Field */} + {/* Description */} +
+ ))} +
+ ))} +
+ ); +} diff --git a/client/src/components/SidePanel/Memories/AdminSettings.tsx b/client/src/components/SidePanel/Memories/AdminSettings.tsx new file mode 100644 index 0000000000..fcb347228d --- /dev/null +++ b/client/src/components/SidePanel/Memories/AdminSettings.tsx @@ -0,0 +1,212 @@ +import * as Ariakit from '@ariakit/react'; +import { useMemo, useEffect, useState } from 'react'; +import { ShieldEllipsis } from 'lucide-react'; +import { useForm, Controller } from 'react-hook-form'; +import { Permissions, SystemRoles, roleDefaults, PermissionTypes } from 'librechat-data-provider'; +import type { Control, UseFormSetValue, UseFormGetValues } from 'react-hook-form'; +import { OGDialog, OGDialogTitle, OGDialogContent, OGDialogTrigger } from '~/components/ui'; +import { useUpdateMemoryPermissionsMutation } from '~/data-provider'; +import { Button, Switch, DropdownPopup } from '~/components/ui'; +import { useLocalize, useAuthContext } from '~/hooks'; +import { useToastContext } from '~/Providers'; + +type FormValues = Record; + +type LabelControllerProps = { + label: string; + memoryPerm: Permissions; + control: Control; + setValue: UseFormSetValue; + getValues: UseFormGetValues; +}; + +const LabelController: React.FC = ({ control, memoryPerm, label }) => ( +
+ {label} + ( + + )} + /> +
+); + +const AdminSettings = () => { + const localize = useLocalize(); + const { user, roles } = useAuthContext(); + const { showToast } = useToastContext(); + const { mutate, isLoading } = useUpdateMemoryPermissionsMutation({ + onSuccess: () => { + showToast({ status: 'success', message: localize('com_ui_saved') }); + }, + onError: () => { + showToast({ status: 'error', message: localize('com_ui_error_save_admin_settings') }); + }, + }); + + const [isRoleMenuOpen, setIsRoleMenuOpen] = useState(false); + const [selectedRole, setSelectedRole] = useState(SystemRoles.USER); + + const defaultValues = useMemo(() => { + if (roles?.[selectedRole]?.permissions) { + return roles?.[selectedRole]?.permissions?.[PermissionTypes.MEMORIES]; + } + return roleDefaults[selectedRole].permissions[PermissionTypes.MEMORIES]; + }, [roles, selectedRole]); + + const { + reset, + control, + setValue, + getValues, + handleSubmit, + formState: { isSubmitting }, + } = useForm({ + mode: 'onChange', + defaultValues, + }); + + useEffect(() => { + if (roles?.[selectedRole]?.permissions?.[PermissionTypes.MEMORIES]) { + reset(roles?.[selectedRole]?.permissions?.[PermissionTypes.MEMORIES]); + } else { + reset(roleDefaults[selectedRole].permissions[PermissionTypes.MEMORIES]); + } + }, [roles, selectedRole, reset]); + + if (user?.role !== SystemRoles.ADMIN) { + return null; + } + + const labelControllerData = [ + { + memoryPerm: Permissions.USE, + label: localize('com_ui_memories_allow_use'), + }, + { + memoryPerm: Permissions.CREATE, + label: localize('com_ui_memories_allow_create'), + }, + { + memoryPerm: Permissions.UPDATE, + label: localize('com_ui_memories_allow_update'), + }, + { + memoryPerm: Permissions.READ, + label: localize('com_ui_memories_allow_read'), + }, + { + memoryPerm: Permissions.OPT_OUT, + label: localize('com_ui_memories_allow_opt_out'), + }, + ]; + + const onSubmit = (data: FormValues) => { + mutate({ roleName: selectedRole, updates: data }); + }; + + const roleDropdownItems = [ + { + label: SystemRoles.USER, + onClick: () => { + setSelectedRole(SystemRoles.USER); + }, + }, + { + label: SystemRoles.ADMIN, + onClick: () => { + setSelectedRole(SystemRoles.ADMIN); + }, + }, + ]; + + return ( + + + + + + {`${localize('com_ui_admin_settings')} - ${localize( + 'com_ui_memories', + )}`} +
+ {/* Role selection dropdown */} +
+ {localize('com_ui_role_select')}: + + {selectedRole} + + } + items={roleDropdownItems} + itemClassName="items-center justify-center" + sameWidth={true} + /> +
+ {/* Permissions form */} +
+
+ {labelControllerData.map(({ memoryPerm, label }) => ( +
+ + {selectedRole === SystemRoles.ADMIN && memoryPerm === Permissions.USE && ( + <> +
+ {localize('com_ui_admin_access_warning')} + {'\n'} + + {localize('com_ui_more_info')} + +
+ + )} +
+ ))} +
+
+ +
+
+
+
+
+ ); +}; + +export default AdminSettings; diff --git a/client/src/components/SidePanel/Memories/MemoryCreateDialog.tsx b/client/src/components/SidePanel/Memories/MemoryCreateDialog.tsx new file mode 100644 index 0000000000..1670ba6f60 --- /dev/null +++ b/client/src/components/SidePanel/Memories/MemoryCreateDialog.tsx @@ -0,0 +1,147 @@ +import React, { useState } from 'react'; +import { PermissionTypes, Permissions } from 'librechat-data-provider'; +import { OGDialog, OGDialogTemplate, Button, Label, Input } from '~/components/ui'; +import { useCreateMemoryMutation } from '~/data-provider'; +import { useLocalize, useHasAccess } from '~/hooks'; +import { useToastContext } from '~/Providers'; +import { Spinner } from '~/components/svg'; + +interface MemoryCreateDialogProps { + open: boolean; + onOpenChange: (open: boolean) => void; + children: React.ReactNode; + triggerRef?: React.MutableRefObject; +} + +export default function MemoryCreateDialog({ + open, + onOpenChange, + children, + triggerRef, +}: MemoryCreateDialogProps) { + const localize = useLocalize(); + const { showToast } = useToastContext(); + + const hasCreateAccess = useHasAccess({ + permissionType: PermissionTypes.MEMORIES, + permission: Permissions.CREATE, + }); + + const { mutate: createMemory, isLoading } = useCreateMemoryMutation({ + onSuccess: () => { + showToast({ + message: localize('com_ui_memory_created'), + status: 'success', + }); + onOpenChange(false); + setKey(''); + setValue(''); + setTimeout(() => { + triggerRef?.current?.focus(); + }, 0); + }, + onError: (error: Error) => { + let errorMessage = localize('com_ui_error'); + + if (error && typeof error === 'object' && 'response' in error) { + const axiosError = error as any; + if (axiosError.response?.data?.error) { + errorMessage = axiosError.response.data.error; + + // Check for duplicate key error + if (axiosError.response?.status === 409 || errorMessage.includes('already exists')) { + errorMessage = localize('com_ui_memory_key_exists'); + } + } + } else if (error.message) { + errorMessage = error.message; + } + + showToast({ + message: errorMessage, + status: 'error', + }); + }, + }); + + const [key, setKey] = useState(''); + const [value, setValue] = useState(''); + + const handleSave = () => { + if (!hasCreateAccess) { + return; + } + + if (!key.trim() || !value.trim()) { + showToast({ + message: localize('com_ui_field_required'), + status: 'error', + }); + return; + } + + createMemory({ + key: key.trim(), + value: value.trim(), + }); + }; + + const handleKeyPress = (e: React.KeyboardEvent) => { + if (e.key === 'Enter' && e.ctrlKey && hasCreateAccess) { + handleSave(); + } + }; + + return ( + + {children} + +
+ + setKey(e.target.value)} + onKeyDown={handleKeyPress} + placeholder={localize('com_ui_enter_key')} + className="w-full" + /> +
+
+ +