diff --git a/.env.example b/.env.example index 9864a41482..a6ff6157ce 100644 --- a/.env.example +++ b/.env.example @@ -47,6 +47,10 @@ TRUST_PROXY=1 # password policies. # MIN_PASSWORD_LENGTH=8 +# When enabled, the app will continue running after encountering uncaught exceptions +# instead of exiting the process. Not recommended for production unless necessary. +# CONTINUE_ON_UNCAUGHT_EXCEPTION=false + #===============# # JSON Logging # #===============# @@ -87,6 +91,16 @@ NODE_MAX_OLD_SPACE_SIZE=6144 # CONFIG_PATH="/alternative/path/to/librechat.yaml" +#==================# +# Langfuse Tracing # +#==================# + +# Get Langfuse API keys for your project from the project settings page: https://cloud.langfuse.com + +# LANGFUSE_PUBLIC_KEY= +# LANGFUSE_SECRET_KEY= +# LANGFUSE_BASE_URL= + #===================================================# # Endpoints # #===================================================# @@ -121,7 +135,7 @@ PROXY= #============# ANTHROPIC_API_KEY=user_provided -# ANTHROPIC_MODELS=claude-opus-4-20250514,claude-sonnet-4-20250514,claude-3-7-sonnet-20250219,claude-3-5-sonnet-20241022,claude-3-5-haiku-20241022,claude-3-opus-20240229,claude-3-sonnet-20240229,claude-3-haiku-20240307 +# ANTHROPIC_MODELS=claude-sonnet-4-6,claude-opus-4-6,claude-opus-4-20250514,claude-sonnet-4-20250514,claude-3-7-sonnet-20250219,claude-3-5-sonnet-20241022,claude-3-5-haiku-20241022,claude-3-opus-20240229,claude-3-sonnet-20240229,claude-3-haiku-20240307 # ANTHROPIC_REVERSE_PROXY= # Set to true to use Anthropic models through Google Vertex AI instead of direct API @@ -156,7 +170,8 @@ ANTHROPIC_API_KEY=user_provided # BEDROCK_AWS_SESSION_TOKEN=someSessionToken # Note: This example list is not meant to be exhaustive. If omitted, all known, supported model IDs will be included for you. -# BEDROCK_AWS_MODELS=anthropic.claude-3-5-sonnet-20240620-v1:0,meta.llama3-1-8b-instruct-v1:0 +# BEDROCK_AWS_MODELS=anthropic.claude-sonnet-4-6,anthropic.claude-opus-4-6-v1,anthropic.claude-3-5-sonnet-20240620-v1:0,meta.llama3-1-8b-instruct-v1:0 +# Cross-region inference model IDs: us.anthropic.claude-sonnet-4-6,us.anthropic.claude-opus-4-6-v1,global.anthropic.claude-opus-4-6-v1 # See all Bedrock model IDs here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html#model-ids-arns @@ -737,8 +752,10 @@ HELP_AND_FAQ_URL=https://librechat.ai # REDIS_PING_INTERVAL=300 # Force specific cache namespaces to use in-memory storage even when Redis is enabled -# Comma-separated list of CacheKeys (e.g., ROLES,MESSAGES) -# FORCED_IN_MEMORY_CACHE_NAMESPACES=ROLES,MESSAGES +# Comma-separated list of CacheKeys +# Defaults to CONFIG_STORE,APP_CONFIG so YAML-derived config stays per-container (safe for blue/green deployments) +# Set to empty string to force all namespaces through Redis: FORCED_IN_MEMORY_CACHE_NAMESPACES= +# FORCED_IN_MEMORY_CACHE_NAMESPACES=CONFIG_STORE,APP_CONFIG # Leader Election Configuration (for multi-instance deployments with Redis) # Duration in seconds that the leader lease is valid before it expires (default: 25) diff --git a/.gitignore b/.gitignore index d173d26b60..86d4a3ddae 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ pids # CI/CD data test-image* +dump.rdb # Directory for instrumented libs generated by jscoverage/JSCover lib-cov @@ -29,6 +30,9 @@ coverage config/translations/stores/* client/src/localization/languages/*_missing_keys.json +# Turborepo +.turbo + # Compiled Dirs (http://nodejs.org/api/addons.html) build/ dist/ diff --git a/Dockerfile b/Dockerfile index 5872440a33..38273bc5eb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -# v0.8.2-rc2 +# v0.8.2 # Base node image FROM node:20-alpine AS node diff --git a/Dockerfile.multi b/Dockerfile.multi index ca66459a44..47e00d0fa8 100644 --- a/Dockerfile.multi +++ b/Dockerfile.multi @@ -1,5 +1,5 @@ # Dockerfile.multi -# v0.8.2-rc2 +# v0.8.2 # Set configurable max-old-space-size with default ARG NODE_MAX_OLD_SPACE_SIZE=6144 diff --git a/README.md b/README.md index a96e47f70f..6e04396637 100644 --- a/README.md +++ b/README.md @@ -109,6 +109,11 @@ - 🎨 **Customizable Interface**: - Customizable Dropdown & Interface that adapts to both power users and newcomers +- 🌊 **[Resumable Streams](https://www.librechat.ai/docs/features/resumable_streams)**: + - Never lose a response: AI responses automatically reconnect and resume if your connection drops + - Multi-Tab & Multi-Device Sync: Open the same chat in multiple tabs or pick up on another device + - Production-Ready: Works from single-server setups to horizontally scaled deployments with Redis + - 🗣️ **Speech & Audio**: - Chat hands-free with Speech-to-Text and Text-to-Speech - Automatically send and play Audio @@ -137,13 +142,11 @@ ## 🪶 All-In-One AI Conversations with LibreChat -LibreChat brings together the future of assistant AIs with the revolutionary technology of OpenAI's ChatGPT. Celebrating the original styling, LibreChat gives you the ability to integrate multiple AI models. It also integrates and enhances original client features such as conversation and message search, prompt templates and plugins. +LibreChat is a self-hosted AI chat platform that unifies all major AI providers in a single, privacy-focused interface. -With LibreChat, you no longer need to opt for ChatGPT Plus and can instead use free or pay-per-call APIs. We welcome contributions, cloning, and forking to enhance the capabilities of this advanced chatbot platform. +Beyond chat, LibreChat provides AI Agents, Model Context Protocol (MCP) support, Artifacts, Code Interpreter, custom actions, conversation search, and enterprise-ready multi-user authentication. -[![Watch the video](https://raw.githubusercontent.com/LibreChat-AI/librechat.ai/main/public/images/changelog/v0.7.6.gif)](https://www.youtube.com/watch?v=ilfwGQtJNlI) - -Click on the thumbnail to open the video☝️ +Open source, actively developed, and built for anyone who values control over their AI infrastructure. --- diff --git a/api/app/clients/specs/BaseClient.test.js b/api/app/clients/specs/BaseClient.test.js index 3cc082ab66..fed80de28c 100644 --- a/api/app/clients/specs/BaseClient.test.js +++ b/api/app/clients/specs/BaseClient.test.js @@ -41,9 +41,9 @@ jest.mock('~/models', () => ({ const { getConvo, saveConvo } = require('~/models'); jest.mock('@librechat/agents', () => { - const { Providers } = jest.requireActual('@librechat/agents'); + const actual = jest.requireActual('@librechat/agents'); return { - Providers, + ...actual, ChatOpenAI: jest.fn().mockImplementation(() => { return {}; }), diff --git a/api/app/clients/tools/manifest.json b/api/app/clients/tools/manifest.json index 9262113501..7930e67ac9 100644 --- a/api/app/clients/tools/manifest.json +++ b/api/app/clients/tools/manifest.json @@ -57,19 +57,6 @@ } ] }, - { - "name": "Browser", - "pluginKey": "web-browser", - "description": "Scrape and summarize webpage data", - "icon": "assets/web-browser.svg", - "authConfig": [ - { - "authField": "OPENAI_API_KEY", - "label": "OpenAI API Key", - "description": "Browser makes use of OpenAI embeddings" - } - ] - }, { "name": "DALL-E-3", "pluginKey": "dalle", diff --git a/api/app/clients/tools/structured/AzureAISearch.js b/api/app/clients/tools/structured/AzureAISearch.js index 55af3cdff5..1815c45e04 100644 --- a/api/app/clients/tools/structured/AzureAISearch.js +++ b/api/app/clients/tools/structured/AzureAISearch.js @@ -1,14 +1,28 @@ -const { z } = require('zod'); const { Tool } = require('@langchain/core/tools'); const { logger } = require('@librechat/data-schemas'); const { SearchClient, AzureKeyCredential } = require('@azure/search-documents'); +const azureAISearchJsonSchema = { + type: 'object', + properties: { + query: { + type: 'string', + description: 'Search word or phrase to Azure AI Search', + }, + }, + required: ['query'], +}; + class AzureAISearch extends Tool { // Constants for default values static DEFAULT_API_VERSION = '2023-11-01'; static DEFAULT_QUERY_TYPE = 'simple'; static DEFAULT_TOP = 5; + static get jsonSchema() { + return azureAISearchJsonSchema; + } + // Helper function for initializing properties _initializeField(field, envVar, defaultValue) { return field || process.env[envVar] || defaultValue; @@ -22,10 +36,7 @@ class AzureAISearch extends Tool { /* Used to initialize the Tool without necessary variables. */ this.override = fields.override ?? false; - // Define schema - this.schema = z.object({ - query: z.string().describe('Search word or phrase to Azure AI Search'), - }); + this.schema = azureAISearchJsonSchema; // Initialize properties using helper function this.serviceEndpoint = this._initializeField( diff --git a/api/app/clients/tools/structured/DALLE3.js b/api/app/clients/tools/structured/DALLE3.js index c44b56f83d..26610f73ba 100644 --- a/api/app/clients/tools/structured/DALLE3.js +++ b/api/app/clients/tools/structured/DALLE3.js @@ -1,4 +1,3 @@ -const { z } = require('zod'); const path = require('path'); const OpenAI = require('openai'); const { v4: uuidv4 } = require('uuid'); @@ -8,6 +7,36 @@ const { logger } = require('@librechat/data-schemas'); const { getImageBasename, extractBaseURL } = require('@librechat/api'); const { FileContext, ContentTypes } = require('librechat-data-provider'); +const dalle3JsonSchema = { + type: 'object', + properties: { + prompt: { + type: 'string', + maxLength: 4000, + description: + 'A text description of the desired image, following the rules, up to 4000 characters.', + }, + style: { + type: 'string', + enum: ['vivid', 'natural'], + description: + 'Must be one of `vivid` or `natural`. `vivid` generates hyper-real and dramatic images, `natural` produces more natural, less hyper-real looking images', + }, + quality: { + type: 'string', + enum: ['hd', 'standard'], + description: 'The quality of the generated image. Only `hd` and `standard` are supported.', + }, + size: { + type: 'string', + enum: ['1024x1024', '1792x1024', '1024x1792'], + description: + 'The size of the requested image. Use 1024x1024 (square) as the default, 1792x1024 if the user requests a wide image, and 1024x1792 for full-body portraits. Always include this parameter in the request.', + }, + }, + required: ['prompt', 'style', 'quality', 'size'], +}; + const displayMessage = "DALL-E displayed an image. All generated images are already plainly visible, so don't repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user."; class DALLE3 extends Tool { @@ -72,27 +101,11 @@ class DALLE3 extends Tool { // The prompt must intricately describe every part of the image in concrete, objective detail. THINK about what the end goal of the description is, and extrapolate that to what would make satisfying images. // All descriptions sent to dalle should be a paragraph of text that is extremely descriptive and detailed. Each should be more than 3 sentences long. // - The "vivid" style is HIGHLY preferred, but "natural" is also supported.`; - this.schema = z.object({ - prompt: z - .string() - .max(4000) - .describe( - 'A text description of the desired image, following the rules, up to 4000 characters.', - ), - style: z - .enum(['vivid', 'natural']) - .describe( - 'Must be one of `vivid` or `natural`. `vivid` generates hyper-real and dramatic images, `natural` produces more natural, less hyper-real looking images', - ), - quality: z - .enum(['hd', 'standard']) - .describe('The quality of the generated image. Only `hd` and `standard` are supported.'), - size: z - .enum(['1024x1024', '1792x1024', '1024x1792']) - .describe( - 'The size of the requested image. Use 1024x1024 (square) as the default, 1792x1024 if the user requests a wide image, and 1024x1792 for full-body portraits. Always include this parameter in the request.', - ), - }); + this.schema = dalle3JsonSchema; + } + + static get jsonSchema() { + return dalle3JsonSchema; } getApiKey() { diff --git a/api/app/clients/tools/structured/FluxAPI.js b/api/app/clients/tools/structured/FluxAPI.js index 9fa08a0343..56f86a707d 100644 --- a/api/app/clients/tools/structured/FluxAPI.js +++ b/api/app/clients/tools/structured/FluxAPI.js @@ -1,4 +1,3 @@ -const { z } = require('zod'); const axios = require('axios'); const fetch = require('node-fetch'); const { v4: uuidv4 } = require('uuid'); @@ -7,6 +6,84 @@ const { logger } = require('@librechat/data-schemas'); const { HttpsProxyAgent } = require('https-proxy-agent'); const { FileContext, ContentTypes } = require('librechat-data-provider'); +const fluxApiJsonSchema = { + type: 'object', + properties: { + action: { + type: 'string', + enum: ['generate', 'list_finetunes', 'generate_finetuned'], + description: + 'Action to perform: "generate" for image generation, "generate_finetuned" for finetuned model generation, "list_finetunes" to get available custom models', + }, + prompt: { + type: 'string', + description: + 'Text prompt for image generation. Required when action is "generate". Not used for list_finetunes.', + }, + width: { + type: 'number', + description: + 'Width of the generated image in pixels. Must be a multiple of 32. Default is 1024.', + }, + height: { + type: 'number', + description: + 'Height of the generated image in pixels. Must be a multiple of 32. Default is 768.', + }, + prompt_upsampling: { + type: 'boolean', + description: 'Whether to perform upsampling on the prompt.', + }, + steps: { + type: 'integer', + description: 'Number of steps to run the model for, a number from 1 to 50. Default is 40.', + }, + seed: { + type: 'number', + description: 'Optional seed for reproducibility.', + }, + safety_tolerance: { + type: 'number', + description: + 'Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict.', + }, + endpoint: { + type: 'string', + enum: [ + '/v1/flux-pro-1.1', + '/v1/flux-pro', + '/v1/flux-dev', + '/v1/flux-pro-1.1-ultra', + '/v1/flux-pro-finetuned', + '/v1/flux-pro-1.1-ultra-finetuned', + ], + description: 'Endpoint to use for image generation.', + }, + raw: { + type: 'boolean', + description: + 'Generate less processed, more natural-looking images. Only works for /v1/flux-pro-1.1-ultra.', + }, + finetune_id: { + type: 'string', + description: 'ID of the finetuned model to use', + }, + finetune_strength: { + type: 'number', + description: 'Strength of the finetuning effect (typically between 0.1 and 1.2)', + }, + guidance: { + type: 'number', + description: 'Guidance scale for finetuned models', + }, + aspect_ratio: { + type: 'string', + description: 'Aspect ratio for ultra models (e.g., "16:9")', + }, + }, + required: [], +}; + const displayMessage = "Flux displayed an image. All generated images are already plainly visible, so don't repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user."; @@ -57,82 +134,11 @@ class FluxAPI extends Tool { // Add base URL from environment variable with fallback this.baseUrl = process.env.FLUX_API_BASE_URL || 'https://api.us1.bfl.ai'; - // Define the schema for structured input - this.schema = z.object({ - action: z - .enum(['generate', 'list_finetunes', 'generate_finetuned']) - .default('generate') - .describe( - 'Action to perform: "generate" for image generation, "generate_finetuned" for finetuned model generation, "list_finetunes" to get available custom models', - ), - prompt: z - .string() - .optional() - .describe( - 'Text prompt for image generation. Required when action is "generate". Not used for list_finetunes.', - ), - width: z - .number() - .optional() - .describe( - 'Width of the generated image in pixels. Must be a multiple of 32. Default is 1024.', - ), - height: z - .number() - .optional() - .describe( - 'Height of the generated image in pixels. Must be a multiple of 32. Default is 768.', - ), - prompt_upsampling: z - .boolean() - .optional() - .default(false) - .describe('Whether to perform upsampling on the prompt.'), - steps: z - .number() - .int() - .optional() - .describe('Number of steps to run the model for, a number from 1 to 50. Default is 40.'), - seed: z.number().optional().describe('Optional seed for reproducibility.'), - safety_tolerance: z - .number() - .optional() - .default(6) - .describe( - 'Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict.', - ), - endpoint: z - .enum([ - '/v1/flux-pro-1.1', - '/v1/flux-pro', - '/v1/flux-dev', - '/v1/flux-pro-1.1-ultra', - '/v1/flux-pro-finetuned', - '/v1/flux-pro-1.1-ultra-finetuned', - ]) - .optional() - .default('/v1/flux-pro-1.1') - .describe('Endpoint to use for image generation.'), - raw: z - .boolean() - .optional() - .default(false) - .describe( - 'Generate less processed, more natural-looking images. Only works for /v1/flux-pro-1.1-ultra.', - ), - finetune_id: z.string().optional().describe('ID of the finetuned model to use'), - finetune_strength: z - .number() - .optional() - .default(1.1) - .describe('Strength of the finetuning effect (typically between 0.1 and 1.2)'), - guidance: z.number().optional().default(2.5).describe('Guidance scale for finetuned models'), - aspect_ratio: z - .string() - .optional() - .default('16:9') - .describe('Aspect ratio for ultra models (e.g., "16:9")'), - }); + this.schema = fluxApiJsonSchema; + } + + static get jsonSchema() { + return fluxApiJsonSchema; } getAxiosConfig() { diff --git a/api/app/clients/tools/structured/GoogleSearch.js b/api/app/clients/tools/structured/GoogleSearch.js index d703d56f83..38f483edf1 100644 --- a/api/app/clients/tools/structured/GoogleSearch.js +++ b/api/app/clients/tools/structured/GoogleSearch.js @@ -1,12 +1,33 @@ -const { z } = require('zod'); const { Tool } = require('@langchain/core/tools'); const { getEnvironmentVariable } = require('@langchain/core/utils/env'); +const googleSearchJsonSchema = { + type: 'object', + properties: { + query: { + type: 'string', + minLength: 1, + description: 'The search query string.', + }, + max_results: { + type: 'integer', + minimum: 1, + maximum: 10, + description: 'The maximum number of search results to return. Defaults to 5.', + }, + }, + required: ['query'], +}; + class GoogleSearchResults extends Tool { static lc_name() { return 'google'; } + static get jsonSchema() { + return googleSearchJsonSchema; + } + constructor(fields = {}) { super(fields); this.name = 'google'; @@ -28,25 +49,11 @@ class GoogleSearchResults extends Tool { this.description = 'A search engine optimized for comprehensive, accurate, and trusted results. Useful for when you need to answer questions about current events.'; - this.schema = z.object({ - query: z.string().min(1).describe('The search query string.'), - max_results: z - .number() - .min(1) - .max(10) - .optional() - .describe('The maximum number of search results to return. Defaults to 10.'), - // Note: Google API has its own parameters for search customization, adjust as needed. - }); + this.schema = googleSearchJsonSchema; } async _call(input) { - const validationResult = this.schema.safeParse(input); - if (!validationResult.success) { - throw new Error(`Validation failed: ${JSON.stringify(validationResult.error.issues)}`); - } - - const { query, max_results = 5 } = validationResult.data; + const { query, max_results = 5 } = input; const response = await fetch( `https://www.googleapis.com/customsearch/v1?key=${this.apiKey}&cx=${ diff --git a/api/app/clients/tools/structured/OpenWeather.js b/api/app/clients/tools/structured/OpenWeather.js index f92fe522ce..38e2b9133c 100644 --- a/api/app/clients/tools/structured/OpenWeather.js +++ b/api/app/clients/tools/structured/OpenWeather.js @@ -1,8 +1,52 @@ const { Tool } = require('@langchain/core/tools'); -const { z } = require('zod'); const { getEnvironmentVariable } = require('@langchain/core/utils/env'); const fetch = require('node-fetch'); +const openWeatherJsonSchema = { + type: 'object', + properties: { + action: { + type: 'string', + enum: ['help', 'current_forecast', 'timestamp', 'daily_aggregation', 'overview'], + description: 'The action to perform', + }, + city: { + type: 'string', + description: 'City name for geocoding if lat/lon not provided', + }, + lat: { + type: 'number', + description: 'Latitude coordinate', + }, + lon: { + type: 'number', + description: 'Longitude coordinate', + }, + exclude: { + type: 'string', + description: 'Parts to exclude from the response', + }, + units: { + type: 'string', + enum: ['Celsius', 'Kelvin', 'Fahrenheit'], + description: 'Temperature units', + }, + lang: { + type: 'string', + description: 'Language code', + }, + date: { + type: 'string', + description: 'Date in YYYY-MM-DD format for timestamp and daily_aggregation', + }, + tz: { + type: 'string', + description: 'Timezone', + }, + }, + required: ['action'], +}; + /** * Map user-friendly units to OpenWeather units. * Defaults to Celsius if not specified. @@ -66,17 +110,11 @@ class OpenWeather extends Tool { 'Units: "Celsius", "Kelvin", or "Fahrenheit" (default: Celsius). ' + 'For timestamp action, use "date" in YYYY-MM-DD format.'; - schema = z.object({ - action: z.enum(['help', 'current_forecast', 'timestamp', 'daily_aggregation', 'overview']), - city: z.string().optional(), - lat: z.number().optional(), - lon: z.number().optional(), - exclude: z.string().optional(), - units: z.enum(['Celsius', 'Kelvin', 'Fahrenheit']).optional(), - lang: z.string().optional(), - date: z.string().optional(), // For timestamp and daily_aggregation - tz: z.string().optional(), - }); + schema = openWeatherJsonSchema; + + static get jsonSchema() { + return openWeatherJsonSchema; + } constructor(fields = {}) { super(); diff --git a/api/app/clients/tools/structured/StableDiffusion.js b/api/app/clients/tools/structured/StableDiffusion.js index 3a1ea831d3..d7a7a4d96b 100644 --- a/api/app/clients/tools/structured/StableDiffusion.js +++ b/api/app/clients/tools/structured/StableDiffusion.js @@ -1,6 +1,5 @@ // Generates image using stable diffusion webui's api (automatic1111) const fs = require('fs'); -const { z } = require('zod'); const path = require('path'); const axios = require('axios'); const sharp = require('sharp'); @@ -11,6 +10,23 @@ const { FileContext, ContentTypes } = require('librechat-data-provider'); const { getBasePath } = require('@librechat/api'); const paths = require('~/config/paths'); +const stableDiffusionJsonSchema = { + type: 'object', + properties: { + prompt: { + type: 'string', + description: + 'Detailed keywords to describe the subject, using at least 7 keywords to accurately describe the image, separated by comma', + }, + negative_prompt: { + type: 'string', + description: + 'Keywords we want to exclude from the final image, using at least 7 keywords to accurately describe the image, separated by comma', + }, + }, + required: ['prompt', 'negative_prompt'], +}; + const displayMessage = "Stable Diffusion displayed an image. All generated images are already plainly visible, so don't repeat the descriptions in detail. Do not list download links as they are available in the UI already. The user may download the images by clicking on them, but do not mention anything about downloading to the user."; @@ -46,18 +62,11 @@ class StableDiffusionAPI extends Tool { // - Generate images only once per human query unless explicitly requested by the user`; this.description = "You can generate images using text with 'stable-diffusion'. This tool is exclusively for visual content."; - this.schema = z.object({ - prompt: z - .string() - .describe( - 'Detailed keywords to describe the subject, using at least 7 keywords to accurately describe the image, separated by comma', - ), - negative_prompt: z - .string() - .describe( - 'Keywords we want to exclude from the final image, using at least 7 keywords to accurately describe the image, separated by comma', - ), - }); + this.schema = stableDiffusionJsonSchema; + } + + static get jsonSchema() { + return stableDiffusionJsonSchema; } replaceNewLinesWithSpaces(inputString) { diff --git a/api/app/clients/tools/structured/TavilySearchResults.js b/api/app/clients/tools/structured/TavilySearchResults.js index 796f31dcca..0faddfb666 100644 --- a/api/app/clients/tools/structured/TavilySearchResults.js +++ b/api/app/clients/tools/structured/TavilySearchResults.js @@ -1,8 +1,75 @@ -const { z } = require('zod'); const { ProxyAgent, fetch } = require('undici'); const { Tool } = require('@langchain/core/tools'); const { getEnvironmentVariable } = require('@langchain/core/utils/env'); +const tavilySearchJsonSchema = { + type: 'object', + properties: { + query: { + type: 'string', + minLength: 1, + description: 'The search query string.', + }, + max_results: { + type: 'number', + minimum: 1, + maximum: 10, + description: 'The maximum number of search results to return. Defaults to 5.', + }, + search_depth: { + type: 'string', + enum: ['basic', 'advanced'], + description: + 'The depth of the search, affecting result quality and response time (`basic` or `advanced`). Default is basic for quick results and advanced for indepth high quality results but longer response time. Advanced calls equals 2 requests.', + }, + include_images: { + type: 'boolean', + description: + 'Whether to include a list of query-related images in the response. Default is False.', + }, + include_answer: { + type: 'boolean', + description: 'Whether to include answers in the search results. Default is False.', + }, + include_raw_content: { + type: 'boolean', + description: 'Whether to include raw content in the search results. Default is False.', + }, + include_domains: { + type: 'array', + items: { type: 'string' }, + description: 'A list of domains to specifically include in the search results.', + }, + exclude_domains: { + type: 'array', + items: { type: 'string' }, + description: 'A list of domains to specifically exclude from the search results.', + }, + topic: { + type: 'string', + enum: ['general', 'news', 'finance'], + description: + 'The category of the search. Use news ONLY if query SPECIFCALLY mentions the word "news".', + }, + time_range: { + type: 'string', + enum: ['day', 'week', 'month', 'year', 'd', 'w', 'm', 'y'], + description: 'The time range back from the current date to filter results.', + }, + days: { + type: 'number', + minimum: 1, + description: 'Number of days back from the current date to include. Only if topic is news.', + }, + include_image_descriptions: { + type: 'boolean', + description: + 'When include_images is true, also add a descriptive text for each image. Default is false.', + }, + }, + required: ['query'], +}; + class TavilySearchResults extends Tool { static lc_name() { return 'TavilySearchResults'; @@ -20,64 +87,11 @@ class TavilySearchResults extends Tool { this.description = 'A search engine optimized for comprehensive, accurate, and trusted results. Useful for when you need to answer questions about current events.'; - this.schema = z.object({ - query: z.string().min(1).describe('The search query string.'), - max_results: z - .number() - .min(1) - .max(10) - .optional() - .describe('The maximum number of search results to return. Defaults to 5.'), - search_depth: z - .enum(['basic', 'advanced']) - .optional() - .describe( - 'The depth of the search, affecting result quality and response time (`basic` or `advanced`). Default is basic for quick results and advanced for indepth high quality results but longer response time. Advanced calls equals 2 requests.', - ), - include_images: z - .boolean() - .optional() - .describe( - 'Whether to include a list of query-related images in the response. Default is False.', - ), - include_answer: z - .boolean() - .optional() - .describe('Whether to include answers in the search results. Default is False.'), - include_raw_content: z - .boolean() - .optional() - .describe('Whether to include raw content in the search results. Default is False.'), - include_domains: z - .array(z.string()) - .optional() - .describe('A list of domains to specifically include in the search results.'), - exclude_domains: z - .array(z.string()) - .optional() - .describe('A list of domains to specifically exclude from the search results.'), - topic: z - .enum(['general', 'news', 'finance']) - .optional() - .describe( - 'The category of the search. Use news ONLY if query SPECIFCALLY mentions the word "news".', - ), - time_range: z - .enum(['day', 'week', 'month', 'year', 'd', 'w', 'm', 'y']) - .optional() - .describe('The time range back from the current date to filter results.'), - days: z - .number() - .min(1) - .optional() - .describe('Number of days back from the current date to include. Only if topic is news.'), - include_image_descriptions: z - .boolean() - .optional() - .describe( - 'When include_images is true, also add a descriptive text for each image. Default is false.', - ), - }); + this.schema = tavilySearchJsonSchema; + } + + static get jsonSchema() { + return tavilySearchJsonSchema; } getApiKey() { @@ -89,12 +103,7 @@ class TavilySearchResults extends Tool { } async _call(input) { - const validationResult = this.schema.safeParse(input); - if (!validationResult.success) { - throw new Error(`Validation failed: ${JSON.stringify(validationResult.error.issues)}`); - } - - const { query, ...rest } = validationResult.data; + const { query, ...rest } = input; const requestBody = { api_key: this.apiKey, diff --git a/api/app/clients/tools/structured/TraversaalSearch.js b/api/app/clients/tools/structured/TraversaalSearch.js index d2ccc35c75..9bc5e399f0 100644 --- a/api/app/clients/tools/structured/TraversaalSearch.js +++ b/api/app/clients/tools/structured/TraversaalSearch.js @@ -1,8 +1,19 @@ -const { z } = require('zod'); const { Tool } = require('@langchain/core/tools'); const { logger } = require('@librechat/data-schemas'); const { getEnvironmentVariable } = require('@langchain/core/utils/env'); +const traversaalSearchJsonSchema = { + type: 'object', + properties: { + query: { + type: 'string', + description: + "A properly written sentence to be interpreted by an AI to search the web according to the user's request.", + }, + }, + required: ['query'], +}; + /** * Tool for the Traversaal AI search API, Ares. */ @@ -17,17 +28,15 @@ class TraversaalSearch extends Tool { Useful for when you need to answer questions about current events. Input should be a search query.`; this.description_for_model = '\'Please create a specific sentence for the AI to understand and use as a query to search the web based on the user\'s request. For example, "Find information about the highest mountains in the world." or "Show me the latest news articles about climate change and its impact on polar ice caps."\''; - this.schema = z.object({ - query: z - .string() - .describe( - "A properly written sentence to be interpreted by an AI to search the web according to the user's request.", - ), - }); + this.schema = traversaalSearchJsonSchema; this.apiKey = fields?.TRAVERSAAL_API_KEY ?? this.getApiKey(); } + static get jsonSchema() { + return traversaalSearchJsonSchema; + } + getApiKey() { const apiKey = getEnvironmentVariable('TRAVERSAAL_API_KEY'); if (!apiKey && this.override) { diff --git a/api/app/clients/tools/structured/Wolfram.js b/api/app/clients/tools/structured/Wolfram.js index 1f7fe6b1b7..196626e39c 100644 --- a/api/app/clients/tools/structured/Wolfram.js +++ b/api/app/clients/tools/structured/Wolfram.js @@ -1,9 +1,19 @@ /* eslint-disable no-useless-escape */ -const { z } = require('zod'); const axios = require('axios'); const { Tool } = require('@langchain/core/tools'); const { logger } = require('@librechat/data-schemas'); +const wolframJsonSchema = { + type: 'object', + properties: { + input: { + type: 'string', + description: 'Natural language query to WolframAlpha following the guidelines', + }, + }, + required: ['input'], +}; + class WolframAlphaAPI extends Tool { constructor(fields) { super(); @@ -41,9 +51,11 @@ class WolframAlphaAPI extends Tool { // -- Do not explain each step unless user input is needed. Proceed directly to making a better API call based on the available assumptions.`; this.description = `WolframAlpha offers computation, math, curated knowledge, and real-time data. It handles natural language queries and performs complex calculations. Follow the guidelines to get the best results.`; - this.schema = z.object({ - input: z.string().describe('Natural language query to WolframAlpha following the guidelines'), - }); + this.schema = wolframJsonSchema; + } + + static get jsonSchema() { + return wolframJsonSchema; } async fetchRawText(url) { diff --git a/api/app/clients/tools/util/fileSearch.js b/api/app/clients/tools/util/fileSearch.js index d48b9b986d..2654722be4 100644 --- a/api/app/clients/tools/util/fileSearch.js +++ b/api/app/clients/tools/util/fileSearch.js @@ -1,4 +1,3 @@ -const { z } = require('zod'); const axios = require('axios'); const { tool } = require('@langchain/core/tools'); const { logger } = require('@librechat/data-schemas'); @@ -7,6 +6,18 @@ const { Tools, EToolResources } = require('librechat-data-provider'); const { filterFilesByAgentAccess } = require('~/server/services/Files/permissions'); const { getFiles } = require('~/models'); +const fileSearchJsonSchema = { + type: 'object', + properties: { + query: { + type: 'string', + description: + "A natural language query to search for relevant information in the files. Be specific and use keywords related to the information you're looking for. The query will be used for semantic similarity matching against the file contents.", + }, + }, + required: ['query'], +}; + /** * * @param {Object} options @@ -182,15 +193,9 @@ Use the EXACT anchor markers shown below (copy them verbatim) immediately after **ALWAYS mention the filename in your text before the citation marker. NEVER use markdown links or footnotes.**` : '' }`, - schema: z.object({ - query: z - .string() - .describe( - "A natural language query to search for relevant information in the files. Be specific and use keywords related to the information you're looking for. The query will be used for semantic similarity matching against the file contents.", - ), - }), + schema: fileSearchJsonSchema, }, ); }; -module.exports = { createFileSearchTool, primeFiles }; +module.exports = { createFileSearchTool, primeFiles, fileSearchJsonSchema }; diff --git a/api/app/clients/tools/util/handleTools.js b/api/app/clients/tools/util/handleTools.js index da4c687b4d..65c88ce83f 100644 --- a/api/app/clients/tools/util/handleTools.js +++ b/api/app/clients/tools/util/handleTools.js @@ -11,6 +11,7 @@ const { mcpToolPattern, loadWebSearchAuth, buildImageToolContext, + buildWebSearchContext, } = require('@librechat/api'); const { getMCPServersRegistry } = require('~/config'); const { @@ -19,7 +20,6 @@ const { Permissions, EToolResources, PermissionTypes, - replaceSpecialVars, } = require('librechat-data-provider'); const { availableTools, @@ -325,24 +325,7 @@ const loadTools = async ({ }); const { onSearchResults, onGetHighlights } = options?.[Tools.web_search] ?? {}; requestedTools[tool] = async () => { - toolContextMap[tool] = `# \`${tool}\`: -Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })} - -**Execute immediately without preface.** After search, provide a brief summary addressing the query directly, then structure your response with clear Markdown formatting (## headers, lists, tables). Cite sources properly, tailor tone to query type, and provide comprehensive details. - -**CITATION FORMAT - UNICODE ESCAPE SEQUENCES ONLY:** -Use these EXACT escape sequences (copy verbatim): \\ue202 (before each anchor), \\ue200 (group start), \\ue201 (group end), \\ue203 (highlight start), \\ue204 (highlight end) - -Anchor pattern: \\ue202turn{N}{type}{index} where N=turn number, type=search|news|image|ref, index=0,1,2... - -**Examples (copy these exactly):** -- Single: "Statement.\\ue202turn0search0" -- Multiple: "Statement.\\ue202turn0search0\\ue202turn0news1" -- Group: "Statement. \\ue200\\ue202turn0search0\\ue202turn0news1\\ue201" -- Highlight: "\\ue203Cited text.\\ue204\\ue202turn0search0" -- Image: "See photo\\ue202turn0image0." - -**CRITICAL:** Output escape sequences EXACTLY as shown. Do NOT substitute with † or other symbols. Place anchors AFTER punctuation. Cite every non-obvious fact/quote. NEVER use markdown links, [1], footnotes, or HTML tags.`.trim(); + toolContextMap[tool] = buildWebSearchContext(); return createSearchTool({ ...result.authResult, onSearchResults, diff --git a/api/cache/banViolation.js b/api/cache/banViolation.js index 122355edb1..4d321889c1 100644 --- a/api/cache/banViolation.js +++ b/api/cache/banViolation.js @@ -55,6 +55,7 @@ const banViolation = async (req, res, errorMessage) => { res.clearCookie('refreshToken'); res.clearCookie('openid_access_token'); + res.clearCookie('openid_id_token'); res.clearCookie('openid_user_id'); res.clearCookie('token_provider'); diff --git a/api/cache/getLogStores.js b/api/cache/getLogStores.js index 40aac08ee6..3089192196 100644 --- a/api/cache/getLogStores.js +++ b/api/cache/getLogStores.js @@ -37,6 +37,7 @@ const namespaces = { [CacheKeys.ROLES]: standardCache(CacheKeys.ROLES), [CacheKeys.APP_CONFIG]: standardCache(CacheKeys.APP_CONFIG), [CacheKeys.CONFIG_STORE]: standardCache(CacheKeys.CONFIG_STORE), + [CacheKeys.TOOL_CACHE]: standardCache(CacheKeys.TOOL_CACHE), [CacheKeys.PENDING_REQ]: standardCache(CacheKeys.PENDING_REQ), [CacheKeys.ENCODED_DOMAINS]: new Keyv({ store: keyvMongo, namespace: CacheKeys.ENCODED_DOMAINS }), [CacheKeys.ABORT_KEYS]: standardCache(CacheKeys.ABORT_KEYS, Time.TEN_MINUTES), @@ -51,6 +52,10 @@ const namespaces = { CacheKeys.OPENID_EXCHANGED_TOKENS, Time.TEN_MINUTES, ), + [CacheKeys.ADMIN_OAUTH_EXCHANGE]: standardCache( + CacheKeys.ADMIN_OAUTH_EXCHANGE, + Time.THIRTY_SECONDS, + ), }; /** diff --git a/api/db/connect.js b/api/db/connect.js index 26166ccff8..3534884b57 100644 --- a/api/db/connect.js +++ b/api/db/connect.js @@ -40,6 +40,10 @@ if (!cached) { cached = global.mongoose = { conn: null, promise: null }; } +mongoose.connection.on('error', (err) => { + logger.error('[connectDb] MongoDB connection error:', err); +}); + async function connectDb() { if (cached.conn && cached.conn?._readyState === 1) { return cached.conn; diff --git a/api/db/indexSync.js b/api/db/indexSync.js index b39f018b3a..8e8e999d92 100644 --- a/api/db/indexSync.js +++ b/api/db/indexSync.js @@ -13,6 +13,11 @@ const searchEnabled = isEnabled(process.env.SEARCH); const indexingDisabled = isEnabled(process.env.MEILI_NO_SYNC); let currentTimeout = null; +const defaultSyncThreshold = 1000; +const syncThreshold = process.env.MEILI_SYNC_THRESHOLD + ? parseInt(process.env.MEILI_SYNC_THRESHOLD, 10) + : defaultSyncThreshold; + class MeiliSearchClient { static instance = null; @@ -221,25 +226,25 @@ async function performSync(flowManager, flowId, flowType) { } // Check if we need to sync messages + logger.info('[indexSync] Requesting message sync progress...'); const messageProgress = await Message.getSyncProgress(); if (!messageProgress.isComplete || settingsUpdated) { logger.info( `[indexSync] Messages need syncing: ${messageProgress.totalProcessed}/${messageProgress.totalDocuments} indexed`, ); - // Check if we should do a full sync or incremental - const messageCount = await Message.countDocuments(); + const messageCount = messageProgress.totalDocuments; const messagesIndexed = messageProgress.totalProcessed; - const syncThreshold = parseInt(process.env.MEILI_SYNC_THRESHOLD || '1000', 10); + const unindexedMessages = messageCount - messagesIndexed; - if (messageCount - messagesIndexed > syncThreshold) { - logger.info('[indexSync] Starting full message sync due to large difference'); - await Message.syncWithMeili(); - messagesSync = true; - } else if (messageCount !== messagesIndexed) { - logger.warn('[indexSync] Messages out of sync, performing incremental sync'); + if (settingsUpdated || unindexedMessages > syncThreshold) { + logger.info(`[indexSync] Starting message sync (${unindexedMessages} unindexed)`); await Message.syncWithMeili(); messagesSync = true; + } else if (unindexedMessages > 0) { + logger.info( + `[indexSync] ${unindexedMessages} messages unindexed (below threshold: ${syncThreshold}, skipping)`, + ); } } else { logger.info( @@ -254,18 +259,18 @@ async function performSync(flowManager, flowId, flowType) { `[indexSync] Conversations need syncing: ${convoProgress.totalProcessed}/${convoProgress.totalDocuments} indexed`, ); - const convoCount = await Conversation.countDocuments(); + const convoCount = convoProgress.totalDocuments; const convosIndexed = convoProgress.totalProcessed; - const syncThreshold = parseInt(process.env.MEILI_SYNC_THRESHOLD || '1000', 10); - if (convoCount - convosIndexed > syncThreshold) { - logger.info('[indexSync] Starting full conversation sync due to large difference'); - await Conversation.syncWithMeili(); - convosSync = true; - } else if (convoCount !== convosIndexed) { - logger.warn('[indexSync] Convos out of sync, performing incremental sync'); + const unindexedConvos = convoCount - convosIndexed; + if (settingsUpdated || unindexedConvos > syncThreshold) { + logger.info(`[indexSync] Starting convos sync (${unindexedConvos} unindexed)`); await Conversation.syncWithMeili(); convosSync = true; + } else if (unindexedConvos > 0) { + logger.info( + `[indexSync] ${unindexedConvos} convos unindexed (below threshold: ${syncThreshold}, skipping)`, + ); } } else { logger.info( diff --git a/api/db/indexSync.spec.js b/api/db/indexSync.spec.js new file mode 100644 index 0000000000..c2e5901d6a --- /dev/null +++ b/api/db/indexSync.spec.js @@ -0,0 +1,465 @@ +/** + * Unit tests for performSync() function in indexSync.js + * + * Tests use real mongoose with mocked model methods, only mocking external calls. + */ + +const mongoose = require('mongoose'); + +// Mock only external dependencies (not internal classes/models) +const mockLogger = { + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + debug: jest.fn(), +}; + +const mockMeiliHealth = jest.fn(); +const mockMeiliIndex = jest.fn(); +const mockBatchResetMeiliFlags = jest.fn(); +const mockIsEnabled = jest.fn(); +const mockGetLogStores = jest.fn(); + +// Create mock models that will be reused +const createMockModel = (collectionName) => ({ + collection: { name: collectionName }, + getSyncProgress: jest.fn(), + syncWithMeili: jest.fn(), + countDocuments: jest.fn(), +}); + +const originalMessageModel = mongoose.models.Message; +const originalConversationModel = mongoose.models.Conversation; + +// Mock external modules +jest.mock('@librechat/data-schemas', () => ({ + logger: mockLogger, +})); + +jest.mock('meilisearch', () => ({ + MeiliSearch: jest.fn(() => ({ + health: mockMeiliHealth, + index: mockMeiliIndex, + })), +})); + +jest.mock('./utils', () => ({ + batchResetMeiliFlags: mockBatchResetMeiliFlags, +})); + +jest.mock('@librechat/api', () => ({ + isEnabled: mockIsEnabled, + FlowStateManager: jest.fn(), +})); + +jest.mock('~/cache', () => ({ + getLogStores: mockGetLogStores, +})); + +// Set environment before module load +process.env.MEILI_HOST = 'http://localhost:7700'; +process.env.MEILI_MASTER_KEY = 'test-key'; +process.env.SEARCH = 'true'; +process.env.MEILI_SYNC_THRESHOLD = '1000'; // Set threshold before module loads + +describe('performSync() - syncThreshold logic', () => { + const ORIGINAL_ENV = process.env; + let Message; + let Conversation; + + beforeAll(() => { + Message = createMockModel('messages'); + Conversation = createMockModel('conversations'); + + mongoose.models.Message = Message; + mongoose.models.Conversation = Conversation; + }); + + beforeEach(() => { + // Reset all mocks + jest.clearAllMocks(); + // Reset modules to ensure fresh load of indexSync.js and its top-level consts (like syncThreshold) + jest.resetModules(); + + // Set up environment + process.env = { ...ORIGINAL_ENV }; + process.env.MEILI_HOST = 'http://localhost:7700'; + process.env.MEILI_MASTER_KEY = 'test-key'; + process.env.SEARCH = 'true'; + delete process.env.MEILI_NO_SYNC; + + // Re-ensure models are available in mongoose after resetModules + // We must require mongoose again to get the fresh instance that indexSync will use + const mongoose = require('mongoose'); + mongoose.models.Message = Message; + mongoose.models.Conversation = Conversation; + + // Mock isEnabled + mockIsEnabled.mockImplementation((val) => val === 'true' || val === true); + + // Mock MeiliSearch client responses + mockMeiliHealth.mockResolvedValue({ status: 'available' }); + mockMeiliIndex.mockReturnValue({ + getSettings: jest.fn().mockResolvedValue({ filterableAttributes: ['user'] }), + updateSettings: jest.fn().mockResolvedValue({}), + search: jest.fn().mockResolvedValue({ hits: [] }), + }); + + mockBatchResetMeiliFlags.mockResolvedValue(undefined); + }); + + afterEach(() => { + process.env = ORIGINAL_ENV; + }); + + afterAll(() => { + mongoose.models.Message = originalMessageModel; + mongoose.models.Conversation = originalConversationModel; + }); + + test('triggers sync when unindexed messages exceed syncThreshold', async () => { + // Arrange: Set threshold before module load + process.env.MEILI_SYNC_THRESHOLD = '1000'; + + // Arrange: 1050 unindexed messages > 1000 threshold + Message.getSyncProgress.mockResolvedValue({ + totalProcessed: 100, + totalDocuments: 1150, // 1050 unindexed + isComplete: false, + }); + + Conversation.getSyncProgress.mockResolvedValue({ + totalProcessed: 50, + totalDocuments: 50, + isComplete: true, + }); + + Message.syncWithMeili.mockResolvedValue(undefined); + + // Act + const indexSync = require('./indexSync'); + await indexSync(); + + // Assert: No countDocuments calls + expect(Message.countDocuments).not.toHaveBeenCalled(); + expect(Conversation.countDocuments).not.toHaveBeenCalled(); + + // Assert: Message sync triggered because 1050 > 1000 + expect(Message.syncWithMeili).toHaveBeenCalledTimes(1); + expect(mockLogger.info).toHaveBeenCalledWith( + '[indexSync] Messages need syncing: 100/1150 indexed', + ); + expect(mockLogger.info).toHaveBeenCalledWith( + '[indexSync] Starting message sync (1050 unindexed)', + ); + + // Assert: Conversation sync NOT triggered (already complete) + expect(Conversation.syncWithMeili).not.toHaveBeenCalled(); + }); + + test('skips sync when unindexed messages are below syncThreshold', async () => { + // Arrange: 50 unindexed messages < 1000 threshold + Message.getSyncProgress.mockResolvedValue({ + totalProcessed: 100, + totalDocuments: 150, // 50 unindexed + isComplete: false, + }); + + Conversation.getSyncProgress.mockResolvedValue({ + totalProcessed: 50, + totalDocuments: 50, + isComplete: true, + }); + + process.env.MEILI_SYNC_THRESHOLD = '1000'; + + // Act + const indexSync = require('./indexSync'); + await indexSync(); + + // Assert: No countDocuments calls + expect(Message.countDocuments).not.toHaveBeenCalled(); + expect(Conversation.countDocuments).not.toHaveBeenCalled(); + + // Assert: Message sync NOT triggered because 50 < 1000 + expect(Message.syncWithMeili).not.toHaveBeenCalled(); + expect(mockLogger.info).toHaveBeenCalledWith( + '[indexSync] Messages need syncing: 100/150 indexed', + ); + expect(mockLogger.info).toHaveBeenCalledWith( + '[indexSync] 50 messages unindexed (below threshold: 1000, skipping)', + ); + + // Assert: Conversation sync NOT triggered (already complete) + expect(Conversation.syncWithMeili).not.toHaveBeenCalled(); + }); + + test('respects syncThreshold at boundary (exactly at threshold)', async () => { + // Arrange: 1000 unindexed messages = 1000 threshold (NOT greater than) + Message.getSyncProgress.mockResolvedValue({ + totalProcessed: 100, + totalDocuments: 1100, // 1000 unindexed + isComplete: false, + }); + + Conversation.getSyncProgress.mockResolvedValue({ + totalProcessed: 0, + totalDocuments: 0, + isComplete: true, + }); + + process.env.MEILI_SYNC_THRESHOLD = '1000'; + + // Act + const indexSync = require('./indexSync'); + await indexSync(); + + // Assert: No countDocuments calls + expect(Message.countDocuments).not.toHaveBeenCalled(); + + // Assert: Message sync NOT triggered because 1000 is NOT > 1000 + expect(Message.syncWithMeili).not.toHaveBeenCalled(); + expect(mockLogger.info).toHaveBeenCalledWith( + '[indexSync] Messages need syncing: 100/1100 indexed', + ); + expect(mockLogger.info).toHaveBeenCalledWith( + '[indexSync] 1000 messages unindexed (below threshold: 1000, skipping)', + ); + }); + + test('triggers sync when unindexed is threshold + 1', async () => { + // Arrange: 1001 unindexed messages > 1000 threshold + Message.getSyncProgress.mockResolvedValue({ + totalProcessed: 100, + totalDocuments: 1101, // 1001 unindexed + isComplete: false, + }); + + Conversation.getSyncProgress.mockResolvedValue({ + totalProcessed: 0, + totalDocuments: 0, + isComplete: true, + }); + + Message.syncWithMeili.mockResolvedValue(undefined); + + process.env.MEILI_SYNC_THRESHOLD = '1000'; + + // Act + const indexSync = require('./indexSync'); + await indexSync(); + + // Assert: No countDocuments calls + expect(Message.countDocuments).not.toHaveBeenCalled(); + + // Assert: Message sync triggered because 1001 > 1000 + expect(Message.syncWithMeili).toHaveBeenCalledTimes(1); + expect(mockLogger.info).toHaveBeenCalledWith( + '[indexSync] Messages need syncing: 100/1101 indexed', + ); + expect(mockLogger.info).toHaveBeenCalledWith( + '[indexSync] Starting message sync (1001 unindexed)', + ); + }); + + test('uses totalDocuments from convoProgress for conversation sync decisions', async () => { + // Arrange: Messages complete, conversations need sync + Message.getSyncProgress.mockResolvedValue({ + totalProcessed: 100, + totalDocuments: 100, + isComplete: true, + }); + + Conversation.getSyncProgress.mockResolvedValue({ + totalProcessed: 50, + totalDocuments: 1100, // 1050 unindexed > 1000 threshold + isComplete: false, + }); + + Conversation.syncWithMeili.mockResolvedValue(undefined); + + process.env.MEILI_SYNC_THRESHOLD = '1000'; + + // Act + const indexSync = require('./indexSync'); + await indexSync(); + + // Assert: No countDocuments calls (the optimization) + expect(Message.countDocuments).not.toHaveBeenCalled(); + expect(Conversation.countDocuments).not.toHaveBeenCalled(); + + // Assert: Only conversation sync triggered + expect(Message.syncWithMeili).not.toHaveBeenCalled(); + expect(Conversation.syncWithMeili).toHaveBeenCalledTimes(1); + expect(mockLogger.info).toHaveBeenCalledWith( + '[indexSync] Conversations need syncing: 50/1100 indexed', + ); + expect(mockLogger.info).toHaveBeenCalledWith( + '[indexSync] Starting convos sync (1050 unindexed)', + ); + }); + + test('skips sync when collections are fully synced', async () => { + // Arrange: Everything already synced + Message.getSyncProgress.mockResolvedValue({ + totalProcessed: 100, + totalDocuments: 100, + isComplete: true, + }); + + Conversation.getSyncProgress.mockResolvedValue({ + totalProcessed: 50, + totalDocuments: 50, + isComplete: true, + }); + + // Act + const indexSync = require('./indexSync'); + await indexSync(); + + // Assert: No countDocuments calls + expect(Message.countDocuments).not.toHaveBeenCalled(); + expect(Conversation.countDocuments).not.toHaveBeenCalled(); + + // Assert: No sync triggered + expect(Message.syncWithMeili).not.toHaveBeenCalled(); + expect(Conversation.syncWithMeili).not.toHaveBeenCalled(); + + // Assert: Correct logs + expect(mockLogger.info).toHaveBeenCalledWith('[indexSync] Messages are fully synced: 100/100'); + expect(mockLogger.info).toHaveBeenCalledWith( + '[indexSync] Conversations are fully synced: 50/50', + ); + }); + + test('triggers message sync when settingsUpdated even if below syncThreshold', async () => { + // Arrange: Only 50 unindexed messages (< 1000 threshold), but settings were updated + Message.getSyncProgress.mockResolvedValue({ + totalProcessed: 100, + totalDocuments: 150, // 50 unindexed + isComplete: false, + }); + + Conversation.getSyncProgress.mockResolvedValue({ + totalProcessed: 50, + totalDocuments: 50, + isComplete: true, + }); + + Message.syncWithMeili.mockResolvedValue(undefined); + + // Mock settings update scenario + mockMeiliIndex.mockReturnValue({ + getSettings: jest.fn().mockResolvedValue({ filterableAttributes: [] }), // No user field + updateSettings: jest.fn().mockResolvedValue({}), + search: jest.fn().mockResolvedValue({ hits: [] }), + }); + + process.env.MEILI_SYNC_THRESHOLD = '1000'; + + // Act + const indexSync = require('./indexSync'); + await indexSync(); + + // Assert: Flags were reset due to settings update + expect(mockBatchResetMeiliFlags).toHaveBeenCalledWith(Message.collection); + expect(mockBatchResetMeiliFlags).toHaveBeenCalledWith(Conversation.collection); + + // Assert: Message sync triggered despite being below threshold (50 < 1000) + expect(Message.syncWithMeili).toHaveBeenCalledTimes(1); + expect(mockLogger.info).toHaveBeenCalledWith( + '[indexSync] Settings updated. Forcing full re-sync to reindex with new configuration...', + ); + expect(mockLogger.info).toHaveBeenCalledWith( + '[indexSync] Starting message sync (50 unindexed)', + ); + }); + + test('triggers conversation sync when settingsUpdated even if below syncThreshold', async () => { + // Arrange: Messages complete, conversations have 50 unindexed (< 1000 threshold), but settings were updated + Message.getSyncProgress.mockResolvedValue({ + totalProcessed: 100, + totalDocuments: 100, + isComplete: true, + }); + + Conversation.getSyncProgress.mockResolvedValue({ + totalProcessed: 50, + totalDocuments: 100, // 50 unindexed + isComplete: false, + }); + + Conversation.syncWithMeili.mockResolvedValue(undefined); + + // Mock settings update scenario + mockMeiliIndex.mockReturnValue({ + getSettings: jest.fn().mockResolvedValue({ filterableAttributes: [] }), // No user field + updateSettings: jest.fn().mockResolvedValue({}), + search: jest.fn().mockResolvedValue({ hits: [] }), + }); + + process.env.MEILI_SYNC_THRESHOLD = '1000'; + + // Act + const indexSync = require('./indexSync'); + await indexSync(); + + // Assert: Flags were reset due to settings update + expect(mockBatchResetMeiliFlags).toHaveBeenCalledWith(Message.collection); + expect(mockBatchResetMeiliFlags).toHaveBeenCalledWith(Conversation.collection); + + // Assert: Conversation sync triggered despite being below threshold (50 < 1000) + expect(Conversation.syncWithMeili).toHaveBeenCalledTimes(1); + expect(mockLogger.info).toHaveBeenCalledWith( + '[indexSync] Settings updated. Forcing full re-sync to reindex with new configuration...', + ); + expect(mockLogger.info).toHaveBeenCalledWith('[indexSync] Starting convos sync (50 unindexed)'); + }); + + test('triggers both message and conversation sync when settingsUpdated even if both below syncThreshold', async () => { + // Arrange: Set threshold before module load + process.env.MEILI_SYNC_THRESHOLD = '1000'; + + // Arrange: Both have documents below threshold (50 each), but settings were updated + Message.getSyncProgress.mockResolvedValue({ + totalProcessed: 100, + totalDocuments: 150, // 50 unindexed + isComplete: false, + }); + + Conversation.getSyncProgress.mockResolvedValue({ + totalProcessed: 50, + totalDocuments: 100, // 50 unindexed + isComplete: false, + }); + + Message.syncWithMeili.mockResolvedValue(undefined); + Conversation.syncWithMeili.mockResolvedValue(undefined); + + // Mock settings update scenario + mockMeiliIndex.mockReturnValue({ + getSettings: jest.fn().mockResolvedValue({ filterableAttributes: [] }), // No user field + updateSettings: jest.fn().mockResolvedValue({}), + search: jest.fn().mockResolvedValue({ hits: [] }), + }); + + // Act + const indexSync = require('./indexSync'); + await indexSync(); + + // Assert: Flags were reset due to settings update + expect(mockBatchResetMeiliFlags).toHaveBeenCalledWith(Message.collection); + expect(mockBatchResetMeiliFlags).toHaveBeenCalledWith(Conversation.collection); + + // Assert: Both syncs triggered despite both being below threshold + expect(Message.syncWithMeili).toHaveBeenCalledTimes(1); + expect(Conversation.syncWithMeili).toHaveBeenCalledTimes(1); + expect(mockLogger.info).toHaveBeenCalledWith( + '[indexSync] Settings updated. Forcing full re-sync to reindex with new configuration...', + ); + expect(mockLogger.info).toHaveBeenCalledWith( + '[indexSync] Starting message sync (50 unindexed)', + ); + expect(mockLogger.info).toHaveBeenCalledWith('[indexSync] Starting convos sync (50 unindexed)'); + }); +}); diff --git a/api/db/utils.js b/api/db/utils.js index 4a311d9832..32051be78d 100644 --- a/api/db/utils.js +++ b/api/db/utils.js @@ -26,7 +26,7 @@ async function batchResetMeiliFlags(collection) { try { while (hasMore) { const docs = await collection - .find({ expiredAt: null, _meiliIndex: true }, { projection: { _id: 1 } }) + .find({ expiredAt: null, _meiliIndex: { $ne: false } }, { projection: { _id: 1 } }) .limit(BATCH_SIZE) .toArray(); diff --git a/api/db/utils.spec.js b/api/db/utils.spec.js index 8b32b4aea8..adf4f6cd86 100644 --- a/api/db/utils.spec.js +++ b/api/db/utils.spec.js @@ -265,8 +265,8 @@ describe('batchResetMeiliFlags', () => { const result = await batchResetMeiliFlags(testCollection); - // Only one document has _meiliIndex: true - expect(result).toBe(1); + // both documents should be updated + expect(result).toBe(2); }); it('should handle mixed document states correctly', async () => { @@ -275,16 +275,18 @@ describe('batchResetMeiliFlags', () => { { _id: new mongoose.Types.ObjectId(), expiredAt: null, _meiliIndex: false }, { _id: new mongoose.Types.ObjectId(), expiredAt: new Date(), _meiliIndex: true }, { _id: new mongoose.Types.ObjectId(), expiredAt: null, _meiliIndex: true }, + { _id: new mongoose.Types.ObjectId(), expiredAt: null, _meiliIndex: null }, + { _id: new mongoose.Types.ObjectId(), expiredAt: null }, ]); const result = await batchResetMeiliFlags(testCollection); - expect(result).toBe(2); + expect(result).toBe(4); const flaggedDocs = await testCollection .find({ expiredAt: null, _meiliIndex: false }) .toArray(); - expect(flaggedDocs).toHaveLength(3); // 2 were updated, 1 was already false + expect(flaggedDocs).toHaveLength(5); // 4 were updated, 1 was already false }); }); diff --git a/api/models/Agent.js b/api/models/Agent.js index 3df2bcbec2..663285183a 100644 --- a/api/models/Agent.js +++ b/api/models/Agent.js @@ -11,17 +11,15 @@ const { isEphemeralAgentId, encodeEphemeralAgentId, } = require('librechat-data-provider'); -const { GLOBAL_PROJECT_NAME, mcp_all, mcp_delimiter } = - require('librechat-data-provider').Constants; +const { mcp_all, mcp_delimiter } = require('librechat-data-provider').Constants; const { removeAgentFromAllProjects, removeAgentIdsFromProject, addAgentIdsToProject, - getProjectByName, } = require('./Project'); const { removeAllPermissions } = require('~/server/services/PermissionService'); const { getMCPServerTools } = require('~/server/services/Config'); -const { Agent, AclEntry } = require('~/db/models'); +const { Agent, AclEntry, User } = require('~/db/models'); const { getActions } = require('./Action'); /** @@ -591,15 +589,29 @@ const deleteAgent = async (searchParameter) => { const agent = await Agent.findOneAndDelete(searchParameter); if (agent) { await removeAgentFromAllProjects(agent.id); - await removeAllPermissions({ - resourceType: ResourceType.AGENT, - resourceId: agent._id, - }); + await Promise.all([ + removeAllPermissions({ + resourceType: ResourceType.AGENT, + resourceId: agent._id, + }), + removeAllPermissions({ + resourceType: ResourceType.REMOTE_AGENT, + resourceId: agent._id, + }), + ]); try { await Agent.updateMany({ 'edges.to': agent.id }, { $pull: { edges: { to: agent.id } } }); } catch (error) { logger.error('[deleteAgent] Error removing agent from handoff edges', error); } + try { + await User.updateMany( + { 'favorites.agentId': agent.id }, + { $pull: { favorites: { agentId: agent.id } } }, + ); + } catch (error) { + logger.error('[deleteAgent] Error removing agent from user favorites', error); + } } return agent; }; @@ -625,10 +637,19 @@ const deleteUserAgents = async (userId) => { } await AclEntry.deleteMany({ - resourceType: ResourceType.AGENT, + resourceType: { $in: [ResourceType.AGENT, ResourceType.REMOTE_AGENT] }, resourceId: { $in: agentObjectIds }, }); + try { + await User.updateMany( + { 'favorites.agentId': { $in: agentIds } }, + { $pull: { favorites: { agentId: { $in: agentIds } } } }, + ); + } catch (error) { + logger.error('[deleteUserAgents] Error removing agents from user favorites', error); + } + await Agent.deleteMany({ author: userId }); } catch (error) { logger.error('[deleteUserAgents] General error:', error); @@ -735,59 +756,6 @@ const getListAgentsByAccess = async ({ }; }; -/** - * Get all agents. - * @deprecated Use getListAgentsByAccess for ACL-aware agent listing - * @param {Object} searchParameter - The search parameters to find matching agents. - * @param {string} searchParameter.author - The user ID of the agent's author. - * @returns {Promise} A promise that resolves to an object containing the agents data and pagination info. - */ -const getListAgents = async (searchParameter) => { - const { author, ...otherParams } = searchParameter; - - let query = Object.assign({ author }, otherParams); - - const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, ['agentIds']); - if (globalProject && (globalProject.agentIds?.length ?? 0) > 0) { - const globalQuery = { id: { $in: globalProject.agentIds }, ...otherParams }; - delete globalQuery.author; - query = { $or: [globalQuery, query] }; - } - const agents = ( - await Agent.find(query, { - id: 1, - _id: 1, - name: 1, - avatar: 1, - author: 1, - projectIds: 1, - description: 1, - // @deprecated - isCollaborative replaced by ACL permissions - isCollaborative: 1, - category: 1, - }).lean() - ).map((agent) => { - if (agent.author?.toString() !== author) { - delete agent.author; - } - if (agent.author) { - agent.author = agent.author.toString(); - } - return agent; - }); - - const hasMore = agents.length > 0; - const firstId = agents.length > 0 ? agents[0].id : null; - const lastId = agents.length > 0 ? agents[agents.length - 1].id : null; - - return { - data: agents, - has_more: hasMore, - first_id: firstId, - last_id: lastId, - }; -}; - /** * Updates the projects associated with an agent, adding and removing project IDs as specified. * This function also updates the corresponding projects to include or exclude the agent ID. @@ -953,12 +921,11 @@ module.exports = { updateAgent, deleteAgent, deleteUserAgents, - getListAgents, revertAgentVersion, updateAgentProjects, + countPromotedAgents, addAgentResourceFile, getListAgentsByAccess, removeAgentResourceFiles, generateActionMetadataHash, - countPromotedAgents, }; diff --git a/api/models/Agent.spec.js b/api/models/Agent.spec.js index 51256f8cf1..baceb3e8f3 100644 --- a/api/models/Agent.spec.js +++ b/api/models/Agent.spec.js @@ -22,17 +22,17 @@ const { createAgent, updateAgent, deleteAgent, - getListAgents, - getListAgentsByAccess, + deleteUserAgents, revertAgentVersion, updateAgentProjects, addAgentResourceFile, + getListAgentsByAccess, removeAgentResourceFiles, generateActionMetadataHash, } = require('./Agent'); const permissionService = require('~/server/services/PermissionService'); const { getCachedTools, getMCPServerTools } = require('~/server/services/Config'); -const { AclEntry } = require('~/db/models'); +const { AclEntry, User } = require('~/db/models'); /** * @type {import('mongoose').Model} @@ -59,6 +59,7 @@ describe('models/Agent', () => { beforeEach(async () => { await Agent.deleteMany({}); + await User.deleteMany({}); }); test('should add tool_resource to tools if missing', async () => { @@ -575,43 +576,488 @@ describe('models/Agent', () => { expect(sourceAgentAfter.edges).toHaveLength(0); }); - test('should list agents by author', async () => { + test('should remove agent from user favorites when agent is deleted', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + const userId = new mongoose.Types.ObjectId(); + + // Create agent + await createAgent({ + id: agentId, + name: 'Agent To Delete', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + // Create user with the agent in favorites + await User.create({ + _id: userId, + name: 'Test User', + email: `test-${uuidv4()}@example.com`, + provider: 'local', + favorites: [{ agentId: agentId }, { model: 'gpt-4', endpoint: 'openAI' }], + }); + + // Verify user has agent in favorites + const userBefore = await User.findById(userId); + expect(userBefore.favorites).toHaveLength(2); + expect(userBefore.favorites.some((f) => f.agentId === agentId)).toBe(true); + + // Delete the agent + await deleteAgent({ id: agentId }); + + // Verify agent is deleted + const agentAfterDelete = await getAgent({ id: agentId }); + expect(agentAfterDelete).toBeNull(); + + // Verify agent is removed from user favorites + const userAfter = await User.findById(userId); + expect(userAfter.favorites).toHaveLength(1); + expect(userAfter.favorites.some((f) => f.agentId === agentId)).toBe(false); + expect(userAfter.favorites.some((f) => f.model === 'gpt-4')).toBe(true); + }); + + test('should remove agent from multiple users favorites when agent is deleted', async () => { + const agentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + const user1Id = new mongoose.Types.ObjectId(); + const user2Id = new mongoose.Types.ObjectId(); + + // Create agent + await createAgent({ + id: agentId, + name: 'Agent To Delete', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + // Create two users with the agent in favorites + await User.create({ + _id: user1Id, + name: 'Test User 1', + email: `test1-${uuidv4()}@example.com`, + provider: 'local', + favorites: [{ agentId: agentId }], + }); + + await User.create({ + _id: user2Id, + name: 'Test User 2', + email: `test2-${uuidv4()}@example.com`, + provider: 'local', + favorites: [{ agentId: agentId }, { agentId: `agent_${uuidv4()}` }], + }); + + // Delete the agent + await deleteAgent({ id: agentId }); + + // Verify agent is removed from both users' favorites + const user1After = await User.findById(user1Id); + const user2After = await User.findById(user2Id); + + expect(user1After.favorites).toHaveLength(0); + expect(user2After.favorites).toHaveLength(1); + expect(user2After.favorites.some((f) => f.agentId === agentId)).toBe(false); + }); + + test('should preserve other agents in database when one agent is deleted', async () => { + const agentToDeleteId = `agent_${uuidv4()}`; + const agentToKeep1Id = `agent_${uuidv4()}`; + const agentToKeep2Id = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + + // Create multiple agents + await createAgent({ + id: agentToDeleteId, + name: 'Agent To Delete', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + await createAgent({ + id: agentToKeep1Id, + name: 'Agent To Keep 1', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + await createAgent({ + id: agentToKeep2Id, + name: 'Agent To Keep 2', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + // Verify all agents exist + expect(await getAgent({ id: agentToDeleteId })).not.toBeNull(); + expect(await getAgent({ id: agentToKeep1Id })).not.toBeNull(); + expect(await getAgent({ id: agentToKeep2Id })).not.toBeNull(); + + // Delete one agent + await deleteAgent({ id: agentToDeleteId }); + + // Verify only the deleted agent is removed, others remain intact + expect(await getAgent({ id: agentToDeleteId })).toBeNull(); + const keptAgent1 = await getAgent({ id: agentToKeep1Id }); + const keptAgent2 = await getAgent({ id: agentToKeep2Id }); + expect(keptAgent1).not.toBeNull(); + expect(keptAgent1.name).toBe('Agent To Keep 1'); + expect(keptAgent2).not.toBeNull(); + expect(keptAgent2.name).toBe('Agent To Keep 2'); + }); + + test('should preserve other agents in user favorites when one agent is deleted', async () => { + const agentToDeleteId = `agent_${uuidv4()}`; + const agentToKeep1Id = `agent_${uuidv4()}`; + const agentToKeep2Id = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + const userId = new mongoose.Types.ObjectId(); + + // Create multiple agents + await createAgent({ + id: agentToDeleteId, + name: 'Agent To Delete', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + await createAgent({ + id: agentToKeep1Id, + name: 'Agent To Keep 1', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + await createAgent({ + id: agentToKeep2Id, + name: 'Agent To Keep 2', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + // Create user with all three agents in favorites + await User.create({ + _id: userId, + name: 'Test User', + email: `test-${uuidv4()}@example.com`, + provider: 'local', + favorites: [ + { agentId: agentToDeleteId }, + { agentId: agentToKeep1Id }, + { agentId: agentToKeep2Id }, + ], + }); + + // Verify user has all three agents in favorites + const userBefore = await User.findById(userId); + expect(userBefore.favorites).toHaveLength(3); + + // Delete one agent + await deleteAgent({ id: agentToDeleteId }); + + // Verify only the deleted agent is removed from favorites + const userAfter = await User.findById(userId); + expect(userAfter.favorites).toHaveLength(2); + expect(userAfter.favorites.some((f) => f.agentId === agentToDeleteId)).toBe(false); + expect(userAfter.favorites.some((f) => f.agentId === agentToKeep1Id)).toBe(true); + expect(userAfter.favorites.some((f) => f.agentId === agentToKeep2Id)).toBe(true); + }); + + test('should not affect users who do not have deleted agent in favorites', async () => { + const agentToDeleteId = `agent_${uuidv4()}`; + const otherAgentId = `agent_${uuidv4()}`; + const authorId = new mongoose.Types.ObjectId(); + const userWithDeletedAgentId = new mongoose.Types.ObjectId(); + const userWithoutDeletedAgentId = new mongoose.Types.ObjectId(); + + // Create agents + await createAgent({ + id: agentToDeleteId, + name: 'Agent To Delete', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + await createAgent({ + id: otherAgentId, + name: 'Other Agent', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + // Create user with the agent to be deleted + await User.create({ + _id: userWithDeletedAgentId, + name: 'User With Deleted Agent', + email: `user1-${uuidv4()}@example.com`, + provider: 'local', + favorites: [{ agentId: agentToDeleteId }, { model: 'gpt-4', endpoint: 'openAI' }], + }); + + // Create user without the agent to be deleted + await User.create({ + _id: userWithoutDeletedAgentId, + name: 'User Without Deleted Agent', + email: `user2-${uuidv4()}@example.com`, + provider: 'local', + favorites: [{ agentId: otherAgentId }, { model: 'claude-3', endpoint: 'anthropic' }], + }); + + // Delete the agent + await deleteAgent({ id: agentToDeleteId }); + + // Verify user with deleted agent has it removed + const userWithDeleted = await User.findById(userWithDeletedAgentId); + expect(userWithDeleted.favorites).toHaveLength(1); + expect(userWithDeleted.favorites.some((f) => f.agentId === agentToDeleteId)).toBe(false); + expect(userWithDeleted.favorites.some((f) => f.model === 'gpt-4')).toBe(true); + + // Verify user without deleted agent is completely unaffected + const userWithoutDeleted = await User.findById(userWithoutDeletedAgentId); + expect(userWithoutDeleted.favorites).toHaveLength(2); + expect(userWithoutDeleted.favorites.some((f) => f.agentId === otherAgentId)).toBe(true); + expect(userWithoutDeleted.favorites.some((f) => f.model === 'claude-3')).toBe(true); + }); + + test('should remove all user agents from favorites when deleteUserAgents is called', async () => { const authorId = new mongoose.Types.ObjectId(); const otherAuthorId = new mongoose.Types.ObjectId(); + const userId = new mongoose.Types.ObjectId(); - const agentIds = []; - for (let i = 0; i < 5; i++) { - const id = `agent_${uuidv4()}`; - agentIds.push(id); - await createAgent({ - id, - name: `Agent ${i}`, - provider: 'test', - model: 'test-model', - author: authorId, - }); - } + const agent1Id = `agent_${uuidv4()}`; + const agent2Id = `agent_${uuidv4()}`; + const otherAuthorAgentId = `agent_${uuidv4()}`; - for (let i = 0; i < 3; i++) { - await createAgent({ - id: `other_agent_${uuidv4()}`, - name: `Other Agent ${i}`, - provider: 'test', - model: 'test-model', - author: otherAuthorId, - }); - } + // Create agents by the author to be deleted + await createAgent({ + id: agent1Id, + name: 'Author Agent 1', + provider: 'test', + model: 'test-model', + author: authorId, + }); - const result = await getListAgents({ author: authorId.toString() }); + await createAgent({ + id: agent2Id, + name: 'Author Agent 2', + provider: 'test', + model: 'test-model', + author: authorId, + }); - expect(result).toBeDefined(); - expect(result.data).toBeDefined(); - expect(result.data).toHaveLength(5); - expect(result.has_more).toBe(true); + // Create agent by different author (should not be deleted) + await createAgent({ + id: otherAuthorAgentId, + name: 'Other Author Agent', + provider: 'test', + model: 'test-model', + author: otherAuthorId, + }); - for (const agent of result.data) { - expect(agent.author).toBe(authorId.toString()); - } + // Create user with all agents in favorites + await User.create({ + _id: userId, + name: 'Test User', + email: `test-${uuidv4()}@example.com`, + provider: 'local', + favorites: [ + { agentId: agent1Id }, + { agentId: agent2Id }, + { agentId: otherAuthorAgentId }, + { model: 'gpt-4', endpoint: 'openAI' }, + ], + }); + + // Verify user has all favorites + const userBefore = await User.findById(userId); + expect(userBefore.favorites).toHaveLength(4); + + // Delete all agents by the author + await deleteUserAgents(authorId.toString()); + + // Verify author's agents are deleted from database + expect(await getAgent({ id: agent1Id })).toBeNull(); + expect(await getAgent({ id: agent2Id })).toBeNull(); + + // Verify other author's agent still exists + expect(await getAgent({ id: otherAuthorAgentId })).not.toBeNull(); + + // Verify user favorites: author's agents removed, others remain + const userAfter = await User.findById(userId); + expect(userAfter.favorites).toHaveLength(2); + expect(userAfter.favorites.some((f) => f.agentId === agent1Id)).toBe(false); + expect(userAfter.favorites.some((f) => f.agentId === agent2Id)).toBe(false); + expect(userAfter.favorites.some((f) => f.agentId === otherAuthorAgentId)).toBe(true); + expect(userAfter.favorites.some((f) => f.model === 'gpt-4')).toBe(true); + }); + + test('should handle deleteUserAgents when agents are in multiple users favorites', async () => { + const authorId = new mongoose.Types.ObjectId(); + const user1Id = new mongoose.Types.ObjectId(); + const user2Id = new mongoose.Types.ObjectId(); + const user3Id = new mongoose.Types.ObjectId(); + + const agent1Id = `agent_${uuidv4()}`; + const agent2Id = `agent_${uuidv4()}`; + const unrelatedAgentId = `agent_${uuidv4()}`; + + // Create agents by the author + await createAgent({ + id: agent1Id, + name: 'Author Agent 1', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + await createAgent({ + id: agent2Id, + name: 'Author Agent 2', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + // Create users with various favorites configurations + await User.create({ + _id: user1Id, + name: 'User 1', + email: `user1-${uuidv4()}@example.com`, + provider: 'local', + favorites: [{ agentId: agent1Id }, { agentId: agent2Id }], + }); + + await User.create({ + _id: user2Id, + name: 'User 2', + email: `user2-${uuidv4()}@example.com`, + provider: 'local', + favorites: [{ agentId: agent1Id }, { model: 'claude-3', endpoint: 'anthropic' }], + }); + + await User.create({ + _id: user3Id, + name: 'User 3', + email: `user3-${uuidv4()}@example.com`, + provider: 'local', + favorites: [{ agentId: unrelatedAgentId }, { model: 'gpt-4', endpoint: 'openAI' }], + }); + + // Delete all agents by the author + await deleteUserAgents(authorId.toString()); + + // Verify all users' favorites are correctly updated + const user1After = await User.findById(user1Id); + expect(user1After.favorites).toHaveLength(0); + + const user2After = await User.findById(user2Id); + expect(user2After.favorites).toHaveLength(1); + expect(user2After.favorites.some((f) => f.agentId === agent1Id)).toBe(false); + expect(user2After.favorites.some((f) => f.model === 'claude-3')).toBe(true); + + // User 3 should be completely unaffected + const user3After = await User.findById(user3Id); + expect(user3After.favorites).toHaveLength(2); + expect(user3After.favorites.some((f) => f.agentId === unrelatedAgentId)).toBe(true); + expect(user3After.favorites.some((f) => f.model === 'gpt-4')).toBe(true); + }); + + test('should handle deleteUserAgents when user has no agents', async () => { + const authorWithNoAgentsId = new mongoose.Types.ObjectId(); + const otherAuthorId = new mongoose.Types.ObjectId(); + const userId = new mongoose.Types.ObjectId(); + + const existingAgentId = `agent_${uuidv4()}`; + + // Create agent by different author + await createAgent({ + id: existingAgentId, + name: 'Existing Agent', + provider: 'test', + model: 'test-model', + author: otherAuthorId, + }); + + // Create user with favorites + await User.create({ + _id: userId, + name: 'Test User', + email: `test-${uuidv4()}@example.com`, + provider: 'local', + favorites: [{ agentId: existingAgentId }, { model: 'gpt-4', endpoint: 'openAI' }], + }); + + // Delete agents for user with no agents (should be a no-op) + await deleteUserAgents(authorWithNoAgentsId.toString()); + + // Verify existing agent still exists + expect(await getAgent({ id: existingAgentId })).not.toBeNull(); + + // Verify user favorites are unchanged + const userAfter = await User.findById(userId); + expect(userAfter.favorites).toHaveLength(2); + expect(userAfter.favorites.some((f) => f.agentId === existingAgentId)).toBe(true); + expect(userAfter.favorites.some((f) => f.model === 'gpt-4')).toBe(true); + }); + + test('should handle deleteUserAgents when agents are not in any favorites', async () => { + const authorId = new mongoose.Types.ObjectId(); + const userId = new mongoose.Types.ObjectId(); + + const agent1Id = `agent_${uuidv4()}`; + const agent2Id = `agent_${uuidv4()}`; + + // Create agents by the author + await createAgent({ + id: agent1Id, + name: 'Agent 1', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + await createAgent({ + id: agent2Id, + name: 'Agent 2', + provider: 'test', + model: 'test-model', + author: authorId, + }); + + // Create user with favorites that don't include these agents + await User.create({ + _id: userId, + name: 'Test User', + email: `test-${uuidv4()}@example.com`, + provider: 'local', + favorites: [{ model: 'gpt-4', endpoint: 'openAI' }], + }); + + // Verify agents exist + expect(await getAgent({ id: agent1Id })).not.toBeNull(); + expect(await getAgent({ id: agent2Id })).not.toBeNull(); + + // Delete all agents by the author + await deleteUserAgents(authorId.toString()); + + // Verify agents are deleted + expect(await getAgent({ id: agent1Id })).toBeNull(); + expect(await getAgent({ id: agent2Id })).toBeNull(); + + // Verify user favorites are unchanged + const userAfter = await User.findById(userId); + expect(userAfter.favorites).toHaveLength(1); + expect(userAfter.favorites.some((f) => f.model === 'gpt-4')).toBe(true); }); test('should update agent projects', async () => { @@ -733,26 +1179,6 @@ describe('models/Agent', () => { expect(result).toBe(expected); }); - test('should handle getListAgents with invalid author format', async () => { - try { - const result = await getListAgents({ author: 'invalid-object-id' }); - expect(result.data).toEqual([]); - } catch (error) { - expect(error).toBeDefined(); - } - }); - - test('should handle getListAgents with no agents', async () => { - const authorId = new mongoose.Types.ObjectId(); - const result = await getListAgents({ author: authorId.toString() }); - - expect(result).toBeDefined(); - expect(result.data).toEqual([]); - expect(result.has_more).toBe(false); - expect(result.first_id).toBeNull(); - expect(result.last_id).toBeNull(); - }); - test('should handle updateAgentProjects with non-existent agent', async () => { const nonExistentId = `agent_${uuidv4()}`; const userId = new mongoose.Types.ObjectId(); @@ -2366,17 +2792,6 @@ describe('models/Agent', () => { expect(result).toBeNull(); }); - test('should handle getListAgents with no agents', async () => { - const authorId = new mongoose.Types.ObjectId(); - const result = await getListAgents({ author: authorId.toString() }); - - expect(result).toBeDefined(); - expect(result.data).toEqual([]); - expect(result.has_more).toBe(false); - expect(result.first_id).toBeNull(); - expect(result.last_id).toBeNull(); - }); - test('should handle updateAgent with MongoDB operators mixed with direct updates', async () => { const agentId = `agent_${uuidv4()}`; const authorId = new mongoose.Types.ObjectId(); diff --git a/api/models/Conversation.js b/api/models/Conversation.js index a8f5f9a36c..32eac1a764 100644 --- a/api/models/Conversation.js +++ b/api/models/Conversation.js @@ -124,10 +124,15 @@ module.exports = { updateOperation, { new: true, - upsert: true, + upsert: metadata?.noUpsert !== true, }, ); + if (!conversation) { + logger.debug('[saveConvo] Conversation not found, skipping update'); + return null; + } + return conversation.toObject(); } catch (error) { logger.error('[saveConvo] Error saving conversation', error); diff --git a/api/models/Conversation.spec.js b/api/models/Conversation.spec.js index b6237d5f15..bd415b4165 100644 --- a/api/models/Conversation.spec.js +++ b/api/models/Conversation.spec.js @@ -106,6 +106,47 @@ describe('Conversation Operations', () => { expect(result.conversationId).toBe(newConversationId); }); + it('should not create a conversation when noUpsert is true and conversation does not exist', async () => { + const nonExistentId = uuidv4(); + const result = await saveConvo( + mockReq, + { conversationId: nonExistentId, title: 'Ghost Title' }, + { noUpsert: true }, + ); + + expect(result).toBeNull(); + + const dbConvo = await Conversation.findOne({ conversationId: nonExistentId }); + expect(dbConvo).toBeNull(); + }); + + it('should update an existing conversation when noUpsert is true', async () => { + await saveConvo(mockReq, mockConversationData); + + const result = await saveConvo( + mockReq, + { conversationId: mockConversationData.conversationId, title: 'Updated Title' }, + { noUpsert: true }, + ); + + expect(result).not.toBeNull(); + expect(result.title).toBe('Updated Title'); + expect(result.conversationId).toBe(mockConversationData.conversationId); + }); + + it('should still upsert by default when noUpsert is not provided', async () => { + const newId = uuidv4(); + const result = await saveConvo(mockReq, { + conversationId: newId, + title: 'New Conversation', + endpoint: EModelEndpoint.openAI, + }); + + expect(result).not.toBeNull(); + expect(result.conversationId).toBe(newId); + expect(result.title).toBe('New Conversation'); + }); + it('should handle unsetFields metadata', async () => { const metadata = { unsetFields: { someField: 1 }, @@ -122,7 +163,6 @@ describe('Conversation Operations', () => { describe('isTemporary conversation handling', () => { it('should save a conversation with expiredAt when isTemporary is true', async () => { - // Mock app config with 24 hour retention mockReq.config.interfaceConfig.temporaryChatRetention = 24; mockReq.body = { isTemporary: true }; @@ -135,7 +175,6 @@ describe('Conversation Operations', () => { expect(result.expiredAt).toBeDefined(); expect(result.expiredAt).toBeInstanceOf(Date); - // Verify expiredAt is approximately 24 hours in the future const expectedExpirationTime = new Date(beforeSave.getTime() + 24 * 60 * 60 * 1000); const actualExpirationTime = new Date(result.expiredAt); @@ -157,7 +196,6 @@ describe('Conversation Operations', () => { }); it('should save a conversation without expiredAt when isTemporary is not provided', async () => { - // No isTemporary in body mockReq.body = {}; const result = await saveConvo(mockReq, mockConversationData); @@ -167,7 +205,6 @@ describe('Conversation Operations', () => { }); it('should use custom retention period from config', async () => { - // Mock app config with 48 hour retention mockReq.config.interfaceConfig.temporaryChatRetention = 48; mockReq.body = { isTemporary: true }; diff --git a/api/models/File.js b/api/models/File.js index 5e90c86fe4..1a01ef12f9 100644 --- a/api/models/File.js +++ b/api/models/File.js @@ -26,7 +26,8 @@ const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }) => { }; /** - * Retrieves tool files (files that are embedded or have a fileIdentifier) from an array of file IDs + * Retrieves tool files (files that are embedded or have a fileIdentifier) from an array of file IDs. + * Note: execute_code files are handled separately by getCodeGeneratedFiles. * @param {string[]} fileIds - Array of file_id strings to search for * @param {Set} toolResourceSet - Optional filter for tool resources * @returns {Promise>} Files that match the criteria @@ -37,21 +38,25 @@ const getToolFilesByIds = async (fileIds, toolResourceSet) => { } try { - const filter = { - file_id: { $in: fileIds }, - $or: [], - }; + const orConditions = []; if (toolResourceSet.has(EToolResources.context)) { - filter.$or.push({ text: { $exists: true, $ne: null }, context: FileContext.agents }); + orConditions.push({ text: { $exists: true, $ne: null }, context: FileContext.agents }); } if (toolResourceSet.has(EToolResources.file_search)) { - filter.$or.push({ embedded: true }); + orConditions.push({ embedded: true }); } - if (toolResourceSet.has(EToolResources.execute_code)) { - filter.$or.push({ 'metadata.fileIdentifier': { $exists: true } }); + + if (orConditions.length === 0) { + return []; } + const filter = { + file_id: { $in: fileIds }, + context: { $ne: FileContext.execute_code }, // Exclude code-generated files + $or: orConditions, + }; + const selectFields = { text: 0 }; const sortOptions = { updatedAt: -1 }; @@ -62,6 +67,70 @@ const getToolFilesByIds = async (fileIds, toolResourceSet) => { } }; +/** + * Retrieves files generated by code execution for a given conversation. + * These files are stored locally with fileIdentifier metadata for code env re-upload. + * @param {string} conversationId - The conversation ID to search for + * @param {string[]} [messageIds] - Optional array of messageIds to filter by (for linear thread filtering) + * @returns {Promise>} Files generated by code execution in the conversation + */ +const getCodeGeneratedFiles = async (conversationId, messageIds) => { + if (!conversationId) { + return []; + } + + /** messageIds are required for proper thread filtering of code-generated files */ + if (!messageIds || messageIds.length === 0) { + return []; + } + + try { + const filter = { + conversationId, + context: FileContext.execute_code, + messageId: { $exists: true, $in: messageIds }, + 'metadata.fileIdentifier': { $exists: true }, + }; + + const selectFields = { text: 0 }; + const sortOptions = { createdAt: 1 }; + + return await getFiles(filter, sortOptions, selectFields); + } catch (error) { + logger.error('[getCodeGeneratedFiles] Error retrieving code generated files:', error); + return []; + } +}; + +/** + * Retrieves user-uploaded execute_code files (not code-generated) by their file IDs. + * These are files with fileIdentifier metadata but context is NOT execute_code (e.g., agents or message_attachment). + * File IDs should be collected from message.files arrays in the current thread. + * @param {string[]} fileIds - Array of file IDs to fetch (from message.files in the thread) + * @returns {Promise>} User-uploaded execute_code files + */ +const getUserCodeFiles = async (fileIds) => { + if (!fileIds || fileIds.length === 0) { + return []; + } + + try { + const filter = { + file_id: { $in: fileIds }, + context: { $ne: FileContext.execute_code }, + 'metadata.fileIdentifier': { $exists: true }, + }; + + const selectFields = { text: 0 }; + const sortOptions = { createdAt: 1 }; + + return await getFiles(filter, sortOptions, selectFields); + } catch (error) { + logger.error('[getUserCodeFiles] Error retrieving user code files:', error); + return []; + } +}; + /** * Creates a new file with a TTL of 1 hour. * @param {MongoFile} data - The file data to be created, must contain file_id. @@ -169,6 +238,8 @@ module.exports = { findFileById, getFiles, getToolFilesByIds, + getCodeGeneratedFiles, + getUserCodeFiles, createFile, updateFile, updateFileUsage, diff --git a/api/models/Transaction.js b/api/models/Transaction.js index 5fa20f1ddf..e553e2bb3b 100644 --- a/api/models/Transaction.js +++ b/api/models/Transaction.js @@ -138,11 +138,10 @@ const updateBalance = async ({ user, incrementValue, setValues }) => { /** Method to calculate and set the tokenValue for a transaction */ function calculateTokenValue(txn) { - if (!txn.valueKey || !txn.tokenType) { - txn.tokenValue = txn.rawAmount; - } - const { valueKey, tokenType, model, endpointTokenConfig } = txn; - const multiplier = Math.abs(getMultiplier({ valueKey, tokenType, model, endpointTokenConfig })); + const { valueKey, tokenType, model, endpointTokenConfig, inputTokenCount } = txn; + const multiplier = Math.abs( + getMultiplier({ valueKey, tokenType, model, endpointTokenConfig, inputTokenCount }), + ); txn.rate = multiplier; txn.tokenValue = txn.rawAmount * multiplier; if (txn.context && txn.tokenType === 'completion' && txn.context === 'incomplete') { @@ -166,6 +165,7 @@ async function createAutoRefillTransaction(txData) { } const transaction = new Transaction(txData); transaction.endpointTokenConfig = txData.endpointTokenConfig; + transaction.inputTokenCount = txData.inputTokenCount; calculateTokenValue(transaction); await transaction.save(); @@ -200,6 +200,7 @@ async function createTransaction(_txData) { const transaction = new Transaction(txData); transaction.endpointTokenConfig = txData.endpointTokenConfig; + transaction.inputTokenCount = txData.inputTokenCount; calculateTokenValue(transaction); await transaction.save(); @@ -231,10 +232,9 @@ async function createStructuredTransaction(_txData) { return; } - const transaction = new Transaction({ - ...txData, - endpointTokenConfig: txData.endpointTokenConfig, - }); + const transaction = new Transaction(txData); + transaction.endpointTokenConfig = txData.endpointTokenConfig; + transaction.inputTokenCount = txData.inputTokenCount; calculateStructuredTokenValue(transaction); @@ -266,10 +266,15 @@ function calculateStructuredTokenValue(txn) { return; } - const { model, endpointTokenConfig } = txn; + const { model, endpointTokenConfig, inputTokenCount } = txn; if (txn.tokenType === 'prompt') { - const inputMultiplier = getMultiplier({ tokenType: 'prompt', model, endpointTokenConfig }); + const inputMultiplier = getMultiplier({ + tokenType: 'prompt', + model, + endpointTokenConfig, + inputTokenCount, + }); const writeMultiplier = getCacheMultiplier({ cacheType: 'write', model, endpointTokenConfig }) ?? inputMultiplier; const readMultiplier = @@ -304,7 +309,12 @@ function calculateStructuredTokenValue(txn) { txn.rawAmount = -totalPromptTokens; } else if (txn.tokenType === 'completion') { - const multiplier = getMultiplier({ tokenType: txn.tokenType, model, endpointTokenConfig }); + const multiplier = getMultiplier({ + tokenType: txn.tokenType, + model, + endpointTokenConfig, + inputTokenCount, + }); txn.rate = Math.abs(multiplier); txn.tokenValue = -Math.abs(txn.rawAmount) * multiplier; txn.rawAmount = -Math.abs(txn.rawAmount); diff --git a/api/models/Transaction.spec.js b/api/models/Transaction.spec.js index 2df9fc67f2..4b478d4dc3 100644 --- a/api/models/Transaction.spec.js +++ b/api/models/Transaction.spec.js @@ -1,7 +1,7 @@ const mongoose = require('mongoose'); const { MongoMemoryServer } = require('mongodb-memory-server'); const { spendTokens, spendStructuredTokens } = require('./spendTokens'); -const { getMultiplier, getCacheMultiplier } = require('./tx'); +const { getMultiplier, getCacheMultiplier, premiumTokenValues, tokenValues } = require('./tx'); const { createTransaction, createStructuredTransaction } = require('./Transaction'); const { Balance, Transaction } = require('~/db/models'); @@ -564,3 +564,291 @@ describe('Transactions Config Tests', () => { expect(balance.tokenCredits).toBe(initialBalance); }); }); + +describe('calculateTokenValue Edge Cases', () => { + test('should derive multiplier from model when valueKey is not provided', async () => { + const userId = new mongoose.Types.ObjectId(); + const initialBalance = 100000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'gpt-4'; + const promptTokens = 1000; + + const result = await createTransaction({ + user: userId, + conversationId: 'test-no-valuekey', + model, + tokenType: 'prompt', + rawAmount: -promptTokens, + context: 'test', + balance: { enabled: true }, + }); + + const expectedRate = getMultiplier({ model, tokenType: 'prompt' }); + expect(result.rate).toBe(expectedRate); + + const tx = await Transaction.findOne({ user: userId }); + expect(tx.tokenValue).toBe(-promptTokens * expectedRate); + expect(tx.rate).toBe(expectedRate); + }); + + test('should derive valueKey and apply correct rate for an unknown model with tokenType', async () => { + const userId = new mongoose.Types.ObjectId(); + const initialBalance = 100000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + await createTransaction({ + user: userId, + conversationId: 'test-unknown-model', + model: 'some-unrecognized-model-xyz', + tokenType: 'prompt', + rawAmount: -500, + context: 'test', + balance: { enabled: true }, + }); + + const tx = await Transaction.findOne({ user: userId }); + expect(tx.rate).toBeDefined(); + expect(tx.rate).toBeGreaterThan(0); + expect(tx.tokenValue).toBe(tx.rawAmount * tx.rate); + }); + + test('should correctly apply model-derived multiplier without valueKey for completion', async () => { + const userId = new mongoose.Types.ObjectId(); + const initialBalance = 100000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'claude-opus-4-6'; + const completionTokens = 500; + + const result = await createTransaction({ + user: userId, + conversationId: 'test-completion-no-valuekey', + model, + tokenType: 'completion', + rawAmount: -completionTokens, + context: 'test', + balance: { enabled: true }, + }); + + const expectedRate = getMultiplier({ model, tokenType: 'completion' }); + expect(expectedRate).toBe(tokenValues[model].completion); + expect(result.rate).toBe(expectedRate); + + const updatedBalance = await Balance.findOne({ user: userId }); + expect(updatedBalance.tokenCredits).toBeCloseTo( + initialBalance - completionTokens * expectedRate, + 0, + ); + }); +}); + +describe('Premium Token Pricing Integration Tests', () => { + test('spendTokens should apply standard pricing when prompt tokens are below premium threshold', async () => { + const userId = new mongoose.Types.ObjectId(); + const initialBalance = 100000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'claude-opus-4-6'; + const promptTokens = 100000; + const completionTokens = 500; + + const txData = { + user: userId, + conversationId: 'test-premium-below', + model, + context: 'test', + endpointTokenConfig: null, + balance: { enabled: true }, + }; + + await spendTokens(txData, { promptTokens, completionTokens }); + + const standardPromptRate = tokenValues[model].prompt; + const standardCompletionRate = tokenValues[model].completion; + const expectedCost = + promptTokens * standardPromptRate + completionTokens * standardCompletionRate; + + const updatedBalance = await Balance.findOne({ user: userId }); + expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); + }); + + test('spendTokens should apply premium pricing when prompt tokens exceed premium threshold', async () => { + const userId = new mongoose.Types.ObjectId(); + const initialBalance = 100000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'claude-opus-4-6'; + const promptTokens = 250000; + const completionTokens = 500; + + const txData = { + user: userId, + conversationId: 'test-premium-above', + model, + context: 'test', + endpointTokenConfig: null, + balance: { enabled: true }, + }; + + await spendTokens(txData, { promptTokens, completionTokens }); + + const premiumPromptRate = premiumTokenValues[model].prompt; + const premiumCompletionRate = premiumTokenValues[model].completion; + const expectedCost = + promptTokens * premiumPromptRate + completionTokens * premiumCompletionRate; + + const updatedBalance = await Balance.findOne({ user: userId }); + expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); + }); + + test('spendTokens should apply standard pricing at exactly the premium threshold', async () => { + const userId = new mongoose.Types.ObjectId(); + const initialBalance = 100000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'claude-opus-4-6'; + const promptTokens = premiumTokenValues[model].threshold; + const completionTokens = 500; + + const txData = { + user: userId, + conversationId: 'test-premium-exact', + model, + context: 'test', + endpointTokenConfig: null, + balance: { enabled: true }, + }; + + await spendTokens(txData, { promptTokens, completionTokens }); + + const standardPromptRate = tokenValues[model].prompt; + const standardCompletionRate = tokenValues[model].completion; + const expectedCost = + promptTokens * standardPromptRate + completionTokens * standardCompletionRate; + + const updatedBalance = await Balance.findOne({ user: userId }); + expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); + }); + + test('spendStructuredTokens should apply premium pricing when total input tokens exceed threshold', async () => { + const userId = new mongoose.Types.ObjectId(); + const initialBalance = 100000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'claude-opus-4-6'; + const txData = { + user: userId, + conversationId: 'test-structured-premium', + model, + context: 'message', + endpointTokenConfig: null, + balance: { enabled: true }, + }; + + const tokenUsage = { + promptTokens: { + input: 200000, + write: 10000, + read: 5000, + }, + completionTokens: 1000, + }; + + const totalInput = + tokenUsage.promptTokens.input + tokenUsage.promptTokens.write + tokenUsage.promptTokens.read; + + await spendStructuredTokens(txData, tokenUsage); + + const premiumPromptRate = premiumTokenValues[model].prompt; + const premiumCompletionRate = premiumTokenValues[model].completion; + const writeMultiplier = getCacheMultiplier({ model, cacheType: 'write' }); + const readMultiplier = getCacheMultiplier({ model, cacheType: 'read' }); + + const expectedPromptCost = + tokenUsage.promptTokens.input * premiumPromptRate + + tokenUsage.promptTokens.write * writeMultiplier + + tokenUsage.promptTokens.read * readMultiplier; + const expectedCompletionCost = tokenUsage.completionTokens * premiumCompletionRate; + const expectedTotalCost = expectedPromptCost + expectedCompletionCost; + + const updatedBalance = await Balance.findOne({ user: userId }); + expect(totalInput).toBeGreaterThan(premiumTokenValues[model].threshold); + expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedTotalCost, 0); + }); + + test('spendStructuredTokens should apply standard pricing when total input tokens are below threshold', async () => { + const userId = new mongoose.Types.ObjectId(); + const initialBalance = 100000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'claude-opus-4-6'; + const txData = { + user: userId, + conversationId: 'test-structured-standard', + model, + context: 'message', + endpointTokenConfig: null, + balance: { enabled: true }, + }; + + const tokenUsage = { + promptTokens: { + input: 50000, + write: 10000, + read: 5000, + }, + completionTokens: 1000, + }; + + const totalInput = + tokenUsage.promptTokens.input + tokenUsage.promptTokens.write + tokenUsage.promptTokens.read; + + await spendStructuredTokens(txData, tokenUsage); + + const standardPromptRate = tokenValues[model].prompt; + const standardCompletionRate = tokenValues[model].completion; + const writeMultiplier = getCacheMultiplier({ model, cacheType: 'write' }); + const readMultiplier = getCacheMultiplier({ model, cacheType: 'read' }); + + const expectedPromptCost = + tokenUsage.promptTokens.input * standardPromptRate + + tokenUsage.promptTokens.write * writeMultiplier + + tokenUsage.promptTokens.read * readMultiplier; + const expectedCompletionCost = tokenUsage.completionTokens * standardCompletionRate; + const expectedTotalCost = expectedPromptCost + expectedCompletionCost; + + const updatedBalance = await Balance.findOne({ user: userId }); + expect(totalInput).toBeLessThanOrEqual(premiumTokenValues[model].threshold); + expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedTotalCost, 0); + }); + + test('non-premium models should not be affected by inputTokenCount regardless of prompt size', async () => { + const userId = new mongoose.Types.ObjectId(); + const initialBalance = 100000000; + await Balance.create({ user: userId, tokenCredits: initialBalance }); + + const model = 'claude-opus-4-5'; + const promptTokens = 300000; + const completionTokens = 500; + + const txData = { + user: userId, + conversationId: 'test-no-premium', + model, + context: 'test', + endpointTokenConfig: null, + balance: { enabled: true }, + }; + + await spendTokens(txData, { promptTokens, completionTokens }); + + const standardPromptRate = getMultiplier({ model, tokenType: 'prompt' }); + const standardCompletionRate = getMultiplier({ model, tokenType: 'completion' }); + const expectedCost = + promptTokens * standardPromptRate + completionTokens * standardCompletionRate; + + const updatedBalance = await Balance.findOne({ user: userId }); + expect(updatedBalance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); + }); +}); diff --git a/api/models/spendTokens.js b/api/models/spendTokens.js index cfd983f6bb..afe05969d8 100644 --- a/api/models/spendTokens.js +++ b/api/models/spendTokens.js @@ -24,12 +24,14 @@ const spendTokens = async (txData, tokenUsage) => { }, ); let prompt, completion; + const normalizedPromptTokens = Math.max(promptTokens ?? 0, 0); try { if (promptTokens !== undefined) { prompt = await createTransaction({ ...txData, tokenType: 'prompt', - rawAmount: promptTokens === 0 ? 0 : -Math.max(promptTokens, 0), + rawAmount: promptTokens === 0 ? 0 : -normalizedPromptTokens, + inputTokenCount: normalizedPromptTokens, }); } @@ -38,6 +40,7 @@ const spendTokens = async (txData, tokenUsage) => { ...txData, tokenType: 'completion', rawAmount: completionTokens === 0 ? 0 : -Math.max(completionTokens, 0), + inputTokenCount: normalizedPromptTokens, }); } @@ -87,21 +90,31 @@ const spendStructuredTokens = async (txData, tokenUsage) => { let prompt, completion; try { if (promptTokens) { - const { input = 0, write = 0, read = 0 } = promptTokens; + const input = Math.max(promptTokens.input ?? 0, 0); + const write = Math.max(promptTokens.write ?? 0, 0); + const read = Math.max(promptTokens.read ?? 0, 0); + const totalInputTokens = input + write + read; prompt = await createStructuredTransaction({ ...txData, tokenType: 'prompt', inputTokens: -input, writeTokens: -write, readTokens: -read, + inputTokenCount: totalInputTokens, }); } if (completionTokens) { + const totalInputTokens = promptTokens + ? Math.max(promptTokens.input ?? 0, 0) + + Math.max(promptTokens.write ?? 0, 0) + + Math.max(promptTokens.read ?? 0, 0) + : undefined; completion = await createTransaction({ ...txData, tokenType: 'completion', - rawAmount: -completionTokens, + rawAmount: -Math.max(completionTokens, 0), + inputTokenCount: totalInputTokens, }); } diff --git a/api/models/spendTokens.spec.js b/api/models/spendTokens.spec.js index eee6572736..c076d29700 100644 --- a/api/models/spendTokens.spec.js +++ b/api/models/spendTokens.spec.js @@ -1,7 +1,8 @@ const mongoose = require('mongoose'); const { MongoMemoryServer } = require('mongodb-memory-server'); -const { spendTokens, spendStructuredTokens } = require('./spendTokens'); const { createTransaction, createAutoRefillTransaction } = require('./Transaction'); +const { tokenValues, premiumTokenValues, getCacheMultiplier } = require('./tx'); +const { spendTokens, spendStructuredTokens } = require('./spendTokens'); require('~/db/models'); @@ -734,4 +735,328 @@ describe('spendTokens', () => { expect(balance).toBeDefined(); expect(balance.tokenCredits).toBeLessThan(10000); // Balance should be reduced }); + + describe('premium token pricing', () => { + it('should charge standard rates for claude-opus-4-6 when prompt tokens are below threshold', async () => { + const initialBalance = 100000000; + await Balance.create({ + user: userId, + tokenCredits: initialBalance, + }); + + const model = 'claude-opus-4-6'; + const promptTokens = 100000; + const completionTokens = 500; + + const txData = { + user: userId, + conversationId: 'test-standard-pricing', + model, + context: 'test', + balance: { enabled: true }, + }; + + await spendTokens(txData, { promptTokens, completionTokens }); + + const expectedCost = + promptTokens * tokenValues[model].prompt + completionTokens * tokenValues[model].completion; + + const balance = await Balance.findOne({ user: userId }); + expect(balance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); + }); + + it('should charge premium rates for claude-opus-4-6 when prompt tokens exceed threshold', async () => { + const initialBalance = 100000000; + await Balance.create({ + user: userId, + tokenCredits: initialBalance, + }); + + const model = 'claude-opus-4-6'; + const promptTokens = 250000; + const completionTokens = 500; + + const txData = { + user: userId, + conversationId: 'test-premium-pricing', + model, + context: 'test', + balance: { enabled: true }, + }; + + await spendTokens(txData, { promptTokens, completionTokens }); + + const expectedCost = + promptTokens * premiumTokenValues[model].prompt + + completionTokens * premiumTokenValues[model].completion; + + const balance = await Balance.findOne({ user: userId }); + expect(balance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); + }); + + it('should charge premium rates for both prompt and completion in structured tokens when above threshold', async () => { + const initialBalance = 100000000; + await Balance.create({ + user: userId, + tokenCredits: initialBalance, + }); + + const model = 'claude-opus-4-6'; + const txData = { + user: userId, + conversationId: 'test-structured-premium', + model, + context: 'test', + balance: { enabled: true }, + }; + + const tokenUsage = { + promptTokens: { + input: 200000, + write: 10000, + read: 5000, + }, + completionTokens: 1000, + }; + + const result = await spendStructuredTokens(txData, tokenUsage); + + const premiumPromptRate = premiumTokenValues[model].prompt; + const premiumCompletionRate = premiumTokenValues[model].completion; + const writeRate = getCacheMultiplier({ model, cacheType: 'write' }); + const readRate = getCacheMultiplier({ model, cacheType: 'read' }); + + const expectedPromptCost = + tokenUsage.promptTokens.input * premiumPromptRate + + tokenUsage.promptTokens.write * writeRate + + tokenUsage.promptTokens.read * readRate; + const expectedCompletionCost = tokenUsage.completionTokens * premiumCompletionRate; + + expect(result.prompt.prompt).toBeCloseTo(-expectedPromptCost, 0); + expect(result.completion.completion).toBeCloseTo(-expectedCompletionCost, 0); + }); + + it('should charge standard rates for structured tokens when below threshold', async () => { + const initialBalance = 100000000; + await Balance.create({ + user: userId, + tokenCredits: initialBalance, + }); + + const model = 'claude-opus-4-6'; + const txData = { + user: userId, + conversationId: 'test-structured-standard', + model, + context: 'test', + balance: { enabled: true }, + }; + + const tokenUsage = { + promptTokens: { + input: 50000, + write: 10000, + read: 5000, + }, + completionTokens: 1000, + }; + + const result = await spendStructuredTokens(txData, tokenUsage); + + const standardPromptRate = tokenValues[model].prompt; + const standardCompletionRate = tokenValues[model].completion; + const writeRate = getCacheMultiplier({ model, cacheType: 'write' }); + const readRate = getCacheMultiplier({ model, cacheType: 'read' }); + + const expectedPromptCost = + tokenUsage.promptTokens.input * standardPromptRate + + tokenUsage.promptTokens.write * writeRate + + tokenUsage.promptTokens.read * readRate; + const expectedCompletionCost = tokenUsage.completionTokens * standardCompletionRate; + + expect(result.prompt.prompt).toBeCloseTo(-expectedPromptCost, 0); + expect(result.completion.completion).toBeCloseTo(-expectedCompletionCost, 0); + }); + + it('should not apply premium pricing to non-premium models regardless of prompt size', async () => { + const initialBalance = 100000000; + await Balance.create({ + user: userId, + tokenCredits: initialBalance, + }); + + const model = 'claude-opus-4-5'; + const promptTokens = 300000; + const completionTokens = 500; + + const txData = { + user: userId, + conversationId: 'test-no-premium', + model, + context: 'test', + balance: { enabled: true }, + }; + + await spendTokens(txData, { promptTokens, completionTokens }); + + const expectedCost = + promptTokens * tokenValues[model].prompt + completionTokens * tokenValues[model].completion; + + const balance = await Balance.findOne({ user: userId }); + expect(balance.tokenCredits).toBeCloseTo(initialBalance - expectedCost, 0); + }); + }); + + describe('inputTokenCount Normalization', () => { + it('should normalize negative promptTokens to zero for inputTokenCount', async () => { + await Balance.create({ + user: userId, + tokenCredits: 100000000, + }); + + const txData = { + user: userId, + conversationId: 'test-negative-prompt', + model: 'claude-opus-4-6', + context: 'test', + balance: { enabled: true }, + }; + + await spendTokens(txData, { promptTokens: -500, completionTokens: 100 }); + + const transactions = await Transaction.find({ user: userId }).sort({ tokenType: 1 }); + + const completionTx = transactions.find((t) => t.tokenType === 'completion'); + const promptTx = transactions.find((t) => t.tokenType === 'prompt'); + + expect(Math.abs(promptTx.rawAmount)).toBe(0); + expect(completionTx.rawAmount).toBe(-100); + + const standardCompletionRate = tokenValues['claude-opus-4-6'].completion; + expect(completionTx.rate).toBe(standardCompletionRate); + }); + + it('should use normalized inputTokenCount for premium threshold check on completion', async () => { + const initialBalance = 100000000; + await Balance.create({ + user: userId, + tokenCredits: initialBalance, + }); + + const model = 'claude-opus-4-6'; + const promptTokens = 250000; + const completionTokens = 500; + + const txData = { + user: userId, + conversationId: 'test-normalized-premium', + model, + context: 'test', + balance: { enabled: true }, + }; + + await spendTokens(txData, { promptTokens, completionTokens }); + + const transactions = await Transaction.find({ user: userId }).sort({ tokenType: 1 }); + const completionTx = transactions.find((t) => t.tokenType === 'completion'); + const promptTx = transactions.find((t) => t.tokenType === 'prompt'); + + const premiumPromptRate = premiumTokenValues[model].prompt; + const premiumCompletionRate = premiumTokenValues[model].completion; + expect(promptTx.rate).toBe(premiumPromptRate); + expect(completionTx.rate).toBe(premiumCompletionRate); + }); + + it('should keep inputTokenCount as zero when promptTokens is zero', async () => { + await Balance.create({ + user: userId, + tokenCredits: 100000000, + }); + + const txData = { + user: userId, + conversationId: 'test-zero-prompt', + model: 'claude-opus-4-6', + context: 'test', + balance: { enabled: true }, + }; + + await spendTokens(txData, { promptTokens: 0, completionTokens: 100 }); + + const transactions = await Transaction.find({ user: userId }).sort({ tokenType: 1 }); + const completionTx = transactions.find((t) => t.tokenType === 'completion'); + const promptTx = transactions.find((t) => t.tokenType === 'prompt'); + + expect(Math.abs(promptTx.rawAmount)).toBe(0); + + const standardCompletionRate = tokenValues['claude-opus-4-6'].completion; + expect(completionTx.rate).toBe(standardCompletionRate); + }); + + it('should not trigger premium pricing with negative promptTokens on premium model', async () => { + const initialBalance = 100000000; + await Balance.create({ + user: userId, + tokenCredits: initialBalance, + }); + + const model = 'claude-opus-4-6'; + const txData = { + user: userId, + conversationId: 'test-negative-no-premium', + model, + context: 'test', + balance: { enabled: true }, + }; + + await spendTokens(txData, { promptTokens: -300000, completionTokens: 500 }); + + const transactions = await Transaction.find({ user: userId }).sort({ tokenType: 1 }); + const completionTx = transactions.find((t) => t.tokenType === 'completion'); + + const standardCompletionRate = tokenValues[model].completion; + expect(completionTx.rate).toBe(standardCompletionRate); + }); + + it('should normalize negative structured token values to zero in spendStructuredTokens', async () => { + const initialBalance = 100000000; + await Balance.create({ + user: userId, + tokenCredits: initialBalance, + }); + + const model = 'claude-opus-4-6'; + const txData = { + user: userId, + conversationId: 'test-negative-structured', + model, + context: 'test', + balance: { enabled: true }, + }; + + const tokenUsage = { + promptTokens: { input: -100, write: 50, read: -30 }, + completionTokens: -200, + }; + + await spendStructuredTokens(txData, tokenUsage); + + const transactions = await Transaction.find({ + user: userId, + conversationId: 'test-negative-structured', + }).sort({ tokenType: 1 }); + + const completionTx = transactions.find((t) => t.tokenType === 'completion'); + const promptTx = transactions.find((t) => t.tokenType === 'prompt'); + + expect(Math.abs(promptTx.inputTokens)).toBe(0); + expect(promptTx.writeTokens).toBe(-50); + expect(Math.abs(promptTx.readTokens)).toBe(0); + + expect(Math.abs(completionTx.rawAmount)).toBe(0); + + const standardRate = tokenValues[model].completion; + expect(completionTx.rate).toBe(standardRate); + }); + }); }); diff --git a/api/models/tx.js b/api/models/tx.js index 6ff105a458..9a6305ec5c 100644 --- a/api/models/tx.js +++ b/api/models/tx.js @@ -1,10 +1,40 @@ const { matchModelName, findMatchingPattern } = require('@librechat/api'); const defaultRate = 6; +/** + * Token Pricing Configuration + * + * IMPORTANT: Key Ordering for Pattern Matching + * ============================================ + * The `findMatchingPattern` function iterates through object keys in REVERSE order + * (last-defined keys are checked first) and uses `modelName.includes(key)` for matching. + * + * This means: + * 1. BASE PATTERNS must be defined FIRST (e.g., "kimi", "moonshot") + * 2. SPECIFIC PATTERNS must be defined AFTER their base patterns (e.g., "kimi-k2", "kimi-k2.5") + * + * Example ordering for Kimi models: + * kimi: { prompt: 0.6, completion: 2.5 }, // Base pattern - checked last + * 'kimi-k2': { prompt: 0.6, completion: 2.5 }, // More specific - checked before "kimi" + * 'kimi-k2.5': { prompt: 0.6, completion: 3.0 }, // Most specific - checked first + * + * Why this matters: + * - Model name "kimi-k2.5" contains both "kimi" and "kimi-k2" as substrings + * - If "kimi" were checked first, it would incorrectly match and return wrong pricing + * - By defining specific patterns AFTER base patterns, they're checked first in reverse iteration + * + * This applies to BOTH `tokenValues` and `cacheTokenValues` objects. + * + * When adding new model families: + * 1. Define the base/generic pattern first + * 2. Define increasingly specific patterns after + * 3. Ensure no pattern is a substring of another that should match differently + */ + /** * AWS Bedrock pricing * source: https://aws.amazon.com/bedrock/pricing/ - * */ + */ const bedrockValues = { // Basic llama2 patterns (base defaults to smallest variant) llama2: { prompt: 0.75, completion: 1.0 }, @@ -80,6 +110,11 @@ const bedrockValues = { 'nova-pro': { prompt: 0.8, completion: 3.2 }, 'nova-premier': { prompt: 2.5, completion: 12.5 }, 'deepseek.r1': { prompt: 1.35, completion: 5.4 }, + // Moonshot/Kimi models on Bedrock + 'moonshot.kimi': { prompt: 0.6, completion: 2.5 }, + 'moonshot.kimi-k2': { prompt: 0.6, completion: 2.5 }, + 'moonshot.kimi-k2.5': { prompt: 0.6, completion: 3.0 }, + 'moonshot.kimi-k2-thinking': { prompt: 0.6, completion: 2.5 }, }; /** @@ -139,7 +174,9 @@ const tokenValues = Object.assign( 'claude-haiku-4-5': { prompt: 1, completion: 5 }, 'claude-opus-4': { prompt: 15, completion: 75 }, 'claude-opus-4-5': { prompt: 5, completion: 25 }, + 'claude-opus-4-6': { prompt: 5, completion: 25 }, 'claude-sonnet-4': { prompt: 3, completion: 15 }, + 'claude-sonnet-4-6': { prompt: 3, completion: 15 }, 'command-r': { prompt: 0.5, completion: 1.5 }, 'command-r-plus': { prompt: 3, completion: 15 }, 'command-text': { prompt: 1.5, completion: 2.0 }, @@ -189,7 +226,31 @@ const tokenValues = Object.assign( 'pixtral-large': { prompt: 2.0, completion: 6.0 }, 'mistral-large': { prompt: 2.0, completion: 6.0 }, 'mixtral-8x22b': { prompt: 0.65, completion: 0.65 }, - kimi: { prompt: 0.14, completion: 2.49 }, // Base pattern (using kimi-k2 pricing) + // Moonshot/Kimi models (base patterns first, specific patterns last for correct matching) + kimi: { prompt: 0.6, completion: 2.5 }, // Base pattern + moonshot: { prompt: 2.0, completion: 5.0 }, // Base pattern (using 128k pricing) + 'kimi-latest': { prompt: 0.2, completion: 2.0 }, // Uses 8k/32k/128k pricing dynamically + 'kimi-k2': { prompt: 0.6, completion: 2.5 }, + 'kimi-k2.5': { prompt: 0.6, completion: 3.0 }, + 'kimi-k2-turbo': { prompt: 1.15, completion: 8.0 }, + 'kimi-k2-turbo-preview': { prompt: 1.15, completion: 8.0 }, + 'kimi-k2-0905': { prompt: 0.6, completion: 2.5 }, + 'kimi-k2-0905-preview': { prompt: 0.6, completion: 2.5 }, + 'kimi-k2-0711': { prompt: 0.6, completion: 2.5 }, + 'kimi-k2-0711-preview': { prompt: 0.6, completion: 2.5 }, + 'kimi-k2-thinking': { prompt: 0.6, completion: 2.5 }, + 'kimi-k2-thinking-turbo': { prompt: 1.15, completion: 8.0 }, + 'moonshot-v1': { prompt: 2.0, completion: 5.0 }, + 'moonshot-v1-auto': { prompt: 2.0, completion: 5.0 }, + 'moonshot-v1-8k': { prompt: 0.2, completion: 2.0 }, + 'moonshot-v1-8k-vision': { prompt: 0.2, completion: 2.0 }, + 'moonshot-v1-8k-vision-preview': { prompt: 0.2, completion: 2.0 }, + 'moonshot-v1-32k': { prompt: 1.0, completion: 3.0 }, + 'moonshot-v1-32k-vision': { prompt: 1.0, completion: 3.0 }, + 'moonshot-v1-32k-vision-preview': { prompt: 1.0, completion: 3.0 }, + 'moonshot-v1-128k': { prompt: 2.0, completion: 5.0 }, + 'moonshot-v1-128k-vision': { prompt: 2.0, completion: 5.0 }, + 'moonshot-v1-128k-vision-preview': { prompt: 2.0, completion: 5.0 }, // GPT-OSS models (specific sizes) 'gpt-oss:20b': { prompt: 0.05, completion: 0.2 }, 'gpt-oss-20b': { prompt: 0.05, completion: 0.2 }, @@ -249,12 +310,36 @@ const cacheTokenValues = { 'claude-3-haiku': { write: 0.3, read: 0.03 }, 'claude-haiku-4-5': { write: 1.25, read: 0.1 }, 'claude-sonnet-4': { write: 3.75, read: 0.3 }, + 'claude-sonnet-4-6': { write: 3.75, read: 0.3 }, 'claude-opus-4': { write: 18.75, read: 1.5 }, 'claude-opus-4-5': { write: 6.25, read: 0.5 }, + 'claude-opus-4-6': { write: 6.25, read: 0.5 }, // DeepSeek models - cache hit: $0.028/1M, cache miss: $0.28/1M deepseek: { write: 0.28, read: 0.028 }, 'deepseek-chat': { write: 0.28, read: 0.028 }, 'deepseek-reasoner': { write: 0.28, read: 0.028 }, + // Moonshot/Kimi models - cache hit: $0.15/1M (k2) or $0.10/1M (k2.5), cache miss: $0.60/1M + kimi: { write: 0.6, read: 0.15 }, + 'kimi-k2': { write: 0.6, read: 0.15 }, + 'kimi-k2.5': { write: 0.6, read: 0.1 }, + 'kimi-k2-turbo': { write: 1.15, read: 0.15 }, + 'kimi-k2-turbo-preview': { write: 1.15, read: 0.15 }, + 'kimi-k2-0905': { write: 0.6, read: 0.15 }, + 'kimi-k2-0905-preview': { write: 0.6, read: 0.15 }, + 'kimi-k2-0711': { write: 0.6, read: 0.15 }, + 'kimi-k2-0711-preview': { write: 0.6, read: 0.15 }, + 'kimi-k2-thinking': { write: 0.6, read: 0.15 }, + 'kimi-k2-thinking-turbo': { write: 1.15, read: 0.15 }, +}; + +/** + * Premium (tiered) pricing for models whose rates change based on prompt size. + * Each entry specifies the token threshold and the rates that apply above it. + * @type {Object.} + */ +const premiumTokenValues = { + 'claude-opus-4-6': { threshold: 200000, prompt: 10, completion: 37.5 }, + 'claude-sonnet-4-6': { threshold: 200000, prompt: 6, completion: 22.5 }, }; /** @@ -313,15 +398,27 @@ const getValueKey = (model, endpoint) => { * @param {string} [params.model] - The model name to derive the value key from if not provided. * @param {string} [params.endpoint] - The endpoint name to derive the value key from if not provided. * @param {EndpointTokenConfig} [params.endpointTokenConfig] - The token configuration for the endpoint. + * @param {number} [params.inputTokenCount] - Total input token count for tiered pricing. * @returns {number} The multiplier for the given parameters, or a default value if not found. */ -const getMultiplier = ({ valueKey, tokenType, model, endpoint, endpointTokenConfig }) => { +const getMultiplier = ({ + model, + valueKey, + endpoint, + tokenType, + inputTokenCount, + endpointTokenConfig, +}) => { if (endpointTokenConfig) { return endpointTokenConfig?.[model]?.[tokenType] ?? defaultRate; } if (valueKey && tokenType) { - return tokenValues[valueKey][tokenType] ?? defaultRate; + const premiumRate = getPremiumRate(valueKey, tokenType, inputTokenCount); + if (premiumRate != null) { + return premiumRate; + } + return tokenValues[valueKey]?.[tokenType] ?? defaultRate; } if (!tokenType || !model) { @@ -333,10 +430,33 @@ const getMultiplier = ({ valueKey, tokenType, model, endpoint, endpointTokenConf return defaultRate; } - // If we got this far, and values[tokenType] is undefined somehow, return a rough average of default multipliers + const premiumRate = getPremiumRate(valueKey, tokenType, inputTokenCount); + if (premiumRate != null) { + return premiumRate; + } + return tokenValues[valueKey]?.[tokenType] ?? defaultRate; }; +/** + * Checks if premium (tiered) pricing applies and returns the premium rate. + * Each model defines its own threshold in `premiumTokenValues`. + * @param {string} valueKey + * @param {string} tokenType + * @param {number} [inputTokenCount] + * @returns {number|null} + */ +const getPremiumRate = (valueKey, tokenType, inputTokenCount) => { + if (inputTokenCount == null) { + return null; + } + const premiumEntry = premiumTokenValues[valueKey]; + if (!premiumEntry || inputTokenCount <= premiumEntry.threshold) { + return null; + } + return premiumEntry[tokenType] ?? null; +}; + /** * Retrieves the cache multiplier for a given value key and token type. If no value key is provided, * it attempts to derive it from the model name. @@ -373,8 +493,10 @@ const getCacheMultiplier = ({ valueKey, cacheType, model, endpoint, endpointToke module.exports = { tokenValues, + premiumTokenValues, getValueKey, getMultiplier, + getPremiumRate, getCacheMultiplier, defaultRate, cacheTokenValues, diff --git a/api/models/tx.spec.js b/api/models/tx.spec.js index f70a6af47c..df1bec8619 100644 --- a/api/models/tx.spec.js +++ b/api/models/tx.spec.js @@ -1,3 +1,4 @@ +/** Note: No hard-coded values should be used in this file. */ const { maxTokensMap } = require('@librechat/api'); const { EModelEndpoint } = require('librechat-data-provider'); const { @@ -5,8 +6,10 @@ const { tokenValues, getValueKey, getMultiplier, + getPremiumRate, cacheTokenValues, getCacheMultiplier, + premiumTokenValues, } = require('./tx'); describe('getValueKey', () => { @@ -239,6 +242,15 @@ describe('getMultiplier', () => { expect(getMultiplier({ valueKey: '8k', tokenType: 'unknownType' })).toBe(defaultRate); }); + it('should return defaultRate if valueKey does not exist in tokenValues', () => { + expect(getMultiplier({ valueKey: 'non-existent-model', tokenType: 'prompt' })).toBe( + defaultRate, + ); + expect(getMultiplier({ valueKey: 'non-existent-model', tokenType: 'completion' })).toBe( + defaultRate, + ); + }); + it('should derive the valueKey from the model if not provided', () => { expect(getMultiplier({ tokenType: 'prompt', model: 'gpt-4-some-other-info' })).toBe( tokenValues['8k'].prompt, @@ -334,8 +346,6 @@ describe('getMultiplier', () => { expect(getMultiplier({ model: 'openai/gpt-5.1', tokenType: 'prompt' })).toBe( tokenValues['gpt-5.1'].prompt, ); - expect(tokenValues['gpt-5.1'].prompt).toBe(1.25); - expect(tokenValues['gpt-5.1'].completion).toBe(10); }); it('should return the correct multiplier for gpt-5.2', () => { @@ -348,8 +358,6 @@ describe('getMultiplier', () => { expect(getMultiplier({ model: 'openai/gpt-5.2', tokenType: 'prompt' })).toBe( tokenValues['gpt-5.2'].prompt, ); - expect(tokenValues['gpt-5.2'].prompt).toBe(1.75); - expect(tokenValues['gpt-5.2'].completion).toBe(14); }); it('should return the correct multiplier for gpt-4o', () => { @@ -815,8 +823,6 @@ describe('Deepseek Model Tests', () => { expect(getMultiplier({ model: 'deepseek-chat', tokenType: 'completion' })).toBe( tokenValues['deepseek-chat'].completion, ); - expect(tokenValues['deepseek-chat'].prompt).toBe(0.28); - expect(tokenValues['deepseek-chat'].completion).toBe(0.42); }); it('should return correct pricing for deepseek-reasoner', () => { @@ -826,8 +832,6 @@ describe('Deepseek Model Tests', () => { expect(getMultiplier({ model: 'deepseek-reasoner', tokenType: 'completion' })).toBe( tokenValues['deepseek-reasoner'].completion, ); - expect(tokenValues['deepseek-reasoner'].prompt).toBe(0.28); - expect(tokenValues['deepseek-reasoner'].completion).toBe(0.42); }); it('should handle DeepSeek model name variations with provider prefixes', () => { @@ -840,8 +844,8 @@ describe('Deepseek Model Tests', () => { modelVariations.forEach((model) => { const promptMultiplier = getMultiplier({ model, tokenType: 'prompt' }); const completionMultiplier = getMultiplier({ model, tokenType: 'completion' }); - expect(promptMultiplier).toBe(0.28); - expect(completionMultiplier).toBe(0.42); + expect(promptMultiplier).toBe(tokenValues['deepseek-chat'].prompt); + expect(completionMultiplier).toBe(tokenValues['deepseek-chat'].completion); }); }); @@ -860,13 +864,13 @@ describe('Deepseek Model Tests', () => { ); }); - it('should return correct cache pricing values for DeepSeek models', () => { - expect(cacheTokenValues['deepseek-chat'].write).toBe(0.28); - expect(cacheTokenValues['deepseek-chat'].read).toBe(0.028); - expect(cacheTokenValues['deepseek-reasoner'].write).toBe(0.28); - expect(cacheTokenValues['deepseek-reasoner'].read).toBe(0.028); - expect(cacheTokenValues['deepseek'].write).toBe(0.28); - expect(cacheTokenValues['deepseek'].read).toBe(0.028); + it('should have consistent cache pricing across DeepSeek model variants', () => { + expect(cacheTokenValues['deepseek'].write).toBe(cacheTokenValues['deepseek-chat'].write); + expect(cacheTokenValues['deepseek'].read).toBe(cacheTokenValues['deepseek-chat'].read); + expect(cacheTokenValues['deepseek-reasoner'].write).toBe( + cacheTokenValues['deepseek-chat'].write, + ); + expect(cacheTokenValues['deepseek-reasoner'].read).toBe(cacheTokenValues['deepseek-chat'].read); }); it('should handle DeepSeek cache multipliers with model variations', () => { @@ -875,8 +879,195 @@ describe('Deepseek Model Tests', () => { modelVariations.forEach((model) => { const writeMultiplier = getCacheMultiplier({ model, cacheType: 'write' }); const readMultiplier = getCacheMultiplier({ model, cacheType: 'read' }); - expect(writeMultiplier).toBe(0.28); - expect(readMultiplier).toBe(0.028); + expect(writeMultiplier).toBe(cacheTokenValues['deepseek-chat'].write); + expect(readMultiplier).toBe(cacheTokenValues['deepseek-chat'].read); + }); + }); +}); + +describe('Moonshot/Kimi Model Tests - Pricing', () => { + describe('Kimi Models', () => { + it('should return correct pricing for kimi base pattern', () => { + expect(getMultiplier({ model: 'kimi', tokenType: 'prompt' })).toBe( + tokenValues['kimi'].prompt, + ); + expect(getMultiplier({ model: 'kimi', tokenType: 'completion' })).toBe( + tokenValues['kimi'].completion, + ); + }); + + it('should return correct pricing for kimi-k2.5', () => { + expect(getMultiplier({ model: 'kimi-k2.5', tokenType: 'prompt' })).toBe( + tokenValues['kimi-k2.5'].prompt, + ); + expect(getMultiplier({ model: 'kimi-k2.5', tokenType: 'completion' })).toBe( + tokenValues['kimi-k2.5'].completion, + ); + }); + + it('should return correct pricing for kimi-k2 series', () => { + expect(getMultiplier({ model: 'kimi-k2', tokenType: 'prompt' })).toBe( + tokenValues['kimi-k2'].prompt, + ); + expect(getMultiplier({ model: 'kimi-k2', tokenType: 'completion' })).toBe( + tokenValues['kimi-k2'].completion, + ); + }); + + it('should return correct pricing for kimi-k2-turbo (higher pricing)', () => { + expect(getMultiplier({ model: 'kimi-k2-turbo', tokenType: 'prompt' })).toBe( + tokenValues['kimi-k2-turbo'].prompt, + ); + expect(getMultiplier({ model: 'kimi-k2-turbo', tokenType: 'completion' })).toBe( + tokenValues['kimi-k2-turbo'].completion, + ); + }); + + it('should return correct pricing for kimi-k2-thinking models', () => { + expect(getMultiplier({ model: 'kimi-k2-thinking', tokenType: 'prompt' })).toBe( + tokenValues['kimi-k2-thinking'].prompt, + ); + expect(getMultiplier({ model: 'kimi-k2-thinking', tokenType: 'completion' })).toBe( + tokenValues['kimi-k2-thinking'].completion, + ); + expect(getMultiplier({ model: 'kimi-k2-thinking-turbo', tokenType: 'prompt' })).toBe( + tokenValues['kimi-k2-thinking-turbo'].prompt, + ); + expect(getMultiplier({ model: 'kimi-k2-thinking-turbo', tokenType: 'completion' })).toBe( + tokenValues['kimi-k2-thinking-turbo'].completion, + ); + }); + + it('should handle Kimi model variations with provider prefixes', () => { + const modelVariations = ['openrouter/kimi-k2', 'openrouter/kimi-k2.5', 'openrouter/kimi']; + + modelVariations.forEach((model) => { + const promptMultiplier = getMultiplier({ model, tokenType: 'prompt' }); + const completionMultiplier = getMultiplier({ model, tokenType: 'completion' }); + expect(promptMultiplier).toBe(tokenValues['kimi'].prompt); + expect([tokenValues['kimi'].completion, tokenValues['kimi-k2.5'].completion]).toContain( + completionMultiplier, + ); + }); + }); + }); + + describe('Moonshot Models', () => { + it('should return correct pricing for moonshot base pattern (128k pricing)', () => { + expect(getMultiplier({ model: 'moonshot', tokenType: 'prompt' })).toBe( + tokenValues['moonshot'].prompt, + ); + expect(getMultiplier({ model: 'moonshot', tokenType: 'completion' })).toBe( + tokenValues['moonshot'].completion, + ); + }); + + it('should return correct pricing for moonshot-v1-8k', () => { + expect(getMultiplier({ model: 'moonshot-v1-8k', tokenType: 'prompt' })).toBe( + tokenValues['moonshot-v1-8k'].prompt, + ); + expect(getMultiplier({ model: 'moonshot-v1-8k', tokenType: 'completion' })).toBe( + tokenValues['moonshot-v1-8k'].completion, + ); + }); + + it('should return correct pricing for moonshot-v1-32k', () => { + expect(getMultiplier({ model: 'moonshot-v1-32k', tokenType: 'prompt' })).toBe( + tokenValues['moonshot-v1-32k'].prompt, + ); + expect(getMultiplier({ model: 'moonshot-v1-32k', tokenType: 'completion' })).toBe( + tokenValues['moonshot-v1-32k'].completion, + ); + }); + + it('should return correct pricing for moonshot-v1-128k', () => { + expect(getMultiplier({ model: 'moonshot-v1-128k', tokenType: 'prompt' })).toBe( + tokenValues['moonshot-v1-128k'].prompt, + ); + expect(getMultiplier({ model: 'moonshot-v1-128k', tokenType: 'completion' })).toBe( + tokenValues['moonshot-v1-128k'].completion, + ); + }); + + it('should return correct pricing for moonshot-v1 vision models', () => { + expect(getMultiplier({ model: 'moonshot-v1-8k-vision', tokenType: 'prompt' })).toBe( + tokenValues['moonshot-v1-8k-vision'].prompt, + ); + expect(getMultiplier({ model: 'moonshot-v1-8k-vision', tokenType: 'completion' })).toBe( + tokenValues['moonshot-v1-8k-vision'].completion, + ); + expect(getMultiplier({ model: 'moonshot-v1-32k-vision', tokenType: 'prompt' })).toBe( + tokenValues['moonshot-v1-32k-vision'].prompt, + ); + expect(getMultiplier({ model: 'moonshot-v1-32k-vision', tokenType: 'completion' })).toBe( + tokenValues['moonshot-v1-32k-vision'].completion, + ); + expect(getMultiplier({ model: 'moonshot-v1-128k-vision', tokenType: 'prompt' })).toBe( + tokenValues['moonshot-v1-128k-vision'].prompt, + ); + expect(getMultiplier({ model: 'moonshot-v1-128k-vision', tokenType: 'completion' })).toBe( + tokenValues['moonshot-v1-128k-vision'].completion, + ); + }); + }); + + describe('Kimi Cache Multipliers', () => { + it('should return correct cache multipliers for kimi-k2 models', () => { + expect(getCacheMultiplier({ model: 'kimi', cacheType: 'write' })).toBe( + cacheTokenValues['kimi'].write, + ); + expect(getCacheMultiplier({ model: 'kimi', cacheType: 'read' })).toBe( + cacheTokenValues['kimi'].read, + ); + }); + + it('should return correct cache multipliers for kimi-k2.5 (lower read price)', () => { + expect(getCacheMultiplier({ model: 'kimi-k2.5', cacheType: 'write' })).toBe( + cacheTokenValues['kimi-k2.5'].write, + ); + expect(getCacheMultiplier({ model: 'kimi-k2.5', cacheType: 'read' })).toBe( + cacheTokenValues['kimi-k2.5'].read, + ); + }); + + it('should return correct cache multipliers for kimi-k2-turbo', () => { + expect(getCacheMultiplier({ model: 'kimi-k2-turbo', cacheType: 'write' })).toBe( + cacheTokenValues['kimi-k2-turbo'].write, + ); + expect(getCacheMultiplier({ model: 'kimi-k2-turbo', cacheType: 'read' })).toBe( + cacheTokenValues['kimi-k2-turbo'].read, + ); + }); + + it('should handle Kimi cache multipliers with model variations', () => { + const modelVariations = ['openrouter/kimi-k2', 'openrouter/kimi']; + + modelVariations.forEach((model) => { + const writeMultiplier = getCacheMultiplier({ model, cacheType: 'write' }); + const readMultiplier = getCacheMultiplier({ model, cacheType: 'read' }); + expect(writeMultiplier).toBe(cacheTokenValues['kimi'].write); + expect(readMultiplier).toBe(cacheTokenValues['kimi'].read); + }); + }); + }); + + describe('Bedrock Moonshot Models', () => { + it('should return correct pricing for Bedrock moonshot models', () => { + expect(getMultiplier({ model: 'moonshot.kimi', tokenType: 'prompt' })).toBe( + tokenValues['moonshot.kimi'].prompt, + ); + expect(getMultiplier({ model: 'moonshot.kimi', tokenType: 'completion' })).toBe( + tokenValues['moonshot.kimi'].completion, + ); + expect(getMultiplier({ model: 'moonshot.kimi-k2', tokenType: 'prompt' })).toBe( + tokenValues['moonshot.kimi-k2'].prompt, + ); + expect(getMultiplier({ model: 'moonshot.kimi-k2.5', tokenType: 'prompt' })).toBe( + tokenValues['moonshot.kimi-k2.5'].prompt, + ); + expect(getMultiplier({ model: 'moonshot.kimi-k2.5', tokenType: 'completion' })).toBe( + tokenValues['moonshot.kimi-k2.5'].completion, + ); }); }); }); @@ -1689,6 +1880,201 @@ describe('Claude Model Tests', () => { ); }); }); + + it('should return correct prompt and completion rates for Claude Opus 4.6', () => { + expect(getMultiplier({ model: 'claude-opus-4-6', tokenType: 'prompt' })).toBe( + tokenValues['claude-opus-4-6'].prompt, + ); + expect(getMultiplier({ model: 'claude-opus-4-6', tokenType: 'completion' })).toBe( + tokenValues['claude-opus-4-6'].completion, + ); + }); + + it('should handle Claude Opus 4.6 model name variations', () => { + const modelVariations = [ + 'claude-opus-4-6', + 'claude-opus-4-6-20250801', + 'claude-opus-4-6-latest', + 'anthropic/claude-opus-4-6', + 'claude-opus-4-6/anthropic', + 'claude-opus-4-6-preview', + ]; + + modelVariations.forEach((model) => { + const valueKey = getValueKey(model); + expect(valueKey).toBe('claude-opus-4-6'); + expect(getMultiplier({ model, tokenType: 'prompt' })).toBe( + tokenValues['claude-opus-4-6'].prompt, + ); + expect(getMultiplier({ model, tokenType: 'completion' })).toBe( + tokenValues['claude-opus-4-6'].completion, + ); + }); + }); + + it('should return correct cache rates for Claude Opus 4.6', () => { + expect(getCacheMultiplier({ model: 'claude-opus-4-6', cacheType: 'write' })).toBe( + cacheTokenValues['claude-opus-4-6'].write, + ); + expect(getCacheMultiplier({ model: 'claude-opus-4-6', cacheType: 'read' })).toBe( + cacheTokenValues['claude-opus-4-6'].read, + ); + }); + + it('should handle Claude Opus 4.6 cache rates with model name variations', () => { + const modelVariations = [ + 'claude-opus-4-6', + 'claude-opus-4-6-20250801', + 'claude-opus-4-6-latest', + 'anthropic/claude-opus-4-6', + 'claude-opus-4-6/anthropic', + 'claude-opus-4-6-preview', + ]; + + modelVariations.forEach((model) => { + expect(getCacheMultiplier({ model, cacheType: 'write' })).toBe( + cacheTokenValues['claude-opus-4-6'].write, + ); + expect(getCacheMultiplier({ model, cacheType: 'read' })).toBe( + cacheTokenValues['claude-opus-4-6'].read, + ); + }); + }); +}); + +describe('Premium Token Pricing', () => { + const premiumModel = 'claude-opus-4-6'; + const premiumEntry = premiumTokenValues[premiumModel]; + const { threshold } = premiumEntry; + const belowThreshold = threshold - 1; + const aboveThreshold = threshold + 1; + const wellAboveThreshold = threshold * 2; + + it('should have premium pricing defined for claude-opus-4-6', () => { + expect(premiumEntry).toBeDefined(); + expect(premiumEntry.threshold).toBeDefined(); + expect(premiumEntry.prompt).toBeDefined(); + expect(premiumEntry.completion).toBeDefined(); + expect(premiumEntry.prompt).toBeGreaterThan(tokenValues[premiumModel].prompt); + expect(premiumEntry.completion).toBeGreaterThan(tokenValues[premiumModel].completion); + }); + + it('should return null from getPremiumRate when inputTokenCount is below threshold', () => { + expect(getPremiumRate(premiumModel, 'prompt', belowThreshold)).toBeNull(); + expect(getPremiumRate(premiumModel, 'completion', belowThreshold)).toBeNull(); + expect(getPremiumRate(premiumModel, 'prompt', threshold)).toBeNull(); + }); + + it('should return premium rate from getPremiumRate when inputTokenCount exceeds threshold', () => { + expect(getPremiumRate(premiumModel, 'prompt', aboveThreshold)).toBe(premiumEntry.prompt); + expect(getPremiumRate(premiumModel, 'completion', aboveThreshold)).toBe( + premiumEntry.completion, + ); + expect(getPremiumRate(premiumModel, 'prompt', wellAboveThreshold)).toBe(premiumEntry.prompt); + }); + + it('should return null from getPremiumRate when inputTokenCount is undefined or null', () => { + expect(getPremiumRate(premiumModel, 'prompt', undefined)).toBeNull(); + expect(getPremiumRate(premiumModel, 'prompt', null)).toBeNull(); + }); + + it('should return null from getPremiumRate for models without premium pricing', () => { + expect(getPremiumRate('claude-opus-4-5', 'prompt', wellAboveThreshold)).toBeNull(); + expect(getPremiumRate('claude-sonnet-4', 'prompt', wellAboveThreshold)).toBeNull(); + expect(getPremiumRate('gpt-4o', 'prompt', wellAboveThreshold)).toBeNull(); + }); + + it('should return standard rate from getMultiplier when inputTokenCount is below threshold', () => { + expect( + getMultiplier({ + model: premiumModel, + tokenType: 'prompt', + inputTokenCount: belowThreshold, + }), + ).toBe(tokenValues[premiumModel].prompt); + expect( + getMultiplier({ + model: premiumModel, + tokenType: 'completion', + inputTokenCount: belowThreshold, + }), + ).toBe(tokenValues[premiumModel].completion); + }); + + it('should return premium rate from getMultiplier when inputTokenCount exceeds threshold', () => { + expect( + getMultiplier({ + model: premiumModel, + tokenType: 'prompt', + inputTokenCount: aboveThreshold, + }), + ).toBe(premiumEntry.prompt); + expect( + getMultiplier({ + model: premiumModel, + tokenType: 'completion', + inputTokenCount: aboveThreshold, + }), + ).toBe(premiumEntry.completion); + }); + + it('should return standard rate from getMultiplier when inputTokenCount is exactly at threshold', () => { + expect( + getMultiplier({ model: premiumModel, tokenType: 'prompt', inputTokenCount: threshold }), + ).toBe(tokenValues[premiumModel].prompt); + }); + + it('should return premium rate from getMultiplier when inputTokenCount is one above threshold', () => { + expect( + getMultiplier({ model: premiumModel, tokenType: 'prompt', inputTokenCount: aboveThreshold }), + ).toBe(premiumEntry.prompt); + }); + + it('should not apply premium pricing to models without premium entries', () => { + expect( + getMultiplier({ + model: 'claude-opus-4-5', + tokenType: 'prompt', + inputTokenCount: wellAboveThreshold, + }), + ).toBe(tokenValues['claude-opus-4-5'].prompt); + expect( + getMultiplier({ + model: 'claude-sonnet-4', + tokenType: 'prompt', + inputTokenCount: wellAboveThreshold, + }), + ).toBe(tokenValues['claude-sonnet-4'].prompt); + }); + + it('should use standard rate when inputTokenCount is not provided', () => { + expect(getMultiplier({ model: premiumModel, tokenType: 'prompt' })).toBe( + tokenValues[premiumModel].prompt, + ); + expect(getMultiplier({ model: premiumModel, tokenType: 'completion' })).toBe( + tokenValues[premiumModel].completion, + ); + }); + + it('should apply premium pricing through getMultiplier with valueKey path', () => { + const valueKey = getValueKey(premiumModel); + expect(getMultiplier({ valueKey, tokenType: 'prompt', inputTokenCount: aboveThreshold })).toBe( + premiumEntry.prompt, + ); + expect( + getMultiplier({ valueKey, tokenType: 'completion', inputTokenCount: aboveThreshold }), + ).toBe(premiumEntry.completion); + }); + + it('should apply standard pricing through getMultiplier with valueKey path when below threshold', () => { + const valueKey = getValueKey(premiumModel); + expect(getMultiplier({ valueKey, tokenType: 'prompt', inputTokenCount: belowThreshold })).toBe( + tokenValues[premiumModel].prompt, + ); + expect( + getMultiplier({ valueKey, tokenType: 'completion', inputTokenCount: belowThreshold }), + ).toBe(tokenValues[premiumModel].completion); + }); }); describe('tokens.ts and tx.js sync validation', () => { diff --git a/api/package.json b/api/package.json index 0881070652..bc212227d3 100644 --- a/api/package.json +++ b/api/package.json @@ -1,6 +1,6 @@ { "name": "@librechat/backend", - "version": "v0.8.2-rc2", + "version": "v0.8.2", "description": "", "scripts": { "start": "echo 'please run this from the root directory'", @@ -34,25 +34,24 @@ }, "homepage": "https://librechat.ai", "dependencies": { - "@anthropic-ai/sdk": "^0.71.0", - "@anthropic-ai/vertex-sdk": "^0.14.0", - "@aws-sdk/client-bedrock-runtime": "^3.941.0", - "@aws-sdk/client-s3": "^3.758.0", + "@anthropic-ai/vertex-sdk": "^0.14.3", + "@aws-sdk/client-bedrock-runtime": "^3.980.0", + "@aws-sdk/client-s3": "^3.980.0", "@aws-sdk/s3-request-presigner": "^3.758.0", "@azure/identity": "^4.7.0", "@azure/search-documents": "^12.0.0", - "@azure/storage-blob": "^12.27.0", + "@azure/storage-blob": "^12.30.0", "@google/genai": "^1.19.0", "@keyv/redis": "^4.3.3", "@langchain/core": "^0.3.80", - "@librechat/agents": "^3.0.77", + "@librechat/agents": "^3.1.50", "@librechat/api": "*", "@librechat/data-schemas": "*", "@microsoft/microsoft-graph-client": "^3.0.7", - "@modelcontextprotocol/sdk": "^1.25.2", + "@modelcontextprotocol/sdk": "^1.26.0", "@node-saml/passport-saml": "^5.1.0", "@smithy/node-http-handler": "^4.4.5", - "axios": "^1.12.1", + "axios": "^1.13.5", "bcryptjs": "^2.4.3", "compression": "^1.8.1", "connect-redis": "^8.1.0", @@ -80,7 +79,7 @@ "keyv-file": "^5.1.2", "klona": "^2.0.6", "librechat-data-provider": "*", - "lodash": "^4.17.21", + "lodash": "^4.17.23", "mathjs": "^15.1.0", "meilisearch": "^0.38.0", "memorystore": "^1.6.7", @@ -108,7 +107,7 @@ "tiktoken": "^1.0.15", "traverse": "^0.6.7", "ua-parser-js": "^1.0.36", - "undici": "^7.10.0", + "undici": "^7.18.2", "winston": "^3.11.0", "winston-daily-rotate-file": "^5.0.0", "zod": "^3.22.4" diff --git a/api/server/controllers/AuthController.js b/api/server/controllers/AuthController.js index 22e53dcfc9..588391b535 100644 --- a/api/server/controllers/AuthController.js +++ b/api/server/controllers/AuthController.js @@ -18,7 +18,6 @@ const { findUser, } = require('~/models'); const { getGraphApiToken } = require('~/server/services/GraphTokenService'); -const { getOAuthReconnectionManager } = require('~/config'); const { getOpenIdConfig } = require('~/strategies'); const registrationController = async (req, res) => { @@ -79,7 +78,12 @@ const refreshController = async (req, res) => { try { const openIdConfig = getOpenIdConfig(); - const tokenset = await openIdClient.refreshTokenGrant(openIdConfig, refreshToken); + const refreshParams = process.env.OPENID_SCOPE ? { scope: process.env.OPENID_SCOPE } : {}; + const tokenset = await openIdClient.refreshTokenGrant( + openIdConfig, + refreshToken, + refreshParams, + ); const claims = tokenset.claims(); const { user, error, migration } = await findOpenIDUser({ findUser, @@ -161,17 +165,6 @@ const refreshController = async (req, res) => { if (session && session.expiration > new Date()) { const token = await setAuthTokens(userId, res, session); - // trigger OAuth MCP server reconnection asynchronously (best effort) - try { - void getOAuthReconnectionManager() - .reconnectServers(userId) - .catch((err) => { - logger.error('[refreshController] Error reconnecting OAuth MCP servers:', err); - }); - } catch (err) { - logger.warn(`[refreshController] Cannot attempt OAuth MCP servers reconnection:`, err); - } - res.status(200).send({ token, user }); } else if (req?.query?.retry) { // Retrying from a refresh token request that failed (401) diff --git a/api/server/controllers/PermissionsController.js b/api/server/controllers/PermissionsController.js index e22e9532c9..51993d083c 100644 --- a/api/server/controllers/PermissionsController.js +++ b/api/server/controllers/PermissionsController.js @@ -5,6 +5,7 @@ const mongoose = require('mongoose'); const { logger } = require('@librechat/data-schemas'); const { ResourceType, PrincipalType, PermissionBits } = require('librechat-data-provider'); +const { enrichRemoteAgentPrincipals, backfillRemoteAgentPermissions } = require('@librechat/api'); const { bulkUpdateResourcePermissions, ensureGroupPrincipalExists, @@ -14,7 +15,6 @@ const { findAccessibleResources, getResourcePermissionsMap, } = require('~/server/services/PermissionService'); -const { AclEntry } = require('~/db/models'); const { searchPrincipals: searchLocalPrincipals, sortPrincipalsByRelevance, @@ -24,6 +24,7 @@ const { entraIdPrincipalFeatureEnabled, searchEntraIdPrincipals, } = require('~/server/services/GraphApiService'); +const { AclEntry, AccessRole } = require('~/db/models'); /** * Generic controller for resource permission endpoints @@ -234,7 +235,7 @@ const getResourcePermissions = async (req, res) => { }, ]); - const principals = []; + let principals = []; let publicPermission = null; // Process aggregation results @@ -280,6 +281,13 @@ const getResourcePermissions = async (req, res) => { } } + if (resourceType === ResourceType.REMOTE_AGENT) { + const enricherDeps = { AclEntry, AccessRole, logger }; + const enrichResult = await enrichRemoteAgentPrincipals(enricherDeps, resourceId, principals); + principals = enrichResult.principals; + backfillRemoteAgentPermissions(enricherDeps, resourceId, enrichResult.entriesToBackfill); + } + // Return response in format expected by frontend const response = { resourceType, diff --git a/api/server/controllers/PluginController.js b/api/server/controllers/PluginController.js index c5e074b8ff..279ffb15fd 100644 --- a/api/server/controllers/PluginController.js +++ b/api/server/controllers/PluginController.js @@ -8,7 +8,7 @@ const { getLogStores } = require('~/cache'); const getAvailablePluginsController = async (req, res) => { try { - const cache = getLogStores(CacheKeys.CONFIG_STORE); + const cache = getLogStores(CacheKeys.TOOL_CACHE); const cachedPlugins = await cache.get(CacheKeys.PLUGINS); if (cachedPlugins) { res.status(200).json(cachedPlugins); @@ -63,7 +63,7 @@ const getAvailableTools = async (req, res) => { logger.warn('[getAvailableTools] User ID not found in request'); return res.status(401).json({ message: 'Unauthorized' }); } - const cache = getLogStores(CacheKeys.CONFIG_STORE); + const cache = getLogStores(CacheKeys.TOOL_CACHE); const cachedToolsArray = await cache.get(CacheKeys.TOOLS); const appConfig = req.config ?? (await getAppConfig({ role: req.user?.role })); diff --git a/api/server/controllers/PluginController.spec.js b/api/server/controllers/PluginController.spec.js index d7d3f83a8b..06a51a3bd6 100644 --- a/api/server/controllers/PluginController.spec.js +++ b/api/server/controllers/PluginController.spec.js @@ -1,3 +1,4 @@ +const { CacheKeys } = require('librechat-data-provider'); const { getCachedTools, getAppConfig } = require('~/server/services/Config'); const { getLogStores } = require('~/cache'); @@ -63,6 +64,28 @@ describe('PluginController', () => { }); }); + describe('cache namespace', () => { + it('getAvailablePluginsController should use TOOL_CACHE namespace', async () => { + mockCache.get.mockResolvedValue([]); + await getAvailablePluginsController(mockReq, mockRes); + expect(getLogStores).toHaveBeenCalledWith(CacheKeys.TOOL_CACHE); + }); + + it('getAvailableTools should use TOOL_CACHE namespace', async () => { + mockCache.get.mockResolvedValue([]); + await getAvailableTools(mockReq, mockRes); + expect(getLogStores).toHaveBeenCalledWith(CacheKeys.TOOL_CACHE); + }); + + it('should NOT use CONFIG_STORE namespace for tool/plugin operations', async () => { + mockCache.get.mockResolvedValue([]); + await getAvailablePluginsController(mockReq, mockRes); + await getAvailableTools(mockReq, mockRes); + const allCalls = getLogStores.mock.calls.flat(); + expect(allCalls).not.toContain(CacheKeys.CONFIG_STORE); + }); + }); + describe('getAvailablePluginsController', () => { it('should use filterUniquePlugins to remove duplicate plugins', async () => { // Add plugins with duplicates to availableTools diff --git a/api/server/controllers/UserController.js b/api/server/controllers/UserController.js index b0cfd7ede2..7a9dd8125e 100644 --- a/api/server/controllers/UserController.js +++ b/api/server/controllers/UserController.js @@ -22,6 +22,7 @@ const { } = require('~/models'); const { ConversationTag, + AgentApiKey, Transaction, MemoryEntry, Assistant, @@ -35,6 +36,7 @@ const { const { updateUserPluginAuth, deleteUserPluginAuth } = require('~/server/services/PluginService'); const { verifyEmail, resendVerificationEmail } = require('~/server/services/AuthService'); const { getMCPManager, getFlowStateManager, getMCPServersRegistry } = require('~/config'); +const { invalidateCachedTools } = require('~/server/services/Config/getCachedTools'); const { needsRefresh, getNewS3URL } = require('~/server/services/Files/S3/crud'); const { processDeleteRequest } = require('~/server/services/Files/process'); const { getAppConfig } = require('~/server/services/Config'); @@ -214,6 +216,7 @@ const updateUserPluginsController = async (req, res) => { `[updateUserPluginsController] Attempting disconnect of MCP server "${serverName}" for user ${user.id} after plugin auth update.`, ); await mcpManager.disconnectUserConnection(user.id, serverName); + await invalidateCachedTools({ userId: user.id, serverName }); } } catch (disconnectError) { logger.error( @@ -256,6 +259,7 @@ const deleteUserController = async (req, res) => { await deleteFiles(null, user.id); // delete database files in case of orphaned files from previous steps await deleteToolCalls(user.id); // delete user tool calls await deleteUserAgents(user.id); // delete user agents + await AgentApiKey.deleteMany({ user: user._id }); // delete user agent API keys await Assistant.deleteMany({ user: user.id }); // delete user assistants await ConversationTag.deleteMany({ user: user.id }); // delete user conversation tags await MemoryEntry.deleteMany({ userId: user.id }); // delete user memory entries diff --git a/api/server/controllers/agents/__tests__/callbacks.spec.js b/api/server/controllers/agents/__tests__/callbacks.spec.js index 7922c31efa..8bd711f9c6 100644 --- a/api/server/controllers/agents/__tests__/callbacks.spec.js +++ b/api/server/controllers/agents/__tests__/callbacks.spec.js @@ -16,13 +16,10 @@ jest.mock('@librechat/data-schemas', () => ({ })); jest.mock('@librechat/agents', () => ({ - EnvVar: { CODE_API_KEY: 'CODE_API_KEY' }, - Providers: { GOOGLE: 'google' }, - GraphEvents: {}, + ...jest.requireActual('@librechat/agents'), getMessageId: jest.fn(), ToolEndHandler: jest.fn(), handleToolCalls: jest.fn(), - ChatModelStreamHandler: jest.fn(), })); jest.mock('~/server/services/Files/Citations', () => ({ diff --git a/api/server/controllers/agents/__tests__/jobReplacement.spec.js b/api/server/controllers/agents/__tests__/jobReplacement.spec.js new file mode 100644 index 0000000000..efa79ca4ba --- /dev/null +++ b/api/server/controllers/agents/__tests__/jobReplacement.spec.js @@ -0,0 +1,281 @@ +/** + * Tests for job replacement detection in ResumableAgentController + * + * Tests the following fixes from PR #11462: + * 1. Job creation timestamp tracking + * 2. Stale job detection and event skipping + * 3. Response message saving before final event emission + */ + +const mockLogger = { + debug: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + info: jest.fn(), +}; + +const mockGenerationJobManager = { + createJob: jest.fn(), + getJob: jest.fn(), + emitDone: jest.fn(), + emitChunk: jest.fn(), + completeJob: jest.fn(), + updateMetadata: jest.fn(), + setContentParts: jest.fn(), + subscribe: jest.fn(), +}; + +const mockSaveMessage = jest.fn(); +const mockDecrementPendingRequest = jest.fn(); + +jest.mock('@librechat/data-schemas', () => ({ + logger: mockLogger, +})); + +jest.mock('@librechat/api', () => ({ + isEnabled: jest.fn().mockReturnValue(false), + GenerationJobManager: mockGenerationJobManager, + checkAndIncrementPendingRequest: jest.fn().mockResolvedValue({ allowed: true }), + decrementPendingRequest: (...args) => mockDecrementPendingRequest(...args), + getViolationInfo: jest.fn(), + sanitizeMessageForTransmit: jest.fn((msg) => msg), + sanitizeFileForTransmit: jest.fn((file) => file), + Constants: { NO_PARENT: '00000000-0000-0000-0000-000000000000' }, +})); + +jest.mock('~/models', () => ({ + saveMessage: (...args) => mockSaveMessage(...args), +})); + +describe('Job Replacement Detection', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + describe('Job Creation Timestamp Tracking', () => { + it('should capture createdAt when job is created', async () => { + const streamId = 'test-stream-123'; + const createdAt = Date.now(); + + mockGenerationJobManager.createJob.mockResolvedValue({ + createdAt, + readyPromise: Promise.resolve(), + abortController: new AbortController(), + emitter: { on: jest.fn() }, + }); + + const job = await mockGenerationJobManager.createJob(streamId, 'user-123', streamId); + + expect(job.createdAt).toBe(createdAt); + }); + }); + + describe('Job Replacement Detection Logic', () => { + /** + * Simulates the job replacement detection logic from request.js + * This is extracted for unit testing since the full controller is complex + */ + const detectJobReplacement = async (streamId, originalCreatedAt) => { + const currentJob = await mockGenerationJobManager.getJob(streamId); + return !currentJob || currentJob.createdAt !== originalCreatedAt; + }; + + it('should detect when job was replaced (different createdAt)', async () => { + const streamId = 'test-stream-123'; + const originalCreatedAt = 1000; + const newCreatedAt = 2000; + + mockGenerationJobManager.getJob.mockResolvedValue({ + createdAt: newCreatedAt, + }); + + const wasReplaced = await detectJobReplacement(streamId, originalCreatedAt); + + expect(wasReplaced).toBe(true); + }); + + it('should detect when job was deleted', async () => { + const streamId = 'test-stream-123'; + const originalCreatedAt = 1000; + + mockGenerationJobManager.getJob.mockResolvedValue(null); + + const wasReplaced = await detectJobReplacement(streamId, originalCreatedAt); + + expect(wasReplaced).toBe(true); + }); + + it('should not detect replacement when same job (same createdAt)', async () => { + const streamId = 'test-stream-123'; + const originalCreatedAt = 1000; + + mockGenerationJobManager.getJob.mockResolvedValue({ + createdAt: originalCreatedAt, + }); + + const wasReplaced = await detectJobReplacement(streamId, originalCreatedAt); + + expect(wasReplaced).toBe(false); + }); + }); + + describe('Event Emission Behavior', () => { + /** + * Simulates the final event emission logic from request.js + */ + const emitFinalEventIfNotReplaced = async ({ + streamId, + originalCreatedAt, + finalEvent, + userId, + }) => { + const currentJob = await mockGenerationJobManager.getJob(streamId); + const jobWasReplaced = !currentJob || currentJob.createdAt !== originalCreatedAt; + + if (jobWasReplaced) { + mockLogger.debug('Skipping FINAL emit - job was replaced', { + streamId, + originalCreatedAt, + currentCreatedAt: currentJob?.createdAt, + }); + await mockDecrementPendingRequest(userId); + return false; + } + + mockGenerationJobManager.emitDone(streamId, finalEvent); + mockGenerationJobManager.completeJob(streamId); + await mockDecrementPendingRequest(userId); + return true; + }; + + it('should skip emitting when job was replaced', async () => { + const streamId = 'test-stream-123'; + const originalCreatedAt = 1000; + const newCreatedAt = 2000; + const userId = 'user-123'; + + mockGenerationJobManager.getJob.mockResolvedValue({ + createdAt: newCreatedAt, + }); + + const emitted = await emitFinalEventIfNotReplaced({ + streamId, + originalCreatedAt, + finalEvent: { final: true }, + userId, + }); + + expect(emitted).toBe(false); + expect(mockGenerationJobManager.emitDone).not.toHaveBeenCalled(); + expect(mockGenerationJobManager.completeJob).not.toHaveBeenCalled(); + expect(mockDecrementPendingRequest).toHaveBeenCalledWith(userId); + expect(mockLogger.debug).toHaveBeenCalledWith( + 'Skipping FINAL emit - job was replaced', + expect.objectContaining({ + streamId, + originalCreatedAt, + currentCreatedAt: newCreatedAt, + }), + ); + }); + + it('should emit when job was not replaced', async () => { + const streamId = 'test-stream-123'; + const originalCreatedAt = 1000; + const userId = 'user-123'; + const finalEvent = { final: true, conversation: { conversationId: streamId } }; + + mockGenerationJobManager.getJob.mockResolvedValue({ + createdAt: originalCreatedAt, + }); + + const emitted = await emitFinalEventIfNotReplaced({ + streamId, + originalCreatedAt, + finalEvent, + userId, + }); + + expect(emitted).toBe(true); + expect(mockGenerationJobManager.emitDone).toHaveBeenCalledWith(streamId, finalEvent); + expect(mockGenerationJobManager.completeJob).toHaveBeenCalledWith(streamId); + expect(mockDecrementPendingRequest).toHaveBeenCalledWith(userId); + }); + }); + + describe('Response Message Saving Order', () => { + /** + * Tests that response messages are saved BEFORE final events are emitted + * This prevents race conditions where clients send follow-up messages + * before the response is in the database + */ + it('should save message before emitting final event', async () => { + const callOrder = []; + + mockSaveMessage.mockImplementation(async () => { + callOrder.push('saveMessage'); + }); + + mockGenerationJobManager.emitDone.mockImplementation(() => { + callOrder.push('emitDone'); + }); + + mockGenerationJobManager.getJob.mockResolvedValue({ + createdAt: 1000, + }); + + // Simulate the order of operations from request.js + const streamId = 'test-stream-123'; + const originalCreatedAt = 1000; + const response = { messageId: 'response-123' }; + const userId = 'user-123'; + + // Step 1: Save message + await mockSaveMessage({}, { ...response, user: userId }, { context: 'test' }); + + // Step 2: Check for replacement + const currentJob = await mockGenerationJobManager.getJob(streamId); + const jobWasReplaced = !currentJob || currentJob.createdAt !== originalCreatedAt; + + // Step 3: Emit if not replaced + if (!jobWasReplaced) { + mockGenerationJobManager.emitDone(streamId, { final: true }); + } + + expect(callOrder).toEqual(['saveMessage', 'emitDone']); + }); + }); + + describe('Aborted Request Handling', () => { + it('should use unfinished: true instead of error: true for aborted requests', () => { + const response = { messageId: 'response-123', content: [] }; + + // The new format for aborted responses + const abortedResponse = { ...response, unfinished: true }; + + expect(abortedResponse.unfinished).toBe(true); + expect(abortedResponse.error).toBeUndefined(); + }); + + it('should include unfinished flag in final event for aborted requests', () => { + const response = { messageId: 'response-123', content: [] }; + + // Old format (deprecated) + const _oldFinalEvent = { + final: true, + responseMessage: { ...response, error: true }, + error: { message: 'Request was aborted' }, + }; + + // New format (PR #11462) + const newFinalEvent = { + final: true, + responseMessage: { ...response, unfinished: true }, + }; + + expect(newFinalEvent.responseMessage.unfinished).toBe(true); + expect(newFinalEvent.error).toBeUndefined(); + expect(newFinalEvent.responseMessage.error).toBeUndefined(); + }); + }); +}); diff --git a/api/server/controllers/agents/__tests__/openai.spec.js b/api/server/controllers/agents/__tests__/openai.spec.js new file mode 100644 index 0000000000..8592c79a2d --- /dev/null +++ b/api/server/controllers/agents/__tests__/openai.spec.js @@ -0,0 +1,204 @@ +/** + * Unit tests for OpenAI-compatible API controller + * Tests that recordCollectedUsage is called correctly for token spending + */ + +const mockSpendTokens = jest.fn().mockResolvedValue({}); +const mockSpendStructuredTokens = jest.fn().mockResolvedValue({}); +const mockRecordCollectedUsage = jest + .fn() + .mockResolvedValue({ input_tokens: 100, output_tokens: 50 }); +const mockGetBalanceConfig = jest.fn().mockReturnValue({ enabled: true }); +const mockGetTransactionsConfig = jest.fn().mockReturnValue({ enabled: true }); + +jest.mock('nanoid', () => ({ + nanoid: jest.fn(() => 'mock-nanoid-123'), +})); + +jest.mock('@librechat/data-schemas', () => ({ + logger: { + debug: jest.fn(), + error: jest.fn(), + warn: jest.fn(), + }, +})); + +jest.mock('@librechat/agents', () => ({ + Callback: { TOOL_ERROR: 'TOOL_ERROR' }, + ToolEndHandler: jest.fn(), + formatAgentMessages: jest.fn().mockReturnValue({ + messages: [], + indexTokenCountMap: {}, + }), +})); + +jest.mock('@librechat/api', () => ({ + writeSSE: jest.fn(), + createRun: jest.fn().mockResolvedValue({ + processStream: jest.fn().mockResolvedValue(undefined), + }), + createChunk: jest.fn().mockReturnValue({}), + buildToolSet: jest.fn().mockReturnValue(new Set()), + sendFinalChunk: jest.fn(), + createSafeUser: jest.fn().mockReturnValue({ id: 'user-123' }), + validateRequest: jest + .fn() + .mockReturnValue({ request: { model: 'agent-123', messages: [], stream: false } }), + initializeAgent: jest.fn().mockResolvedValue({ + model: 'gpt-4', + model_parameters: {}, + toolRegistry: {}, + }), + getBalanceConfig: mockGetBalanceConfig, + createErrorResponse: jest.fn(), + getTransactionsConfig: mockGetTransactionsConfig, + recordCollectedUsage: mockRecordCollectedUsage, + buildNonStreamingResponse: jest.fn().mockReturnValue({ id: 'resp-123' }), + createOpenAIStreamTracker: jest.fn().mockReturnValue({ + addText: jest.fn(), + addReasoning: jest.fn(), + toolCalls: new Map(), + usage: { promptTokens: 0, completionTokens: 0, reasoningTokens: 0 }, + }), + createOpenAIContentAggregator: jest.fn().mockReturnValue({ + addText: jest.fn(), + addReasoning: jest.fn(), + getText: jest.fn().mockReturnValue(''), + getReasoning: jest.fn().mockReturnValue(''), + toolCalls: new Map(), + usage: { promptTokens: 100, completionTokens: 50, reasoningTokens: 0 }, + }), + createToolExecuteHandler: jest.fn().mockReturnValue({ handle: jest.fn() }), + isChatCompletionValidationFailure: jest.fn().mockReturnValue(false), +})); + +jest.mock('~/server/services/ToolService', () => ({ + loadAgentTools: jest.fn().mockResolvedValue([]), + loadToolsForExecution: jest.fn().mockResolvedValue([]), +})); + +jest.mock('~/models/spendTokens', () => ({ + spendTokens: mockSpendTokens, + spendStructuredTokens: mockSpendStructuredTokens, +})); + +jest.mock('~/server/controllers/agents/callbacks', () => ({ + createToolEndCallback: jest.fn().mockReturnValue(jest.fn()), +})); + +jest.mock('~/server/services/PermissionService', () => ({ + findAccessibleResources: jest.fn().mockResolvedValue([]), +})); + +jest.mock('~/models/Conversation', () => ({ + getConvoFiles: jest.fn().mockResolvedValue([]), +})); + +jest.mock('~/models/Agent', () => ({ + getAgent: jest.fn().mockResolvedValue({ + id: 'agent-123', + provider: 'openAI', + model_parameters: { model: 'gpt-4' }, + }), + getAgents: jest.fn().mockResolvedValue([]), +})); + +jest.mock('~/models', () => ({ + getFiles: jest.fn(), + getUserKey: jest.fn(), + getMessages: jest.fn(), + updateFilesUsage: jest.fn(), + getUserKeyValues: jest.fn(), + getUserCodeFiles: jest.fn(), + getToolFilesByIds: jest.fn(), + getCodeGeneratedFiles: jest.fn(), +})); + +describe('OpenAIChatCompletionController', () => { + let OpenAIChatCompletionController; + let req, res; + + beforeEach(() => { + jest.clearAllMocks(); + + const controller = require('../openai'); + OpenAIChatCompletionController = controller.OpenAIChatCompletionController; + + req = { + body: { + model: 'agent-123', + messages: [{ role: 'user', content: 'Hello' }], + stream: false, + }, + user: { id: 'user-123' }, + config: { + endpoints: { + agents: { allowedProviders: ['openAI'] }, + }, + }, + on: jest.fn(), + }; + + res = { + status: jest.fn().mockReturnThis(), + json: jest.fn(), + setHeader: jest.fn(), + flushHeaders: jest.fn(), + end: jest.fn(), + write: jest.fn(), + }; + }); + + describe('token usage recording', () => { + it('should call recordCollectedUsage after successful non-streaming completion', async () => { + await OpenAIChatCompletionController(req, res); + + expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1); + expect(mockRecordCollectedUsage).toHaveBeenCalledWith( + { spendTokens: mockSpendTokens, spendStructuredTokens: mockSpendStructuredTokens }, + expect.objectContaining({ + user: 'user-123', + conversationId: expect.any(String), + collectedUsage: expect.any(Array), + context: 'message', + balance: { enabled: true }, + transactions: { enabled: true }, + }), + ); + }); + + it('should pass balance and transactions config to recordCollectedUsage', async () => { + mockGetBalanceConfig.mockReturnValue({ enabled: true, startBalance: 1000 }); + mockGetTransactionsConfig.mockReturnValue({ enabled: true, rateLimit: 100 }); + + await OpenAIChatCompletionController(req, res); + + expect(mockRecordCollectedUsage).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + balance: { enabled: true, startBalance: 1000 }, + transactions: { enabled: true, rateLimit: 100 }, + }), + ); + }); + + it('should pass spendTokens and spendStructuredTokens as dependencies', async () => { + await OpenAIChatCompletionController(req, res); + + const [deps] = mockRecordCollectedUsage.mock.calls[0]; + expect(deps).toHaveProperty('spendTokens', mockSpendTokens); + expect(deps).toHaveProperty('spendStructuredTokens', mockSpendStructuredTokens); + }); + + it('should include model from primaryConfig in recordCollectedUsage params', async () => { + await OpenAIChatCompletionController(req, res); + + expect(mockRecordCollectedUsage).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + model: 'gpt-4', + }), + ); + }); + }); +}); diff --git a/api/server/controllers/agents/__tests__/responses.unit.spec.js b/api/server/controllers/agents/__tests__/responses.unit.spec.js new file mode 100644 index 0000000000..e16ca394b2 --- /dev/null +++ b/api/server/controllers/agents/__tests__/responses.unit.spec.js @@ -0,0 +1,312 @@ +/** + * Unit tests for Open Responses API controller + * Tests that recordCollectedUsage is called correctly for token spending + */ + +const mockSpendTokens = jest.fn().mockResolvedValue({}); +const mockSpendStructuredTokens = jest.fn().mockResolvedValue({}); +const mockRecordCollectedUsage = jest + .fn() + .mockResolvedValue({ input_tokens: 100, output_tokens: 50 }); +const mockGetBalanceConfig = jest.fn().mockReturnValue({ enabled: true }); +const mockGetTransactionsConfig = jest.fn().mockReturnValue({ enabled: true }); + +jest.mock('nanoid', () => ({ + nanoid: jest.fn(() => 'mock-nanoid-123'), +})); + +jest.mock('uuid', () => ({ + v4: jest.fn(() => 'mock-uuid-456'), +})); + +jest.mock('@librechat/data-schemas', () => ({ + logger: { + debug: jest.fn(), + error: jest.fn(), + warn: jest.fn(), + }, +})); + +jest.mock('@librechat/agents', () => ({ + Callback: { TOOL_ERROR: 'TOOL_ERROR' }, + ToolEndHandler: jest.fn(), + formatAgentMessages: jest.fn().mockReturnValue({ + messages: [], + indexTokenCountMap: {}, + }), +})); + +jest.mock('@librechat/api', () => ({ + createRun: jest.fn().mockResolvedValue({ + processStream: jest.fn().mockResolvedValue(undefined), + }), + buildToolSet: jest.fn().mockReturnValue(new Set()), + createSafeUser: jest.fn().mockReturnValue({ id: 'user-123' }), + initializeAgent: jest.fn().mockResolvedValue({ + model: 'claude-3', + model_parameters: {}, + toolRegistry: {}, + }), + getBalanceConfig: mockGetBalanceConfig, + getTransactionsConfig: mockGetTransactionsConfig, + recordCollectedUsage: mockRecordCollectedUsage, + createToolExecuteHandler: jest.fn().mockReturnValue({ handle: jest.fn() }), + // Responses API + writeDone: jest.fn(), + buildResponse: jest.fn().mockReturnValue({ id: 'resp_123', output: [] }), + generateResponseId: jest.fn().mockReturnValue('resp_mock-123'), + isValidationFailure: jest.fn().mockReturnValue(false), + emitResponseCreated: jest.fn(), + createResponseContext: jest.fn().mockReturnValue({ responseId: 'resp_123' }), + createResponseTracker: jest.fn().mockReturnValue({ + usage: { promptTokens: 100, completionTokens: 50 }, + }), + setupStreamingResponse: jest.fn(), + emitResponseInProgress: jest.fn(), + convertInputToMessages: jest.fn().mockReturnValue([]), + validateResponseRequest: jest.fn().mockReturnValue({ + request: { model: 'agent-123', input: 'Hello', stream: false }, + }), + buildAggregatedResponse: jest.fn().mockReturnValue({ + id: 'resp_123', + status: 'completed', + output: [], + usage: { input_tokens: 100, output_tokens: 50, total_tokens: 150 }, + }), + createResponseAggregator: jest.fn().mockReturnValue({ + usage: { promptTokens: 100, completionTokens: 50 }, + }), + sendResponsesErrorResponse: jest.fn(), + createResponsesEventHandlers: jest.fn().mockReturnValue({ + handlers: { + on_message_delta: { handle: jest.fn() }, + on_reasoning_delta: { handle: jest.fn() }, + on_run_step: { handle: jest.fn() }, + on_run_step_delta: { handle: jest.fn() }, + on_chat_model_end: { handle: jest.fn() }, + }, + finalizeStream: jest.fn(), + }), + createAggregatorEventHandlers: jest.fn().mockReturnValue({ + on_message_delta: { handle: jest.fn() }, + on_reasoning_delta: { handle: jest.fn() }, + on_run_step: { handle: jest.fn() }, + on_run_step_delta: { handle: jest.fn() }, + on_chat_model_end: { handle: jest.fn() }, + }), +})); + +jest.mock('~/server/services/ToolService', () => ({ + loadAgentTools: jest.fn().mockResolvedValue([]), + loadToolsForExecution: jest.fn().mockResolvedValue([]), +})); + +jest.mock('~/models/spendTokens', () => ({ + spendTokens: mockSpendTokens, + spendStructuredTokens: mockSpendStructuredTokens, +})); + +jest.mock('~/server/controllers/agents/callbacks', () => ({ + createToolEndCallback: jest.fn().mockReturnValue(jest.fn()), + createResponsesToolEndCallback: jest.fn().mockReturnValue(jest.fn()), +})); + +jest.mock('~/server/services/PermissionService', () => ({ + findAccessibleResources: jest.fn().mockResolvedValue([]), +})); + +jest.mock('~/models/Conversation', () => ({ + getConvoFiles: jest.fn().mockResolvedValue([]), + saveConvo: jest.fn().mockResolvedValue({}), + getConvo: jest.fn().mockResolvedValue(null), +})); + +jest.mock('~/models/Agent', () => ({ + getAgent: jest.fn().mockResolvedValue({ + id: 'agent-123', + name: 'Test Agent', + provider: 'anthropic', + model_parameters: { model: 'claude-3' }, + }), + getAgents: jest.fn().mockResolvedValue([]), +})); + +jest.mock('~/models', () => ({ + getFiles: jest.fn(), + getUserKey: jest.fn(), + getMessages: jest.fn().mockResolvedValue([]), + saveMessage: jest.fn().mockResolvedValue({}), + updateFilesUsage: jest.fn(), + getUserKeyValues: jest.fn(), + getUserCodeFiles: jest.fn(), + getToolFilesByIds: jest.fn(), + getCodeGeneratedFiles: jest.fn(), +})); + +describe('createResponse controller', () => { + let createResponse; + let req, res; + + beforeEach(() => { + jest.clearAllMocks(); + + const controller = require('../responses'); + createResponse = controller.createResponse; + + req = { + body: { + model: 'agent-123', + input: 'Hello', + stream: false, + }, + user: { id: 'user-123' }, + config: { + endpoints: { + agents: { allowedProviders: ['anthropic'] }, + }, + }, + on: jest.fn(), + }; + + res = { + status: jest.fn().mockReturnThis(), + json: jest.fn(), + setHeader: jest.fn(), + flushHeaders: jest.fn(), + end: jest.fn(), + write: jest.fn(), + }; + }); + + describe('token usage recording - non-streaming', () => { + it('should call recordCollectedUsage after successful non-streaming completion', async () => { + await createResponse(req, res); + + expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1); + expect(mockRecordCollectedUsage).toHaveBeenCalledWith( + { spendTokens: mockSpendTokens, spendStructuredTokens: mockSpendStructuredTokens }, + expect.objectContaining({ + user: 'user-123', + conversationId: expect.any(String), + collectedUsage: expect.any(Array), + context: 'message', + }), + ); + }); + + it('should pass balance and transactions config to recordCollectedUsage', async () => { + mockGetBalanceConfig.mockReturnValue({ enabled: true, startBalance: 2000 }); + mockGetTransactionsConfig.mockReturnValue({ enabled: true }); + + await createResponse(req, res); + + expect(mockRecordCollectedUsage).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + balance: { enabled: true, startBalance: 2000 }, + transactions: { enabled: true }, + }), + ); + }); + + it('should pass spendTokens and spendStructuredTokens as dependencies', async () => { + await createResponse(req, res); + + const [deps] = mockRecordCollectedUsage.mock.calls[0]; + expect(deps).toHaveProperty('spendTokens', mockSpendTokens); + expect(deps).toHaveProperty('spendStructuredTokens', mockSpendStructuredTokens); + }); + + it('should include model from primaryConfig in recordCollectedUsage params', async () => { + await createResponse(req, res); + + expect(mockRecordCollectedUsage).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + model: 'claude-3', + }), + ); + }); + }); + + describe('token usage recording - streaming', () => { + beforeEach(() => { + req.body.stream = true; + + const api = require('@librechat/api'); + api.validateResponseRequest.mockReturnValue({ + request: { model: 'agent-123', input: 'Hello', stream: true }, + }); + }); + + it('should call recordCollectedUsage after successful streaming completion', async () => { + await createResponse(req, res); + + expect(mockRecordCollectedUsage).toHaveBeenCalledTimes(1); + expect(mockRecordCollectedUsage).toHaveBeenCalledWith( + { spendTokens: mockSpendTokens, spendStructuredTokens: mockSpendStructuredTokens }, + expect.objectContaining({ + user: 'user-123', + context: 'message', + }), + ); + }); + }); + + describe('collectedUsage population', () => { + it('should collect usage from on_chat_model_end events', async () => { + const api = require('@librechat/api'); + + let capturedOnChatModelEnd; + api.createAggregatorEventHandlers.mockImplementation(() => { + return { + on_message_delta: { handle: jest.fn() }, + on_reasoning_delta: { handle: jest.fn() }, + on_run_step: { handle: jest.fn() }, + on_run_step_delta: { handle: jest.fn() }, + on_chat_model_end: { + handle: jest.fn((event, data) => { + if (capturedOnChatModelEnd) { + capturedOnChatModelEnd(event, data); + } + }), + }, + }; + }); + + api.createRun.mockImplementation(async ({ customHandlers }) => { + capturedOnChatModelEnd = (event, data) => { + customHandlers.on_chat_model_end.handle(event, data); + }; + + return { + processStream: jest.fn().mockImplementation(async () => { + customHandlers.on_chat_model_end.handle('on_chat_model_end', { + output: { + usage_metadata: { + input_tokens: 150, + output_tokens: 75, + model: 'claude-3', + }, + }, + }); + }), + }; + }); + + await createResponse(req, res); + + expect(mockRecordCollectedUsage).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ + collectedUsage: expect.arrayContaining([ + expect.objectContaining({ + input_tokens: 150, + output_tokens: 75, + }), + ]), + }), + ); + }); + }); +}); diff --git a/api/server/controllers/agents/callbacks.js b/api/server/controllers/agents/callbacks.js index 0d2a7bc317..0bb935795d 100644 --- a/api/server/controllers/agents/callbacks.js +++ b/api/server/controllers/agents/callbacks.js @@ -1,16 +1,13 @@ const { nanoid } = require('nanoid'); -const { sendEvent, GenerationJobManager } = require('@librechat/api'); const { logger } = require('@librechat/data-schemas'); +const { Constants, EnvVar, GraphEvents, ToolEndHandler } = require('@librechat/agents'); const { Tools, StepTypes, FileContext, ErrorTypes } = require('librechat-data-provider'); const { - EnvVar, - Providers, - GraphEvents, - getMessageId, - ToolEndHandler, - handleToolCalls, - ChatModelStreamHandler, -} = require('@librechat/agents'); + sendEvent, + GenerationJobManager, + writeAttachmentEvent, + createToolExecuteHandler, +} = require('@librechat/api'); const { processFileCitations } = require('~/server/services/Files/Citations'); const { processCodeOutput } = require('~/server/services/Files/Code/process'); const { loadAuthValues } = require('~/server/services/Tools/credentials'); @@ -51,8 +48,6 @@ class ModelEndHandler { let errorMessage; try { const agentContext = graph.getAgentContext(metadata); - const isGoogle = agentContext.provider === Providers.GOOGLE; - const streamingDisabled = !!agentContext.clientOptions?.disableStreaming; if (data?.output?.additional_kwargs?.stop_reason === 'refusal') { const info = { ...data.output.additional_kwargs }; errorMessage = JSON.stringify({ @@ -67,21 +62,6 @@ class ModelEndHandler { }); } - const toolCalls = data?.output?.tool_calls; - let hasUnprocessedToolCalls = false; - if (Array.isArray(toolCalls) && toolCalls.length > 0 && graph?.toolCallStepIds?.has) { - try { - hasUnprocessedToolCalls = toolCalls.some( - (tc) => tc?.id && !graph.toolCallStepIds.has(tc.id), - ); - } catch { - hasUnprocessedToolCalls = false; - } - } - if (isGoogle || streamingDisabled || hasUnprocessedToolCalls) { - await handleToolCalls(toolCalls, metadata, graph); - } - const usage = data?.output?.usage_metadata; if (!usage) { return this.finalize(errorMessage); @@ -92,38 +72,6 @@ class ModelEndHandler { } this.collectedUsage.push(usage); - if (!streamingDisabled) { - return this.finalize(errorMessage); - } - if (!data.output.content) { - return this.finalize(errorMessage); - } - const stepKey = graph.getStepKey(metadata); - const message_id = getMessageId(stepKey, graph) ?? ''; - if (message_id) { - await graph.dispatchRunStep(stepKey, { - type: StepTypes.MESSAGE_CREATION, - message_creation: { - message_id, - }, - }); - } - const stepId = graph.getStepIdByKey(stepKey); - const content = data.output.content; - if (typeof content === 'string') { - await graph.dispatchMessageDelta(stepId, { - content: [ - { - type: 'text', - text: content, - }, - ], - }); - } else if (content.every((c) => c.type?.startsWith('text'))) { - await graph.dispatchMessageDelta(stepId, { - content, - }); - } } catch (error) { logger.error('Error handling model end event:', error); return this.finalize(errorMessage); @@ -146,18 +94,26 @@ function checkIfLastAgent(last_agent_id, langgraph_node) { /** * Helper to emit events either to res (standard mode) or to job emitter (resumable mode). + * In Redis mode, awaits the emit to guarantee event ordering (critical for streaming deltas). * @param {ServerResponse} res - The server response object * @param {string | null} streamId - The stream ID for resumable mode, or null for standard mode * @param {Object} eventData - The event data to send + * @returns {Promise} */ -function emitEvent(res, streamId, eventData) { +async function emitEvent(res, streamId, eventData) { if (streamId) { - GenerationJobManager.emitChunk(streamId, eventData); + await GenerationJobManager.emitChunk(streamId, eventData); } else { sendEvent(res, eventData); } } +/** + * @typedef {Object} ToolExecuteOptions + * @property {(toolNames: string[]) => Promise<{loadedTools: StructuredTool[]}>} loadTools - Function to load tools by name + * @property {Object} configurable - Configurable context for tool invocation + */ + /** * Get default handlers for stream events. * @param {Object} options - The options object. @@ -166,6 +122,7 @@ function emitEvent(res, streamId, eventData) { * @param {ToolEndCallback} options.toolEndCallback - Callback to use when tool ends. * @param {Array} options.collectedUsage - The list of collected usage metadata. * @param {string | null} [options.streamId] - The stream ID for resumable mode, or null for standard mode. + * @param {ToolExecuteOptions} [options.toolExecuteOptions] - Options for event-driven tool execution. * @returns {Record} The default handlers. * @throws {Error} If the request is not found. */ @@ -175,6 +132,7 @@ function getDefaultHandlers({ toolEndCallback, collectedUsage, streamId = null, + toolExecuteOptions = null, }) { if (!res || !aggregateContent) { throw new Error( @@ -184,7 +142,6 @@ function getDefaultHandlers({ const handlers = { [GraphEvents.CHAT_MODEL_END]: new ModelEndHandler(collectedUsage), [GraphEvents.TOOL_END]: new ToolEndHandler(toolEndCallback, logger), - [GraphEvents.CHAT_MODEL_STREAM]: new ChatModelStreamHandler(), [GraphEvents.ON_RUN_STEP]: { /** * Handle ON_RUN_STEP event. @@ -192,18 +149,19 @@ function getDefaultHandlers({ * @param {StreamEventData} data - The event data. * @param {GraphRunnableConfig['configurable']} [metadata] The runnable metadata. */ - handle: (event, data, metadata) => { + handle: async (event, data, metadata) => { + aggregateContent({ event, data }); if (data?.stepDetails.type === StepTypes.TOOL_CALLS) { - emitEvent(res, streamId, { event, data }); + await emitEvent(res, streamId, { event, data }); } else if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) { - emitEvent(res, streamId, { event, data }); + await emitEvent(res, streamId, { event, data }); } else if (!metadata?.hide_sequential_outputs) { - emitEvent(res, streamId, { event, data }); + await emitEvent(res, streamId, { event, data }); } else { const agentName = metadata?.name ?? 'Agent'; const isToolCall = data?.stepDetails.type === StepTypes.TOOL_CALLS; const action = isToolCall ? 'performing a task...' : 'thinking...'; - emitEvent(res, streamId, { + await emitEvent(res, streamId, { event: 'on_agent_update', data: { runId: metadata?.run_id, @@ -211,7 +169,6 @@ function getDefaultHandlers({ }, }); } - aggregateContent({ event, data }); }, }, [GraphEvents.ON_RUN_STEP_DELTA]: { @@ -221,15 +178,15 @@ function getDefaultHandlers({ * @param {StreamEventData} data - The event data. * @param {GraphRunnableConfig['configurable']} [metadata] The runnable metadata. */ - handle: (event, data, metadata) => { - if (data?.delta.type === StepTypes.TOOL_CALLS) { - emitEvent(res, streamId, { event, data }); - } else if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) { - emitEvent(res, streamId, { event, data }); - } else if (!metadata?.hide_sequential_outputs) { - emitEvent(res, streamId, { event, data }); - } + handle: async (event, data, metadata) => { aggregateContent({ event, data }); + if (data?.delta.type === StepTypes.TOOL_CALLS) { + await emitEvent(res, streamId, { event, data }); + } else if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) { + await emitEvent(res, streamId, { event, data }); + } else if (!metadata?.hide_sequential_outputs) { + await emitEvent(res, streamId, { event, data }); + } }, }, [GraphEvents.ON_RUN_STEP_COMPLETED]: { @@ -239,15 +196,15 @@ function getDefaultHandlers({ * @param {StreamEventData & { result: ToolEndData }} data - The event data. * @param {GraphRunnableConfig['configurable']} [metadata] The runnable metadata. */ - handle: (event, data, metadata) => { - if (data?.result != null) { - emitEvent(res, streamId, { event, data }); - } else if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) { - emitEvent(res, streamId, { event, data }); - } else if (!metadata?.hide_sequential_outputs) { - emitEvent(res, streamId, { event, data }); - } + handle: async (event, data, metadata) => { aggregateContent({ event, data }); + if (data?.result != null) { + await emitEvent(res, streamId, { event, data }); + } else if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) { + await emitEvent(res, streamId, { event, data }); + } else if (!metadata?.hide_sequential_outputs) { + await emitEvent(res, streamId, { event, data }); + } }, }, [GraphEvents.ON_MESSAGE_DELTA]: { @@ -257,13 +214,13 @@ function getDefaultHandlers({ * @param {StreamEventData} data - The event data. * @param {GraphRunnableConfig['configurable']} [metadata] The runnable metadata. */ - handle: (event, data, metadata) => { - if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) { - emitEvent(res, streamId, { event, data }); - } else if (!metadata?.hide_sequential_outputs) { - emitEvent(res, streamId, { event, data }); - } + handle: async (event, data, metadata) => { aggregateContent({ event, data }); + if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) { + await emitEvent(res, streamId, { event, data }); + } else if (!metadata?.hide_sequential_outputs) { + await emitEvent(res, streamId, { event, data }); + } }, }, [GraphEvents.ON_REASONING_DELTA]: { @@ -273,22 +230,27 @@ function getDefaultHandlers({ * @param {StreamEventData} data - The event data. * @param {GraphRunnableConfig['configurable']} [metadata] The runnable metadata. */ - handle: (event, data, metadata) => { - if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) { - emitEvent(res, streamId, { event, data }); - } else if (!metadata?.hide_sequential_outputs) { - emitEvent(res, streamId, { event, data }); - } + handle: async (event, data, metadata) => { aggregateContent({ event, data }); + if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) { + await emitEvent(res, streamId, { event, data }); + } else if (!metadata?.hide_sequential_outputs) { + await emitEvent(res, streamId, { event, data }); + } }, }, }; + if (toolExecuteOptions) { + handlers[GraphEvents.ON_TOOL_EXECUTE] = createToolExecuteHandler(toolExecuteOptions); + } + return handlers; } /** * Helper to write attachment events either to res or to job emitter. + * Note: Attachments are not order-sensitive like deltas, so fire-and-forget is acceptable. * @param {ServerResponse} res - The server response object * @param {string | null} streamId - The stream ID for resumable mode, or null for standard mode * @param {Object} attachment - The attachment data @@ -441,10 +403,10 @@ function createToolEndCallback({ req, res, artifactPromises, streamId = null }) return; } - { - if (output.name !== Tools.execute_code) { - return; - } + const isCodeTool = + output.name === Tools.execute_code || output.name === Constants.PROGRAMMATIC_TOOL_CALLING; + if (!isCodeTool) { + return; } if (!output.artifact.files) { @@ -488,7 +450,226 @@ function createToolEndCallback({ req, res, artifactPromises, streamId = null }) }; } +/** + * Helper to write attachment events in Open Responses format (librechat:attachment) + * @param {ServerResponse} res - The server response object + * @param {Object} tracker - The response tracker with sequence number + * @param {Object} attachment - The attachment data + * @param {Object} metadata - Additional metadata (messageId, conversationId) + */ +function writeResponsesAttachment(res, tracker, attachment, metadata) { + const sequenceNumber = tracker.nextSequence(); + writeAttachmentEvent(res, sequenceNumber, attachment, { + messageId: metadata.run_id, + conversationId: metadata.thread_id, + }); +} + +/** + * Creates a tool end callback specifically for the Responses API. + * Emits attachments as `librechat:attachment` events per the Open Responses extension spec. + * + * @param {Object} params + * @param {ServerRequest} params.req + * @param {ServerResponse} params.res + * @param {Object} params.tracker - Response tracker with sequence number + * @param {Promise[]} params.artifactPromises + * @returns {ToolEndCallback} The tool end callback. + */ +function createResponsesToolEndCallback({ req, res, tracker, artifactPromises }) { + /** + * @type {ToolEndCallback} + */ + return async (data, metadata) => { + const output = data?.output; + if (!output) { + return; + } + + if (!output.artifact) { + return; + } + + if (output.artifact[Tools.file_search]) { + artifactPromises.push( + (async () => { + const user = req.user; + const attachment = await processFileCitations({ + user, + metadata, + appConfig: req.config, + toolArtifact: output.artifact, + toolCallId: output.tool_call_id, + }); + if (!attachment) { + return null; + } + // For Responses API, emit attachment during streaming + if (res.headersSent && !res.writableEnded) { + writeResponsesAttachment(res, tracker, attachment, metadata); + } + return attachment; + })().catch((error) => { + logger.error('Error processing file citations:', error); + return null; + }), + ); + } + + if (output.artifact[Tools.ui_resources]) { + artifactPromises.push( + (async () => { + const attachment = { + type: Tools.ui_resources, + toolCallId: output.tool_call_id, + [Tools.ui_resources]: output.artifact[Tools.ui_resources].data, + }; + // For Responses API, always emit attachment during streaming + if (res.headersSent && !res.writableEnded) { + writeResponsesAttachment(res, tracker, attachment, metadata); + } + return attachment; + })().catch((error) => { + logger.error('Error processing artifact content:', error); + return null; + }), + ); + } + + if (output.artifact[Tools.web_search]) { + artifactPromises.push( + (async () => { + const attachment = { + type: Tools.web_search, + toolCallId: output.tool_call_id, + [Tools.web_search]: { ...output.artifact[Tools.web_search] }, + }; + // For Responses API, always emit attachment during streaming + if (res.headersSent && !res.writableEnded) { + writeResponsesAttachment(res, tracker, attachment, metadata); + } + return attachment; + })().catch((error) => { + logger.error('Error processing artifact content:', error); + return null; + }), + ); + } + + if (output.artifact.content) { + /** @type {FormattedContent[]} */ + const content = output.artifact.content; + for (let i = 0; i < content.length; i++) { + const part = content[i]; + if (!part) { + continue; + } + if (part.type !== 'image_url') { + continue; + } + const { url } = part.image_url; + artifactPromises.push( + (async () => { + const filename = `${output.name}_img_${nanoid()}`; + const file_id = output.artifact.file_ids?.[i]; + const file = await saveBase64Image(url, { + req, + file_id, + filename, + endpoint: metadata.provider, + context: FileContext.image_generation, + }); + const fileMetadata = Object.assign(file, { + toolCallId: output.tool_call_id, + }); + + if (!fileMetadata) { + return null; + } + + // For Responses API, emit attachment during streaming + if (res.headersSent && !res.writableEnded) { + const attachment = { + file_id: fileMetadata.file_id, + filename: fileMetadata.filename, + type: fileMetadata.type, + url: fileMetadata.filepath, + width: fileMetadata.width, + height: fileMetadata.height, + tool_call_id: output.tool_call_id, + }; + writeResponsesAttachment(res, tracker, attachment, metadata); + } + + return fileMetadata; + })().catch((error) => { + logger.error('Error processing artifact content:', error); + return null; + }), + ); + } + return; + } + + const isCodeTool = + output.name === Tools.execute_code || output.name === Constants.PROGRAMMATIC_TOOL_CALLING; + if (!isCodeTool) { + return; + } + + if (!output.artifact.files) { + return; + } + + for (const file of output.artifact.files) { + const { id, name } = file; + artifactPromises.push( + (async () => { + const result = await loadAuthValues({ + userId: req.user.id, + authFields: [EnvVar.CODE_API_KEY], + }); + const fileMetadata = await processCodeOutput({ + req, + id, + name, + apiKey: result[EnvVar.CODE_API_KEY], + messageId: metadata.run_id, + toolCallId: output.tool_call_id, + conversationId: metadata.thread_id, + session_id: output.artifact.session_id, + }); + + if (!fileMetadata) { + return null; + } + + // For Responses API, emit attachment during streaming + if (res.headersSent && !res.writableEnded) { + const attachment = { + file_id: fileMetadata.file_id, + filename: fileMetadata.filename, + type: fileMetadata.type, + url: fileMetadata.filepath, + width: fileMetadata.width, + height: fileMetadata.height, + tool_call_id: output.tool_call_id, + }; + writeResponsesAttachment(res, tracker, attachment, metadata); + } + + return fileMetadata; + })().catch((error) => { + logger.error('Error processing code output:', error); + return null; + }), + ); + } + }; +} + module.exports = { getDefaultHandlers, createToolEndCallback, + createResponsesToolEndCallback, }; diff --git a/api/server/controllers/agents/client.js b/api/server/controllers/agents/client.js index 2b5872411b..49240a6b3b 100644 --- a/api/server/controllers/agents/client.js +++ b/api/server/controllers/agents/client.js @@ -1,22 +1,27 @@ require('events').EventEmitter.defaultMaxListeners = 100; const { logger } = require('@librechat/data-schemas'); -const { DynamicStructuredTool } = require('@langchain/core/tools'); const { getBufferString, HumanMessage } = require('@langchain/core/messages'); const { createRun, Tokenizer, checkAccess, - logAxiosError, + buildToolSet, sanitizeTitle, + logToolError, + payloadParser, resolveHeaders, createSafeUser, initializeAgent, getBalanceConfig, getProviderConfig, + omitTitleOptions, memoryInstructions, + applyContextToAgent, + createTokenCounter, GenerationJobManager, getTransactionsConfig, createMemoryProcessor, + createMultiAgentMapper, filterMalformedContentParts, } = require('@librechat/api'); const { @@ -24,9 +29,7 @@ const { Providers, TitleMethod, formatMessage, - labelContentByAgent, formatAgentMessages, - getTokenCountForMessage, createMetadataAggregator, } = require('@librechat/agents'); const { @@ -38,7 +41,6 @@ const { PermissionTypes, isAgentsEndpoint, isEphemeralAgentId, - bedrockInputSchema, removeNullishValues, } = require('librechat-data-provider'); const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); @@ -51,183 +53,6 @@ const { loadAgent } = require('~/models/Agent'); const { getMCPManager } = require('~/config'); const db = require('~/models'); -const omitTitleOptions = new Set([ - 'stream', - 'thinking', - 'streaming', - 'clientOptions', - 'thinkingConfig', - 'thinkingBudget', - 'includeThoughts', - 'maxOutputTokens', - 'additionalModelRequestFields', -]); - -/** - * @param {ServerRequest} req - * @param {Agent} agent - * @param {string} endpoint - */ -const payloadParser = ({ req, agent, endpoint }) => { - if (isAgentsEndpoint(endpoint)) { - return { model: undefined }; - } else if (endpoint === EModelEndpoint.bedrock) { - const parsedValues = bedrockInputSchema.parse(agent.model_parameters); - if (parsedValues.thinking == null) { - parsedValues.thinking = false; - } - return parsedValues; - } - return req.body.endpointOption.model_parameters; -}; - -function createTokenCounter(encoding) { - return function (message) { - const countTokens = (text) => Tokenizer.getTokenCount(text, encoding); - return getTokenCountForMessage(message, countTokens); - }; -} - -function logToolError(graph, error, toolId) { - logAxiosError({ - error, - message: `[api/server/controllers/agents/client.js #chatCompletion] Tool Error "${toolId}"`, - }); -} - -/** Regex pattern to match agent ID suffix (____N) */ -const AGENT_SUFFIX_PATTERN = /____(\d+)$/; - -/** - * Finds the primary agent ID within a set of agent IDs. - * Primary = no suffix (____N) or lowest suffix number. - * @param {Set} agentIds - * @returns {string | null} - */ -function findPrimaryAgentId(agentIds) { - let primaryAgentId = null; - let lowestSuffixIndex = Infinity; - - for (const agentId of agentIds) { - const suffixMatch = agentId.match(AGENT_SUFFIX_PATTERN); - if (!suffixMatch) { - return agentId; - } - const suffixIndex = parseInt(suffixMatch[1], 10); - if (suffixIndex < lowestSuffixIndex) { - lowestSuffixIndex = suffixIndex; - primaryAgentId = agentId; - } - } - - return primaryAgentId; -} - -/** - * Creates a mapMethod for getMessagesForConversation that processes agent content. - * - Strips agentId/groupId metadata from all content - * - For parallel agents (addedConvo with groupId): filters each group to its primary agent - * - For handoffs (agentId without groupId): keeps all content from all agents - * - For multi-agent: applies agent labels to content - * - * The key distinction: - * - Parallel execution (addedConvo): Parts have both agentId AND groupId - * - Handoffs: Parts only have agentId, no groupId - * - * @param {Agent} primaryAgent - Primary agent configuration - * @param {Map} [agentConfigs] - Additional agent configurations - * @returns {(message: TMessage) => TMessage} Map method for processing messages - */ -function createMultiAgentMapper(primaryAgent, agentConfigs) { - const hasMultipleAgents = (primaryAgent.edges?.length ?? 0) > 0 || (agentConfigs?.size ?? 0) > 0; - - /** @type {Record | null} */ - let agentNames = null; - if (hasMultipleAgents) { - agentNames = { [primaryAgent.id]: primaryAgent.name || 'Assistant' }; - if (agentConfigs) { - for (const [agentId, agentConfig] of agentConfigs.entries()) { - agentNames[agentId] = agentConfig.name || agentConfig.id; - } - } - } - - return (message) => { - if (message.isCreatedByUser || !Array.isArray(message.content)) { - return message; - } - - // Check for metadata - const hasAgentMetadata = message.content.some((part) => part?.agentId || part?.groupId != null); - if (!hasAgentMetadata) { - return message; - } - - try { - // Build a map of groupId -> Set of agentIds, to find primary per group - /** @type {Map>} */ - const groupAgentMap = new Map(); - - for (const part of message.content) { - const groupId = part?.groupId; - const agentId = part?.agentId; - if (groupId != null && agentId) { - if (!groupAgentMap.has(groupId)) { - groupAgentMap.set(groupId, new Set()); - } - groupAgentMap.get(groupId).add(agentId); - } - } - - // For each group, find the primary agent - /** @type {Map} */ - const groupPrimaryMap = new Map(); - for (const [groupId, agentIds] of groupAgentMap) { - const primary = findPrimaryAgentId(agentIds); - if (primary) { - groupPrimaryMap.set(groupId, primary); - } - } - - /** @type {Array} */ - const filteredContent = []; - /** @type {Record} */ - const agentIdMap = {}; - - for (const part of message.content) { - const agentId = part?.agentId; - const groupId = part?.groupId; - - // Filtering logic: - // - No groupId (handoffs): always include - // - Has groupId (parallel): only include if it's the primary for that group - const isParallelPart = groupId != null; - const groupPrimary = isParallelPart ? groupPrimaryMap.get(groupId) : null; - const shouldInclude = !isParallelPart || !agentId || agentId === groupPrimary; - - if (shouldInclude) { - const newIndex = filteredContent.length; - const { agentId: _a, groupId: _g, ...cleanPart } = part; - filteredContent.push(cleanPart); - if (agentId && hasMultipleAgents) { - agentIdMap[newIndex] = agentId; - } - } - } - - const finalContent = - Object.keys(agentIdMap).length > 0 && agentNames - ? labelContentByAgent(filteredContent, agentIdMap, agentNames) - : filteredContent; - - return { ...message, content: finalContent }; - } catch (error) { - logger.error('[AgentClient] Error processing multi-agent message:', error); - return message; - } - }; -} - class AgentClient extends BaseClient { constructor(options = {}) { super(null, options); @@ -295,14 +120,9 @@ class AgentClient extends BaseClient { checkVisionRequest() {} getSaveOptions() { - // TODO: - // would need to be override settings; otherwise, model needs to be undefined - // model: this.override.model, - // instructions: this.override.instructions, - // additional_instructions: this.override.additional_instructions, let runOptions = {}; try { - runOptions = payloadParser(this.options); + runOptions = payloadParser(this.options) ?? {}; } catch (error) { logger.error( '[api/server/controllers/agents/client.js #getSaveOptions] Error parsing options', @@ -313,14 +133,14 @@ class AgentClient extends BaseClient { return removeNullishValues( Object.assign( { + spec: this.options.spec, + iconURL: this.options.iconURL, endpoint: this.options.endpoint, agent_id: this.options.agent.id, modelLabel: this.options.modelLabel, - maxContextTokens: this.options.maxContextTokens, resendFiles: this.options.resendFiles, imageDetail: this.options.imageDetail, - spec: this.options.spec, - iconURL: this.options.iconURL, + maxContextTokens: this.maxContextTokens, }, // TODO: PARSE OPTIONS BY PROVIDER, MAY CONTAIN SENSITIVE DATA runOptions, @@ -328,11 +148,13 @@ class AgentClient extends BaseClient { ); } + /** + * Returns build message options. For AgentClient, agent-specific instructions + * are retrieved directly from agent objects in buildMessages, so this returns empty. + * @returns {Object} Empty options object + */ getBuildMessagesOptions() { - return { - instructions: this.options.agent.instructions, - additional_instructions: this.options.agent.additional_instructions, - }; + return {}; } /** @@ -355,12 +177,7 @@ class AgentClient extends BaseClient { return files; } - async buildMessages( - messages, - parentMessageId, - { instructions = null, additional_instructions = null }, - opts, - ) { + async buildMessages(messages, parentMessageId, _buildOptions, opts) { /** Always pass mapMethod; getMessagesForConversation applies it only to messages with addedConvo flag */ const orderedMessages = this.constructor.getMessagesForConversation({ messages, @@ -374,11 +191,29 @@ class AgentClient extends BaseClient { /** @type {number | undefined} */ let promptTokens; - /** @type {string} */ - let systemContent = [instructions ?? '', additional_instructions ?? ''] - .filter(Boolean) - .join('\n') - .trim(); + /** + * Extract base instructions for all agents (combines instructions + additional_instructions). + * This must be done before applying context to preserve the original agent configuration. + */ + const extractBaseInstructions = (agent) => { + const baseInstructions = [agent.instructions ?? '', agent.additional_instructions ?? ''] + .filter(Boolean) + .join('\n') + .trim(); + agent.instructions = baseInstructions; + return agent; + }; + + /** Collect all agents for unified processing, extracting base instructions during collection */ + const allAgents = [ + { agent: extractBaseInstructions(this.options.agent), agentId: this.options.agent.id }, + ...(this.agentConfigs?.size > 0 + ? Array.from(this.agentConfigs.entries()).map(([agentId, agent]) => ({ + agent: extractBaseInstructions(agent), + agentId, + })) + : []), + ]; if (this.options.attachments) { const attachments = await this.options.attachments; @@ -413,6 +248,7 @@ class AgentClient extends BaseClient { assistantName: this.options?.modelLabel, }); + /** For non-latest messages, prepend file context directly to message content */ if (message.fileContext && i !== orderedMessages.length - 1) { if (typeof formattedMessage.content === 'string') { formattedMessage.content = message.fileContext + '\n' + formattedMessage.content; @@ -422,8 +258,6 @@ class AgentClient extends BaseClient { ? (textPart.text = message.fileContext + '\n' + textPart.text) : formattedMessage.content.unshift({ type: 'text', text: message.fileContext }); } - } else if (message.fileContext && i === orderedMessages.length - 1) { - systemContent = [systemContent, message.fileContext].join('\n'); } const needsTokenCount = @@ -456,46 +290,35 @@ class AgentClient extends BaseClient { return formattedMessage; }); + /** + * Build shared run context - applies to ALL agents in the run. + * This includes: file context (latest message), augmented prompt (RAG), memory context. + */ + const sharedRunContextParts = []; + + /** File context from the latest message (attachments) */ + const latestMessage = orderedMessages[orderedMessages.length - 1]; + if (latestMessage?.fileContext) { + sharedRunContextParts.push(latestMessage.fileContext); + } + + /** Augmented prompt from RAG/context handlers */ if (this.contextHandlers) { this.augmentedPrompt = await this.contextHandlers.createContext(); - systemContent = this.augmentedPrompt + systemContent; - } - - // Inject MCP server instructions if available - const ephemeralAgent = this.options.req.body.ephemeralAgent; - let mcpServers = []; - - // Check for ephemeral agent MCP servers - if (ephemeralAgent && ephemeralAgent.mcp && ephemeralAgent.mcp.length > 0) { - mcpServers = ephemeralAgent.mcp; - } - // Check for regular agent MCP tools - else if (this.options.agent && this.options.agent.tools) { - mcpServers = this.options.agent.tools - .filter( - (tool) => - tool instanceof DynamicStructuredTool && tool.name.includes(Constants.mcp_delimiter), - ) - .map((tool) => tool.name.split(Constants.mcp_delimiter).pop()) - .filter(Boolean); - } - - if (mcpServers.length > 0) { - try { - const mcpInstructions = await getMCPManager().formatInstructionsForContext(mcpServers); - if (mcpInstructions) { - systemContent = [systemContent, mcpInstructions].filter(Boolean).join('\n\n'); - logger.debug('[AgentClient] Injected MCP instructions for servers:', mcpServers); - } - } catch (error) { - logger.error('[AgentClient] Failed to inject MCP instructions:', error); + if (this.augmentedPrompt) { + sharedRunContextParts.push(this.augmentedPrompt); } } - if (systemContent) { - this.options.agent.instructions = systemContent; + /** Memory context (user preferences/memories) */ + const withoutKeys = await this.useMemory(); + if (withoutKeys) { + const memoryContext = `${memoryInstructions}\n\n# Existing memory about the user:\n${withoutKeys}`; + sharedRunContextParts.push(memoryContext); } + const sharedRunContext = sharedRunContextParts.join('\n\n'); + /** @type {Record | undefined} */ let tokenCountMap; @@ -521,14 +344,27 @@ class AgentClient extends BaseClient { opts.getReqData({ promptTokens }); } - const withoutKeys = await this.useMemory(); - if (withoutKeys) { - systemContent += `${memoryInstructions}\n\n# Existing memory about the user:\n${withoutKeys}`; - } - - if (systemContent) { - this.options.agent.instructions = systemContent; - } + /** + * Apply context to all agents. + * Each agent gets: shared run context + their own base instructions + their own MCP instructions. + * + * NOTE: This intentionally mutates agent objects in place. The agentConfigs Map + * holds references to config objects that will be passed to the graph runtime. + */ + const ephemeralAgent = this.options.req.body.ephemeralAgent; + const mcpManager = getMCPManager(); + await Promise.all( + allAgents.map(({ agent, agentId }) => + applyContextToAgent({ + agent, + agentId, + logger, + mcpManager, + sharedRunContext, + ephemeralAgent: agentId === this.options.agent.id ? ephemeralAgent : undefined, + }), + ), + ); return result; } @@ -600,6 +436,8 @@ class AgentClient extends BaseClient { agent_id: memoryConfig.agent.id, endpoint: EModelEndpoint.agents, }); + } else if (memoryConfig.agent?.id != null) { + prelimAgent = this.options.agent; } else if ( memoryConfig.agent?.id == null && memoryConfig.agent?.model != null && @@ -614,6 +452,10 @@ class AgentClient extends BaseClient { ); } + if (!prelimAgent) { + return; + } + const agent = await initializeAgent( { req: this.options.req, @@ -633,6 +475,7 @@ class AgentClient extends BaseClient { updateFilesUsage: db.updateFilesUsage, getUserKeyValues: db.getUserKeyValues, getToolFilesByIds: db.getToolFilesByIds, + getCodeGeneratedFiles: db.getCodeGeneratedFiles, }, ); @@ -945,13 +788,13 @@ class AgentClient extends BaseClient { }, user: createSafeUser(this.options.req.user), }, - recursionLimit: agentsEConfig?.recursionLimit ?? 25, + recursionLimit: agentsEConfig?.recursionLimit ?? 50, signal: abortController.signal, streamMode: 'values', version: 'v2', }; - const toolSet = new Set((this.options.agent.tools ?? []).map((tool) => tool && tool.name)); + const toolSet = buildToolSet(this.options.agent); let { messages: initialMessages, indexTokenCountMap } = formatAgentMessages( payload, this.indexTokenCountMap, @@ -1012,6 +855,7 @@ class AgentClient extends BaseClient { run = await createRun({ agents, + messages, indexTokenCountMap, runId: this.responseMessageId, signal: abortController.signal, @@ -1084,11 +928,20 @@ class AgentClient extends BaseClient { this.artifactPromises.push(...attachments); } - await this.recordCollectedUsage({ - context: 'message', - balance: balanceConfig, - transactions: transactionsConfig, - }); + /** Skip token spending if aborted - the abort handler (abortMiddleware.js) handles it + This prevents double-spending when user aborts via `/api/agents/chat/abort` */ + const wasAborted = abortController?.signal?.aborted; + if (!wasAborted) { + await this.recordCollectedUsage({ + context: 'message', + balance: balanceConfig, + transactions: transactionsConfig, + }); + } else { + logger.debug( + '[api/server/controllers/agents/client.js #chatCompletion] Skipping token spending - handled by abort middleware', + ); + } } catch (err) { logger.error( '[api/server/controllers/agents/client.js #chatCompletion] Error in cleanup phase', diff --git a/api/server/controllers/agents/client.test.js b/api/server/controllers/agents/client.test.js index 14f0df9bb0..9dd3567047 100644 --- a/api/server/controllers/agents/client.test.js +++ b/api/server/controllers/agents/client.test.js @@ -12,6 +12,17 @@ jest.mock('@librechat/agents', () => ({ jest.mock('@librechat/api', () => ({ ...jest.requireActual('@librechat/api'), + checkAccess: jest.fn(), + initializeAgent: jest.fn(), + createMemoryProcessor: jest.fn(), +})); + +jest.mock('~/models/Agent', () => ({ + loadAgent: jest.fn(), +})); + +jest.mock('~/models/Role', () => ({ + getRoleByName: jest.fn(), })); // Mock getMCPManager @@ -1310,8 +1321,8 @@ describe('AgentClient - titleConvo', () => { expect(client.options.agent.instructions).toContain('# MCP Server Instructions'); expect(client.options.agent.instructions).toContain('Use these tools carefully'); - // Verify the base instructions are also included - expect(client.options.agent.instructions).toContain('Base instructions'); + // Verify the base instructions are also included (from agent config, not buildOptions) + expect(client.options.agent.instructions).toContain('Base agent instructions'); }); it('should handle MCP instructions with ephemeral agent', async () => { @@ -1373,8 +1384,8 @@ describe('AgentClient - titleConvo', () => { additional_instructions: null, }); - // Verify the instructions still work without MCP content - expect(client.options.agent.instructions).toBe('Base instructions only'); + // Verify the instructions still work without MCP content (from agent config, not buildOptions) + expect(client.options.agent.instructions).toBe('Base agent instructions'); expect(client.options.agent.instructions).not.toContain('[object Promise]'); }); @@ -1398,8 +1409,8 @@ describe('AgentClient - titleConvo', () => { additional_instructions: null, }); - // Should still have base instructions without MCP content - expect(client.options.agent.instructions).toContain('Base instructions'); + // Should still have base instructions without MCP content (from agent config, not buildOptions) + expect(client.options.agent.instructions).toContain('Base agent instructions'); expect(client.options.agent.instructions).not.toContain('[object Promise]'); }); }); @@ -1849,4 +1860,400 @@ describe('AgentClient - titleConvo', () => { }); }); }); + + describe('buildMessages - memory context for parallel agents', () => { + let client; + let mockReq; + let mockRes; + let mockAgent; + let mockOptions; + + beforeEach(() => { + jest.clearAllMocks(); + + mockAgent = { + id: 'primary-agent', + name: 'Primary Agent', + endpoint: EModelEndpoint.openAI, + provider: EModelEndpoint.openAI, + instructions: 'Primary agent instructions', + model_parameters: { + model: 'gpt-4', + }, + tools: [], + }; + + mockReq = { + user: { + id: 'user-123', + personalization: { + memories: true, + }, + }, + body: { + endpoint: EModelEndpoint.openAI, + }, + config: { + memory: { + disabled: false, + }, + }, + }; + + mockRes = {}; + + mockOptions = { + req: mockReq, + res: mockRes, + agent: mockAgent, + endpoint: EModelEndpoint.agents, + }; + + client = new AgentClient(mockOptions); + client.conversationId = 'convo-123'; + client.responseMessageId = 'response-123'; + client.shouldSummarize = false; + client.maxContextTokens = 4096; + }); + + it('should pass memory context to parallel agents (addedConvo)', async () => { + const memoryContent = 'User prefers dark mode. User is a software developer.'; + client.useMemory = jest.fn().mockResolvedValue(memoryContent); + + const parallelAgent1 = { + id: 'parallel-agent-1', + name: 'Parallel Agent 1', + instructions: 'Parallel agent 1 instructions', + provider: EModelEndpoint.openAI, + }; + + const parallelAgent2 = { + id: 'parallel-agent-2', + name: 'Parallel Agent 2', + instructions: 'Parallel agent 2 instructions', + provider: EModelEndpoint.anthropic, + }; + + client.agentConfigs = new Map([ + ['parallel-agent-1', parallelAgent1], + ['parallel-agent-2', parallelAgent2], + ]); + + const messages = [ + { + messageId: 'msg-1', + parentMessageId: null, + sender: 'User', + text: 'Hello', + isCreatedByUser: true, + }, + ]; + + await client.buildMessages(messages, null, { + instructions: 'Base instructions', + additional_instructions: null, + }); + + expect(client.useMemory).toHaveBeenCalled(); + + // Verify primary agent has its configured instructions (not from buildOptions) and memory context + expect(client.options.agent.instructions).toContain('Primary agent instructions'); + expect(client.options.agent.instructions).toContain(memoryContent); + + expect(parallelAgent1.instructions).toContain('Parallel agent 1 instructions'); + expect(parallelAgent1.instructions).toContain(memoryContent); + + expect(parallelAgent2.instructions).toContain('Parallel agent 2 instructions'); + expect(parallelAgent2.instructions).toContain(memoryContent); + }); + + it('should not modify parallel agents when no memory context is available', async () => { + client.useMemory = jest.fn().mockResolvedValue(undefined); + + const parallelAgent = { + id: 'parallel-agent-1', + name: 'Parallel Agent 1', + instructions: 'Original parallel instructions', + provider: EModelEndpoint.openAI, + }; + + client.agentConfigs = new Map([['parallel-agent-1', parallelAgent]]); + + const messages = [ + { + messageId: 'msg-1', + parentMessageId: null, + sender: 'User', + text: 'Hello', + isCreatedByUser: true, + }, + ]; + + await client.buildMessages(messages, null, { + instructions: 'Base instructions', + additional_instructions: null, + }); + + expect(parallelAgent.instructions).toBe('Original parallel instructions'); + }); + + it('should handle parallel agents without existing instructions', async () => { + const memoryContent = 'User is a data scientist.'; + client.useMemory = jest.fn().mockResolvedValue(memoryContent); + + const parallelAgentNoInstructions = { + id: 'parallel-agent-no-instructions', + name: 'Parallel Agent No Instructions', + provider: EModelEndpoint.openAI, + }; + + client.agentConfigs = new Map([ + ['parallel-agent-no-instructions', parallelAgentNoInstructions], + ]); + + const messages = [ + { + messageId: 'msg-1', + parentMessageId: null, + sender: 'User', + text: 'Hello', + isCreatedByUser: true, + }, + ]; + + await client.buildMessages(messages, null, { + instructions: null, + additional_instructions: null, + }); + + expect(parallelAgentNoInstructions.instructions).toContain(memoryContent); + }); + + it('should not modify agentConfigs when none exist', async () => { + const memoryContent = 'User prefers concise responses.'; + client.useMemory = jest.fn().mockResolvedValue(memoryContent); + + client.agentConfigs = null; + + const messages = [ + { + messageId: 'msg-1', + parentMessageId: null, + sender: 'User', + text: 'Hello', + isCreatedByUser: true, + }, + ]; + + await expect( + client.buildMessages(messages, null, { + instructions: 'Base instructions', + additional_instructions: null, + }), + ).resolves.not.toThrow(); + + expect(client.options.agent.instructions).toContain(memoryContent); + }); + + it('should handle empty agentConfigs map', async () => { + const memoryContent = 'User likes detailed explanations.'; + client.useMemory = jest.fn().mockResolvedValue(memoryContent); + + client.agentConfigs = new Map(); + + const messages = [ + { + messageId: 'msg-1', + parentMessageId: null, + sender: 'User', + text: 'Hello', + isCreatedByUser: true, + }, + ]; + + await expect( + client.buildMessages(messages, null, { + instructions: 'Base instructions', + additional_instructions: null, + }), + ).resolves.not.toThrow(); + + expect(client.options.agent.instructions).toContain(memoryContent); + }); + }); + + describe('useMemory method - prelimAgent assignment', () => { + let client; + let mockReq; + let mockRes; + let mockAgent; + let mockOptions; + let mockCheckAccess; + let mockLoadAgent; + let mockInitializeAgent; + let mockCreateMemoryProcessor; + + beforeEach(() => { + jest.clearAllMocks(); + + mockAgent = { + id: 'agent-123', + endpoint: EModelEndpoint.openAI, + provider: EModelEndpoint.openAI, + instructions: 'Test instructions', + model: 'gpt-4', + model_parameters: { + model: 'gpt-4', + }, + }; + + mockReq = { + user: { + id: 'user-123', + personalization: { + memories: true, + }, + }, + config: { + memory: { + agent: { + id: 'agent-123', + }, + }, + endpoints: { + [EModelEndpoint.agents]: { + allowedProviders: [EModelEndpoint.openAI], + }, + }, + }, + }; + + mockRes = {}; + + mockOptions = { + req: mockReq, + res: mockRes, + agent: mockAgent, + }; + + mockCheckAccess = require('@librechat/api').checkAccess; + mockLoadAgent = require('~/models/Agent').loadAgent; + mockInitializeAgent = require('@librechat/api').initializeAgent; + mockCreateMemoryProcessor = require('@librechat/api').createMemoryProcessor; + }); + + it('should use current agent when memory config agent.id matches current agent id', async () => { + mockCheckAccess.mockResolvedValue(true); + mockInitializeAgent.mockResolvedValue({ + ...mockAgent, + provider: EModelEndpoint.openAI, + }); + mockCreateMemoryProcessor.mockResolvedValue([undefined, jest.fn()]); + + client = new AgentClient(mockOptions); + client.conversationId = 'convo-123'; + client.responseMessageId = 'response-123'; + + await client.useMemory(); + + expect(mockLoadAgent).not.toHaveBeenCalled(); + expect(mockInitializeAgent).toHaveBeenCalledWith( + expect.objectContaining({ + agent: mockAgent, + }), + expect.any(Object), + ); + }); + + it('should load different agent when memory config agent.id differs from current agent id', async () => { + const differentAgentId = 'different-agent-456'; + const differentAgent = { + id: differentAgentId, + provider: EModelEndpoint.openAI, + model: 'gpt-4', + instructions: 'Different agent instructions', + }; + + mockReq.config.memory.agent.id = differentAgentId; + + mockCheckAccess.mockResolvedValue(true); + mockLoadAgent.mockResolvedValue(differentAgent); + mockInitializeAgent.mockResolvedValue({ + ...differentAgent, + provider: EModelEndpoint.openAI, + }); + mockCreateMemoryProcessor.mockResolvedValue([undefined, jest.fn()]); + + client = new AgentClient(mockOptions); + client.conversationId = 'convo-123'; + client.responseMessageId = 'response-123'; + + await client.useMemory(); + + expect(mockLoadAgent).toHaveBeenCalledWith( + expect.objectContaining({ + agent_id: differentAgentId, + }), + ); + expect(mockInitializeAgent).toHaveBeenCalledWith( + expect.objectContaining({ + agent: differentAgent, + }), + expect.any(Object), + ); + }); + + it('should return early when prelimAgent is undefined (no valid memory agent config)', async () => { + mockReq.config.memory = { + agent: {}, + }; + + mockCheckAccess.mockResolvedValue(true); + + client = new AgentClient(mockOptions); + client.conversationId = 'convo-123'; + client.responseMessageId = 'response-123'; + + const result = await client.useMemory(); + + expect(result).toBeUndefined(); + expect(mockInitializeAgent).not.toHaveBeenCalled(); + expect(mockCreateMemoryProcessor).not.toHaveBeenCalled(); + }); + + it('should create ephemeral agent when no id but model and provider are specified', async () => { + mockReq.config.memory = { + agent: { + model: 'gpt-4', + provider: EModelEndpoint.openAI, + }, + }; + + mockCheckAccess.mockResolvedValue(true); + mockInitializeAgent.mockResolvedValue({ + id: Constants.EPHEMERAL_AGENT_ID, + model: 'gpt-4', + provider: EModelEndpoint.openAI, + }); + mockCreateMemoryProcessor.mockResolvedValue([undefined, jest.fn()]); + + client = new AgentClient(mockOptions); + client.conversationId = 'convo-123'; + client.responseMessageId = 'response-123'; + + await client.useMemory(); + + expect(mockLoadAgent).not.toHaveBeenCalled(); + expect(mockInitializeAgent).toHaveBeenCalledWith( + expect.objectContaining({ + agent: expect.objectContaining({ + id: Constants.EPHEMERAL_AGENT_ID, + model: 'gpt-4', + provider: EModelEndpoint.openAI, + }), + }), + expect.any(Object), + ); + }); + }); }); diff --git a/api/server/controllers/agents/openai.js b/api/server/controllers/agents/openai.js new file mode 100644 index 0000000000..b334580eb1 --- /dev/null +++ b/api/server/controllers/agents/openai.js @@ -0,0 +1,701 @@ +const { nanoid } = require('nanoid'); +const { logger } = require('@librechat/data-schemas'); +const { Callback, ToolEndHandler, formatAgentMessages } = require('@librechat/agents'); +const { EModelEndpoint, ResourceType, PermissionBits } = require('librechat-data-provider'); +const { + writeSSE, + createRun, + createChunk, + buildToolSet, + sendFinalChunk, + createSafeUser, + validateRequest, + initializeAgent, + getBalanceConfig, + createErrorResponse, + recordCollectedUsage, + getTransactionsConfig, + createToolExecuteHandler, + buildNonStreamingResponse, + createOpenAIStreamTracker, + createOpenAIContentAggregator, + isChatCompletionValidationFailure, +} = require('@librechat/api'); +const { loadAgentTools, loadToolsForExecution } = require('~/server/services/ToolService'); +const { createToolEndCallback } = require('~/server/controllers/agents/callbacks'); +const { findAccessibleResources } = require('~/server/services/PermissionService'); +const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); +const { getConvoFiles } = require('~/models/Conversation'); +const { getAgent, getAgents } = require('~/models/Agent'); +const db = require('~/models'); + +/** + * Creates a tool loader function for the agent. + * @param {AbortSignal} signal - The abort signal + * @param {boolean} [definitionsOnly=true] - When true, returns only serializable + * tool definitions without creating full tool instances (for event-driven mode) + */ +function createToolLoader(signal, definitionsOnly = true) { + return async function loadTools({ + req, + res, + tools, + model, + agentId, + provider, + tool_options, + tool_resources, + }) { + const agent = { id: agentId, tools, provider, model, tool_options }; + try { + return await loadAgentTools({ + req, + res, + agent, + signal, + tool_resources, + definitionsOnly, + streamId: null, // No resumable stream for OpenAI compat + }); + } catch (error) { + logger.error('Error loading tools for agent ' + agentId, error); + } + }; +} + +/** + * Convert content part to internal format + * @param {Object} part - Content part + * @returns {Object} Converted part + */ +function convertContentPart(part) { + if (part.type === 'text') { + return { type: 'text', text: part.text }; + } + if (part.type === 'image_url') { + return { type: 'image_url', image_url: part.image_url }; + } + return part; +} + +/** + * Convert OpenAI messages to internal format + * @param {Array} messages - OpenAI format messages + * @returns {Array} Internal format messages + */ +function convertMessages(messages) { + return messages.map((msg) => { + let content; + if (typeof msg.content === 'string') { + content = msg.content; + } else if (msg.content) { + content = msg.content.map(convertContentPart); + } else { + content = ''; + } + + return { + role: msg.role, + content, + ...(msg.name && { name: msg.name }), + ...(msg.tool_calls && { tool_calls: msg.tool_calls }), + ...(msg.tool_call_id && { tool_call_id: msg.tool_call_id }), + }; + }); +} + +/** + * Send an error response in OpenAI format + */ +function sendErrorResponse(res, statusCode, message, type = 'invalid_request_error', code = null) { + res.status(statusCode).json(createErrorResponse(message, type, code)); +} + +/** + * OpenAI-compatible chat completions controller for agents. + * + * POST /v1/chat/completions + * + * Request format: + * { + * "model": "agent_id_here", + * "messages": [{"role": "user", "content": "Hello!"}], + * "stream": true, + * "conversation_id": "optional", + * "parent_message_id": "optional" + * } + */ +const OpenAIChatCompletionController = async (req, res) => { + const appConfig = req.config; + const requestStartTime = Date.now(); + + // Validate request + const validation = validateRequest(req.body); + if (isChatCompletionValidationFailure(validation)) { + return sendErrorResponse(res, 400, validation.error); + } + + const request = validation.request; + const agentId = request.model; + + // Look up the agent + const agent = await getAgent({ id: agentId }); + if (!agent) { + return sendErrorResponse( + res, + 404, + `Agent not found: ${agentId}`, + 'invalid_request_error', + 'model_not_found', + ); + } + + // Generate IDs + const requestId = `chatcmpl-${nanoid()}`; + const conversationId = request.conversation_id ?? nanoid(); + const parentMessageId = request.parent_message_id ?? null; + const created = Math.floor(Date.now() / 1000); + + const context = { + created, + requestId, + model: agentId, + }; + + logger.debug( + `[OpenAI API] Request ${requestId} started for agent ${agentId}, stream: ${request.stream}`, + ); + + // Set up abort controller + const abortController = new AbortController(); + + // Handle client disconnect + req.on('close', () => { + if (!abortController.signal.aborted) { + abortController.abort(); + logger.debug('[OpenAI API] Client disconnected, aborting'); + } + }); + + try { + // Build allowed providers set + const allowedProviders = new Set( + appConfig?.endpoints?.[EModelEndpoint.agents]?.allowedProviders, + ); + + // Create tool loader + const loadTools = createToolLoader(abortController.signal); + + // Initialize the agent first to check for disableStreaming + const endpointOption = { + endpoint: agent.provider, + model_parameters: agent.model_parameters ?? {}, + }; + + const primaryConfig = await initializeAgent( + { + req, + res, + loadTools, + requestFiles: [], + conversationId, + parentMessageId, + agent, + endpointOption, + allowedProviders, + isInitialAgent: true, + }, + { + getConvoFiles, + getFiles: db.getFiles, + getUserKey: db.getUserKey, + getMessages: db.getMessages, + updateFilesUsage: db.updateFilesUsage, + getUserKeyValues: db.getUserKeyValues, + getUserCodeFiles: db.getUserCodeFiles, + getToolFilesByIds: db.getToolFilesByIds, + getCodeGeneratedFiles: db.getCodeGeneratedFiles, + }, + ); + + // Determine if streaming is enabled (check both request and agent config) + const streamingDisabled = !!primaryConfig.model_parameters?.disableStreaming; + const isStreaming = request.stream === true && !streamingDisabled; + + // Create tracker for streaming or aggregator for non-streaming + const tracker = isStreaming ? createOpenAIStreamTracker() : null; + const aggregator = isStreaming ? null : createOpenAIContentAggregator(); + + // Set up response for streaming + if (isStreaming) { + res.setHeader('Content-Type', 'text/event-stream'); + res.setHeader('Cache-Control', 'no-cache'); + res.setHeader('Connection', 'keep-alive'); + res.setHeader('X-Accel-Buffering', 'no'); + res.flushHeaders(); + + // Send initial chunk with role + const initialChunk = createChunk(context, { role: 'assistant' }); + writeSSE(res, initialChunk); + } + + // Create handler config for OpenAI streaming (only used when streaming) + const handlerConfig = isStreaming + ? { + res, + context, + tracker, + } + : null; + + const collectedUsage = []; + /** @type {Promise[]} */ + const artifactPromises = []; + + const toolEndCallback = createToolEndCallback({ req, res, artifactPromises, streamId: null }); + + const toolExecuteOptions = { + loadTools: async (toolNames) => { + return loadToolsForExecution({ + req, + res, + agent, + toolNames, + signal: abortController.signal, + toolRegistry: primaryConfig.toolRegistry, + userMCPAuthMap: primaryConfig.userMCPAuthMap, + tool_resources: primaryConfig.tool_resources, + }); + }, + toolEndCallback, + }; + + const openaiMessages = convertMessages(request.messages); + + const toolSet = buildToolSet(primaryConfig); + const { messages: formattedMessages, indexTokenCountMap } = formatAgentMessages( + openaiMessages, + {}, + toolSet, + ); + + /** + * Create a simple handler that processes data + */ + const createHandler = (processor) => ({ + handle: (_event, data) => { + if (processor) { + processor(data); + } + }, + }); + + /** + * Stream text content in OpenAI format + */ + const streamText = (text) => { + if (!text) { + return; + } + if (isStreaming) { + tracker.addText(); + writeSSE(res, createChunk(context, { content: text })); + } else { + aggregator.addText(text); + } + }; + + /** + * Stream reasoning content in OpenAI format (OpenRouter convention) + */ + const streamReasoning = (text) => { + if (!text) { + return; + } + if (isStreaming) { + tracker.addReasoning(); + writeSSE(res, createChunk(context, { reasoning: text })); + } else { + aggregator.addReasoning(text); + } + }; + + // Event handlers for OpenAI-compatible streaming + const handlers = { + // Text content streaming + on_message_delta: createHandler((data) => { + const content = data?.delta?.content; + if (Array.isArray(content)) { + for (const part of content) { + if (part.type === 'text' && part.text) { + streamText(part.text); + } + } + } + }), + + // Reasoning/thinking content streaming + on_reasoning_delta: createHandler((data) => { + const content = data?.delta?.content; + if (Array.isArray(content)) { + for (const part of content) { + const text = part.think || part.text; + if (text) { + streamReasoning(text); + } + } + } + }), + + // Tool call initiation - streams id and name (from on_run_step) + on_run_step: createHandler((data) => { + const stepDetails = data?.stepDetails; + if (stepDetails?.type === 'tool_calls' && stepDetails.tool_calls) { + for (const tc of stepDetails.tool_calls) { + const toolIndex = data.index ?? 0; + const toolId = tc.id ?? ''; + const toolName = tc.name ?? ''; + const toolCall = { + id: toolId, + type: 'function', + function: { name: toolName, arguments: '' }, + }; + + // Track tool call in tracker or aggregator + if (isStreaming) { + if (!tracker.toolCalls.has(toolIndex)) { + tracker.toolCalls.set(toolIndex, toolCall); + } + // Stream initial tool call chunk (like OpenAI does) + writeSSE( + res, + createChunk(context, { + tool_calls: [{ index: toolIndex, ...toolCall }], + }), + ); + } else { + if (!aggregator.toolCalls.has(toolIndex)) { + aggregator.toolCalls.set(toolIndex, toolCall); + } + } + } + } + }), + + // Tool call argument streaming (from on_run_step_delta) + on_run_step_delta: createHandler((data) => { + const delta = data?.delta; + if (delta?.type === 'tool_calls' && delta.tool_calls) { + for (const tc of delta.tool_calls) { + const args = tc.args ?? ''; + if (!args) { + continue; + } + + const toolIndex = tc.index ?? 0; + + // Update tool call arguments + const targetMap = isStreaming ? tracker.toolCalls : aggregator.toolCalls; + const tracked = targetMap.get(toolIndex); + if (tracked) { + tracked.function.arguments += args; + } + + // Stream argument delta (only for streaming) + if (isStreaming) { + writeSSE( + res, + createChunk(context, { + tool_calls: [ + { + index: toolIndex, + function: { arguments: args }, + }, + ], + }), + ); + } + } + } + }), + + // Usage tracking + on_chat_model_end: createHandler((data) => { + const usage = data?.output?.usage_metadata; + if (usage) { + collectedUsage.push(usage); + const target = isStreaming ? tracker : aggregator; + target.usage.promptTokens += usage.input_tokens ?? 0; + target.usage.completionTokens += usage.output_tokens ?? 0; + } + }), + on_run_step_completed: createHandler(), + // Use proper ToolEndHandler for processing artifacts (images, file citations, code output) + on_tool_end: new ToolEndHandler(toolEndCallback, logger), + on_chain_stream: createHandler(), + on_chain_end: createHandler(), + on_agent_update: createHandler(), + on_custom_event: createHandler(), + // Event-driven tool execution handler + on_tool_execute: createToolExecuteHandler(toolExecuteOptions), + }; + + // Create and run the agent + const userId = req.user?.id ?? 'api-user'; + + // Extract userMCPAuthMap from primaryConfig (needed for MCP tool connections) + const userMCPAuthMap = primaryConfig.userMCPAuthMap; + + const run = await createRun({ + agents: [primaryConfig], + messages: formattedMessages, + indexTokenCountMap, + runId: requestId, + signal: abortController.signal, + customHandlers: handlers, + requestBody: { + messageId: requestId, + conversationId, + }, + user: { id: userId }, + }); + + if (!run) { + throw new Error('Failed to create agent run'); + } + + // Process the stream + const config = { + runName: 'AgentRun', + configurable: { + thread_id: conversationId, + user_id: userId, + user: createSafeUser(req.user), + ...(userMCPAuthMap != null && { userMCPAuthMap }), + }, + signal: abortController.signal, + streamMode: 'values', + version: 'v2', + }; + + await run.processStream({ messages: formattedMessages }, config, { + callbacks: { + [Callback.TOOL_ERROR]: (graph, error, toolId) => { + logger.error(`[OpenAI API] Tool Error "${toolId}"`, error); + }, + }, + }); + + // Record token usage against balance + const balanceConfig = getBalanceConfig(appConfig); + const transactionsConfig = getTransactionsConfig(appConfig); + recordCollectedUsage( + { spendTokens, spendStructuredTokens }, + { + user: userId, + conversationId, + collectedUsage, + context: 'message', + balance: balanceConfig, + transactions: transactionsConfig, + model: primaryConfig.model || agent.model_parameters?.model, + }, + ).catch((err) => { + logger.error('[OpenAI API] Error recording usage:', err); + }); + + // Finalize response + const duration = Date.now() - requestStartTime; + if (isStreaming) { + sendFinalChunk(handlerConfig); + res.end(); + logger.debug(`[OpenAI API] Request ${requestId} completed in ${duration}ms (streaming)`); + + // Wait for artifact processing after response ends (non-blocking) + if (artifactPromises.length > 0) { + Promise.all(artifactPromises).catch((artifactError) => { + logger.warn('[OpenAI API] Error processing artifacts:', artifactError); + }); + } + } else { + // For non-streaming, wait for artifacts before sending response + if (artifactPromises.length > 0) { + try { + await Promise.all(artifactPromises); + } catch (artifactError) { + logger.warn('[OpenAI API] Error processing artifacts:', artifactError); + } + } + + // Build usage from aggregated data + const usage = { + prompt_tokens: aggregator.usage.promptTokens, + completion_tokens: aggregator.usage.completionTokens, + total_tokens: aggregator.usage.promptTokens + aggregator.usage.completionTokens, + }; + + if (aggregator.usage.reasoningTokens > 0) { + usage.completion_tokens_details = { + reasoning_tokens: aggregator.usage.reasoningTokens, + }; + } + + const response = buildNonStreamingResponse( + context, + aggregator.getText(), + aggregator.getReasoning(), + aggregator.toolCalls, + usage, + ); + res.json(response); + logger.debug(`[OpenAI API] Request ${requestId} completed in ${duration}ms (non-streaming)`); + } + } catch (error) { + const errorMessage = error instanceof Error ? error.message : 'An error occurred'; + logger.error('[OpenAI API] Error:', error); + + // Check if we already started streaming (headers sent) + if (res.headersSent) { + // Headers already sent, send error in stream + const errorChunk = createChunk(context, { content: `\n\nError: ${errorMessage}` }, 'stop'); + writeSSE(res, errorChunk); + writeSSE(res, '[DONE]'); + res.end(); + } else { + // Forward upstream provider status codes (e.g., Anthropic 400s) instead of masking as 500 + const statusCode = + typeof error?.status === 'number' && error.status >= 400 && error.status < 600 + ? error.status + : 500; + const errorType = + statusCode >= 400 && statusCode < 500 ? 'invalid_request_error' : 'server_error'; + sendErrorResponse(res, statusCode, errorMessage, errorType); + } + } +}; + +/** + * List available agents as models (filtered by remote access permissions) + * + * GET /v1/models + */ +const ListModelsController = async (req, res) => { + try { + const userId = req.user?.id; + const userRole = req.user?.role; + + if (!userId) { + return sendErrorResponse(res, 401, 'Authentication required', 'auth_error'); + } + + // Find agents the user has remote access to (VIEW permission on REMOTE_AGENT) + const accessibleAgentIds = await findAccessibleResources({ + userId, + role: userRole, + resourceType: ResourceType.REMOTE_AGENT, + requiredPermissions: PermissionBits.VIEW, + }); + + // Get the accessible agents + let agents = []; + if (accessibleAgentIds.length > 0) { + agents = await getAgents({ _id: { $in: accessibleAgentIds } }); + } + + const models = agents.map((agent) => ({ + id: agent.id, + object: 'model', + created: Math.floor(new Date(agent.createdAt || Date.now()).getTime() / 1000), + owned_by: 'librechat', + permission: [], + root: agent.id, + parent: null, + // LibreChat extensions + name: agent.name, + description: agent.description, + provider: agent.provider, + })); + + res.json({ + object: 'list', + data: models, + }); + } catch (error) { + const errorMessage = error instanceof Error ? error.message : 'Failed to list models'; + logger.error('[OpenAI API] Error listing models:', error); + sendErrorResponse(res, 500, errorMessage, 'server_error'); + } +}; + +/** + * Get a specific model/agent (with remote access permission check) + * + * GET /v1/models/:model + */ +const GetModelController = async (req, res) => { + try { + const { model } = req.params; + const userId = req.user?.id; + const userRole = req.user?.role; + + if (!userId) { + return sendErrorResponse(res, 401, 'Authentication required', 'auth_error'); + } + + const agent = await getAgent({ id: model }); + + if (!agent) { + return sendErrorResponse( + res, + 404, + `Model not found: ${model}`, + 'invalid_request_error', + 'model_not_found', + ); + } + + // Check if user has remote access to this agent + const accessibleAgentIds = await findAccessibleResources({ + userId, + role: userRole, + resourceType: ResourceType.REMOTE_AGENT, + requiredPermissions: PermissionBits.VIEW, + }); + + const hasAccess = accessibleAgentIds.some((id) => id.toString() === agent._id.toString()); + + if (!hasAccess) { + return sendErrorResponse( + res, + 403, + `No remote access to model: ${model}`, + 'permission_error', + 'access_denied', + ); + } + + res.json({ + id: agent.id, + object: 'model', + created: Math.floor(new Date(agent.createdAt || Date.now()).getTime() / 1000), + owned_by: 'librechat', + permission: [], + root: agent.id, + parent: null, + // LibreChat extensions + name: agent.name, + description: agent.description, + provider: agent.provider, + }); + } catch (error) { + const errorMessage = error instanceof Error ? error.message : 'Failed to get model'; + logger.error('[OpenAI API] Error getting model:', error); + sendErrorResponse(res, 500, errorMessage, 'server_error'); + } +}; + +module.exports = { + OpenAIChatCompletionController, + ListModelsController, + GetModelController, +}; diff --git a/api/server/controllers/agents/request.js b/api/server/controllers/agents/request.js index cf706ef89c..79387b6e89 100644 --- a/api/server/controllers/agents/request.js +++ b/api/server/controllers/agents/request.js @@ -67,7 +67,15 @@ const ResumableAgentController = async (req, res, next, initializeClient, addTit let client = null; try { + logger.debug(`[ResumableAgentController] Creating job`, { + streamId, + conversationId, + reqConversationId, + userId, + }); + const job = await GenerationJobManager.createJob(streamId, userId, conversationId); + const jobCreatedAt = job.createdAt; // Capture creation time to detect job replacement req._resumableStreamId = streamId; // Send JSON response IMMEDIATELY so client can connect to SSE stream @@ -272,6 +280,33 @@ const ResumableAgentController = async (req, res, next, initializeClient, addTit }); } + // CRITICAL: Save response message BEFORE emitting final event. + // This prevents race conditions where the client sends a follow-up message + // before the response is saved to the database, causing orphaned parentMessageIds. + if (client.savedMessageIds && !client.savedMessageIds.has(messageId)) { + await saveMessage( + req, + { ...response, user: userId, unfinished: wasAbortedBeforeComplete }, + { context: 'api/server/controllers/agents/request.js - resumable response end' }, + ); + } + + // Check if our job was replaced by a new request before emitting + // This prevents stale requests from emitting events to newer jobs + const currentJob = await GenerationJobManager.getJob(streamId); + const jobWasReplaced = !currentJob || currentJob.createdAt !== jobCreatedAt; + + if (jobWasReplaced) { + logger.debug(`[ResumableAgentController] Skipping FINAL emit - job was replaced`, { + streamId, + originalCreatedAt: jobCreatedAt, + currentCreatedAt: currentJob?.createdAt, + }); + // Still decrement pending request since we incremented at start + await decrementPendingRequest(userId); + return; + } + if (!wasAbortedBeforeComplete) { const finalEvent = { final: true, @@ -281,27 +316,35 @@ const ResumableAgentController = async (req, res, next, initializeClient, addTit responseMessage: { ...response }, }; - GenerationJobManager.emitDone(streamId, finalEvent); + logger.debug(`[ResumableAgentController] Emitting FINAL event`, { + streamId, + wasAbortedBeforeComplete, + userMessageId: userMessage?.messageId, + responseMessageId: response?.messageId, + conversationId: conversation?.conversationId, + }); + + await GenerationJobManager.emitDone(streamId, finalEvent); GenerationJobManager.completeJob(streamId); await decrementPendingRequest(userId); - - if (client.savedMessageIds && !client.savedMessageIds.has(messageId)) { - await saveMessage( - req, - { ...response, user: userId }, - { context: 'api/server/controllers/agents/request.js - resumable response end' }, - ); - } } else { const finalEvent = { final: true, conversation, title: conversation.title, requestMessage: sanitizeMessageForTransmit(userMessage), - responseMessage: { ...response, error: true }, - error: { message: 'Request was aborted' }, + responseMessage: { ...response, unfinished: true }, }; - GenerationJobManager.emitDone(streamId, finalEvent); + + logger.debug(`[ResumableAgentController] Emitting ABORTED FINAL event`, { + streamId, + wasAbortedBeforeComplete, + userMessageId: userMessage?.messageId, + responseMessageId: response?.messageId, + conversationId: conversation?.conversationId, + }); + + await GenerationJobManager.emitDone(streamId, finalEvent); GenerationJobManager.completeJob(streamId, 'Request aborted'); await decrementPendingRequest(userId); } @@ -334,7 +377,7 @@ const ResumableAgentController = async (req, res, next, initializeClient, addTit // abortJob already handled emitDone and completeJob } else { logger.error(`[ResumableAgentController] Generation error for ${streamId}:`, error); - GenerationJobManager.emitError(streamId, error.message || 'Generation failed'); + await GenerationJobManager.emitError(streamId, error.message || 'Generation failed'); GenerationJobManager.completeJob(streamId, error.message); } @@ -363,7 +406,7 @@ const ResumableAgentController = async (req, res, next, initializeClient, addTit res.status(500).json({ error: error.message || 'Failed to start generation' }); } else { // JSON already sent, emit error to stream so client can receive it - GenerationJobManager.emitError(streamId, error.message || 'Failed to start generation'); + await GenerationJobManager.emitError(streamId, error.message || 'Failed to start generation'); } GenerationJobManager.completeJob(streamId, error.message); await decrementPendingRequest(userId); diff --git a/api/server/controllers/agents/responses.js b/api/server/controllers/agents/responses.js new file mode 100644 index 0000000000..afdb96be9f --- /dev/null +++ b/api/server/controllers/agents/responses.js @@ -0,0 +1,889 @@ +const { nanoid } = require('nanoid'); +const { v4: uuidv4 } = require('uuid'); +const { logger } = require('@librechat/data-schemas'); +const { Callback, ToolEndHandler, formatAgentMessages } = require('@librechat/agents'); +const { EModelEndpoint, ResourceType, PermissionBits } = require('librechat-data-provider'); +const { + createRun, + buildToolSet, + createSafeUser, + initializeAgent, + getBalanceConfig, + recordCollectedUsage, + getTransactionsConfig, + createToolExecuteHandler, + // Responses API + writeDone, + buildResponse, + generateResponseId, + isValidationFailure, + emitResponseCreated, + createResponseContext, + createResponseTracker, + setupStreamingResponse, + emitResponseInProgress, + convertInputToMessages, + validateResponseRequest, + buildAggregatedResponse, + createResponseAggregator, + sendResponsesErrorResponse, + createResponsesEventHandlers, + createAggregatorEventHandlers, +} = require('@librechat/api'); +const { + createResponsesToolEndCallback, + createToolEndCallback, +} = require('~/server/controllers/agents/callbacks'); +const { loadAgentTools, loadToolsForExecution } = require('~/server/services/ToolService'); +const { findAccessibleResources } = require('~/server/services/PermissionService'); +const { getConvoFiles, saveConvo, getConvo } = require('~/models/Conversation'); +const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); +const { getAgent, getAgents } = require('~/models/Agent'); +const db = require('~/models'); + +/** @type {import('@librechat/api').AppConfig | null} */ +let appConfig = null; + +/** + * Set the app config for the controller + * @param {import('@librechat/api').AppConfig} config + */ +function setAppConfig(config) { + appConfig = config; +} + +/** + * Creates a tool loader function for the agent. + * @param {AbortSignal} signal - The abort signal + * @param {boolean} [definitionsOnly=true] - When true, returns only serializable + * tool definitions without creating full tool instances (for event-driven mode) + */ +function createToolLoader(signal, definitionsOnly = true) { + return async function loadTools({ + req, + res, + tools, + model, + agentId, + provider, + tool_options, + tool_resources, + }) { + const agent = { id: agentId, tools, provider, model, tool_options }; + try { + return await loadAgentTools({ + req, + res, + agent, + signal, + tool_resources, + definitionsOnly, + streamId: null, + }); + } catch (error) { + logger.error('Error loading tools for agent ' + agentId, error); + } + }; +} + +/** + * Convert Open Responses input items to internal messages + * @param {import('@librechat/api').InputItem[]} input + * @returns {Array} Internal messages + */ +function convertToInternalMessages(input) { + return convertInputToMessages(input); +} + +/** + * Load messages from a previous response/conversation + * @param {string} conversationId - The conversation/response ID + * @param {string} userId - The user ID + * @returns {Promise} Messages from the conversation + */ +async function loadPreviousMessages(conversationId, userId) { + try { + const messages = await db.getMessages({ conversationId, user: userId }); + if (!messages || messages.length === 0) { + return []; + } + + // Convert stored messages to internal format + return messages.map((msg) => { + const internalMsg = { + role: msg.isCreatedByUser ? 'user' : 'assistant', + content: '', + messageId: msg.messageId, + }; + + // Handle content - could be string or array + if (typeof msg.text === 'string') { + internalMsg.content = msg.text; + } else if (Array.isArray(msg.content)) { + // Handle content parts + internalMsg.content = msg.content; + } else if (msg.text) { + internalMsg.content = String(msg.text); + } + + return internalMsg; + }); + } catch (error) { + logger.error('[Responses API] Error loading previous messages:', error); + return []; + } +} + +/** + * Save input messages to database + * @param {import('express').Request} req + * @param {string} conversationId + * @param {Array} inputMessages - Internal format messages + * @param {string} agentId + * @returns {Promise} + */ +async function saveInputMessages(req, conversationId, inputMessages, agentId) { + for (const msg of inputMessages) { + if (msg.role === 'user') { + await db.saveMessage( + req, + { + messageId: msg.messageId || nanoid(), + conversationId, + parentMessageId: null, + isCreatedByUser: true, + text: typeof msg.content === 'string' ? msg.content : JSON.stringify(msg.content), + sender: 'User', + endpoint: EModelEndpoint.agents, + model: agentId, + }, + { context: 'Responses API - save user input' }, + ); + } + } +} + +/** + * Save response output to database + * @param {import('express').Request} req + * @param {string} conversationId + * @param {string} responseId + * @param {import('@librechat/api').Response} response + * @param {string} agentId + * @returns {Promise} + */ +async function saveResponseOutput(req, conversationId, responseId, response, agentId) { + // Extract text content from output items + let responseText = ''; + for (const item of response.output) { + if (item.type === 'message' && item.content) { + for (const part of item.content) { + if (part.type === 'output_text' && part.text) { + responseText += part.text; + } + } + } + } + + // Save the assistant message + await db.saveMessage( + req, + { + messageId: responseId, + conversationId, + parentMessageId: null, + isCreatedByUser: false, + text: responseText, + sender: 'Agent', + endpoint: EModelEndpoint.agents, + model: agentId, + finish_reason: response.status === 'completed' ? 'stop' : response.status, + tokenCount: response.usage?.output_tokens, + }, + { context: 'Responses API - save assistant response' }, + ); +} + +/** + * Save or update conversation + * @param {import('express').Request} req + * @param {string} conversationId + * @param {string} agentId + * @param {object} agent + * @returns {Promise} + */ +async function saveConversation(req, conversationId, agentId, agent) { + await saveConvo( + req, + { + conversationId, + endpoint: EModelEndpoint.agents, + agentId, + title: agent?.name || 'Open Responses Conversation', + model: agent?.model, + }, + { context: 'Responses API - save conversation' }, + ); +} + +/** + * Convert stored messages to Open Responses output format + * @param {Array} messages - Stored messages + * @returns {Array} Output items + */ +function convertMessagesToOutputItems(messages) { + const output = []; + + for (const msg of messages) { + if (!msg.isCreatedByUser) { + output.push({ + type: 'message', + id: msg.messageId, + role: 'assistant', + status: 'completed', + content: [ + { + type: 'output_text', + text: msg.text || '', + annotations: [], + }, + ], + }); + } + } + + return output; +} + +/** + * Create Response - POST /v1/responses + * + * Creates a model response following the Open Responses API specification. + * Supports both streaming and non-streaming responses. + * + * @param {import('express').Request} req + * @param {import('express').Response} res + */ +const createResponse = async (req, res) => { + const requestStartTime = Date.now(); + + // Validate request + const validation = validateResponseRequest(req.body); + if (isValidationFailure(validation)) { + return sendResponsesErrorResponse(res, 400, validation.error); + } + + const request = validation.request; + const agentId = request.model; + const isStreaming = request.stream === true; + + // Look up the agent + const agent = await getAgent({ id: agentId }); + if (!agent) { + return sendResponsesErrorResponse( + res, + 404, + `Agent not found: ${agentId}`, + 'not_found', + 'model_not_found', + ); + } + + // Generate IDs + const responseId = generateResponseId(); + const conversationId = request.previous_response_id ?? uuidv4(); + const parentMessageId = null; + + // Create response context + const context = createResponseContext(request, responseId); + + logger.debug( + `[Responses API] Request ${responseId} started for agent ${agentId}, stream: ${isStreaming}`, + ); + + // Set up abort controller + const abortController = new AbortController(); + + // Handle client disconnect + req.on('close', () => { + if (!abortController.signal.aborted) { + abortController.abort(); + logger.debug('[Responses API] Client disconnected, aborting'); + } + }); + + try { + // Build allowed providers set + const allowedProviders = new Set( + appConfig?.endpoints?.[EModelEndpoint.agents]?.allowedProviders, + ); + + // Create tool loader + const loadTools = createToolLoader(abortController.signal); + + // Initialize the agent first to check for disableStreaming + const endpointOption = { + endpoint: agent.provider, + model_parameters: agent.model_parameters ?? {}, + }; + + const primaryConfig = await initializeAgent( + { + req, + res, + loadTools, + requestFiles: [], + conversationId, + parentMessageId, + agent, + endpointOption, + allowedProviders, + isInitialAgent: true, + }, + { + getConvoFiles, + getFiles: db.getFiles, + getUserKey: db.getUserKey, + getMessages: db.getMessages, + updateFilesUsage: db.updateFilesUsage, + getUserKeyValues: db.getUserKeyValues, + getUserCodeFiles: db.getUserCodeFiles, + getToolFilesByIds: db.getToolFilesByIds, + getCodeGeneratedFiles: db.getCodeGeneratedFiles, + }, + ); + + // Determine if streaming is enabled (check both request and agent config) + const streamingDisabled = !!primaryConfig.model_parameters?.disableStreaming; + const actuallyStreaming = isStreaming && !streamingDisabled; + + // Load previous messages if previous_response_id is provided + let previousMessages = []; + if (request.previous_response_id) { + const userId = req.user?.id ?? 'api-user'; + previousMessages = await loadPreviousMessages(request.previous_response_id, userId); + } + + // Convert input to internal messages + const inputMessages = convertToInternalMessages( + typeof request.input === 'string' ? request.input : request.input, + ); + + // Merge previous messages with new input + const allMessages = [...previousMessages, ...inputMessages]; + + const toolSet = buildToolSet(primaryConfig); + const { messages: formattedMessages, indexTokenCountMap } = formatAgentMessages( + allMessages, + {}, + toolSet, + ); + + // Create tracker for streaming or aggregator for non-streaming + const tracker = actuallyStreaming ? createResponseTracker() : null; + const aggregator = actuallyStreaming ? null : createResponseAggregator(); + + // Set up response for streaming + if (actuallyStreaming) { + setupStreamingResponse(res); + + // Create handler config + const handlerConfig = { + res, + context, + tracker, + }; + + // Emit response.created then response.in_progress per Open Responses spec + emitResponseCreated(handlerConfig); + emitResponseInProgress(handlerConfig); + + // Create event handlers + const { handlers: responsesHandlers, finalizeStream } = + createResponsesEventHandlers(handlerConfig); + + // Collect usage for balance tracking + const collectedUsage = []; + + // Artifact promises for processing tool outputs + /** @type {Promise[]} */ + const artifactPromises = []; + // Use Responses API-specific callback that emits librechat:attachment events + const toolEndCallback = createResponsesToolEndCallback({ + req, + res, + tracker, + artifactPromises, + }); + + // Create tool execute options for event-driven tool execution + const toolExecuteOptions = { + loadTools: async (toolNames) => { + return loadToolsForExecution({ + req, + res, + agent, + toolNames, + signal: abortController.signal, + toolRegistry: primaryConfig.toolRegistry, + userMCPAuthMap: primaryConfig.userMCPAuthMap, + tool_resources: primaryConfig.tool_resources, + }); + }, + toolEndCallback, + }; + + // Combine handlers + const handlers = { + on_message_delta: responsesHandlers.on_message_delta, + on_reasoning_delta: responsesHandlers.on_reasoning_delta, + on_run_step: responsesHandlers.on_run_step, + on_run_step_delta: responsesHandlers.on_run_step_delta, + on_chat_model_end: { + handle: (event, data) => { + responsesHandlers.on_chat_model_end.handle(event, data); + const usage = data?.output?.usage_metadata; + if (usage) { + collectedUsage.push(usage); + } + }, + }, + on_tool_end: new ToolEndHandler(toolEndCallback, logger), + on_run_step_completed: { handle: () => {} }, + on_chain_stream: { handle: () => {} }, + on_chain_end: { handle: () => {} }, + on_agent_update: { handle: () => {} }, + on_custom_event: { handle: () => {} }, + on_tool_execute: createToolExecuteHandler(toolExecuteOptions), + }; + + // Create and run the agent + const userId = req.user?.id ?? 'api-user'; + const userMCPAuthMap = primaryConfig.userMCPAuthMap; + + const run = await createRun({ + agents: [primaryConfig], + messages: formattedMessages, + indexTokenCountMap, + runId: responseId, + signal: abortController.signal, + customHandlers: handlers, + requestBody: { + messageId: responseId, + conversationId, + }, + user: { id: userId }, + }); + + if (!run) { + throw new Error('Failed to create agent run'); + } + + // Process the stream + const config = { + runName: 'AgentRun', + configurable: { + thread_id: conversationId, + user_id: userId, + user: createSafeUser(req.user), + ...(userMCPAuthMap != null && { userMCPAuthMap }), + }, + signal: abortController.signal, + streamMode: 'values', + version: 'v2', + }; + + await run.processStream({ messages: formattedMessages }, config, { + callbacks: { + [Callback.TOOL_ERROR]: (graph, error, toolId) => { + logger.error(`[Responses API] Tool Error "${toolId}"`, error); + }, + }, + }); + + // Record token usage against balance + const balanceConfig = getBalanceConfig(req.config); + const transactionsConfig = getTransactionsConfig(req.config); + recordCollectedUsage( + { spendTokens, spendStructuredTokens }, + { + user: userId, + conversationId, + collectedUsage, + context: 'message', + balance: balanceConfig, + transactions: transactionsConfig, + model: primaryConfig.model || agent.model_parameters?.model, + }, + ).catch((err) => { + logger.error('[Responses API] Error recording usage:', err); + }); + + // Finalize the stream + finalizeStream(); + res.end(); + + const duration = Date.now() - requestStartTime; + logger.debug(`[Responses API] Request ${responseId} completed in ${duration}ms (streaming)`); + + // Save to database if store: true + if (request.store === true) { + try { + // Save conversation + await saveConversation(req, conversationId, agentId, agent); + + // Save input messages + await saveInputMessages(req, conversationId, inputMessages, agentId); + + // Build response for saving (use tracker with buildResponse for streaming) + const finalResponse = buildResponse(context, tracker, 'completed'); + await saveResponseOutput(req, conversationId, responseId, finalResponse, agentId); + + logger.debug( + `[Responses API] Stored response ${responseId} in conversation ${conversationId}`, + ); + } catch (saveError) { + logger.error('[Responses API] Error saving response:', saveError); + // Don't fail the request if saving fails + } + } + + // Wait for artifact processing after response ends (non-blocking) + if (artifactPromises.length > 0) { + Promise.all(artifactPromises).catch((artifactError) => { + logger.warn('[Responses API] Error processing artifacts:', artifactError); + }); + } + } else { + const aggregatorHandlers = createAggregatorEventHandlers(aggregator); + + // Collect usage for balance tracking + const collectedUsage = []; + + /** @type {Promise[]} */ + const artifactPromises = []; + const toolEndCallback = createToolEndCallback({ req, res, artifactPromises, streamId: null }); + + const toolExecuteOptions = { + loadTools: async (toolNames) => { + return loadToolsForExecution({ + req, + res, + agent, + toolNames, + signal: abortController.signal, + toolRegistry: primaryConfig.toolRegistry, + userMCPAuthMap: primaryConfig.userMCPAuthMap, + tool_resources: primaryConfig.tool_resources, + }); + }, + toolEndCallback, + }; + + const handlers = { + on_message_delta: aggregatorHandlers.on_message_delta, + on_reasoning_delta: aggregatorHandlers.on_reasoning_delta, + on_run_step: aggregatorHandlers.on_run_step, + on_run_step_delta: aggregatorHandlers.on_run_step_delta, + on_chat_model_end: { + handle: (event, data) => { + aggregatorHandlers.on_chat_model_end.handle(event, data); + const usage = data?.output?.usage_metadata; + if (usage) { + collectedUsage.push(usage); + } + }, + }, + on_tool_end: new ToolEndHandler(toolEndCallback, logger), + on_run_step_completed: { handle: () => {} }, + on_chain_stream: { handle: () => {} }, + on_chain_end: { handle: () => {} }, + on_agent_update: { handle: () => {} }, + on_custom_event: { handle: () => {} }, + on_tool_execute: createToolExecuteHandler(toolExecuteOptions), + }; + + const userId = req.user?.id ?? 'api-user'; + const userMCPAuthMap = primaryConfig.userMCPAuthMap; + + const run = await createRun({ + agents: [primaryConfig], + messages: formattedMessages, + indexTokenCountMap, + runId: responseId, + signal: abortController.signal, + customHandlers: handlers, + requestBody: { + messageId: responseId, + conversationId, + }, + user: { id: userId }, + }); + + if (!run) { + throw new Error('Failed to create agent run'); + } + + const config = { + runName: 'AgentRun', + configurable: { + thread_id: conversationId, + user_id: userId, + user: createSafeUser(req.user), + ...(userMCPAuthMap != null && { userMCPAuthMap }), + }, + signal: abortController.signal, + streamMode: 'values', + version: 'v2', + }; + + await run.processStream({ messages: formattedMessages }, config, { + callbacks: { + [Callback.TOOL_ERROR]: (graph, error, toolId) => { + logger.error(`[Responses API] Tool Error "${toolId}"`, error); + }, + }, + }); + + // Record token usage against balance + const balanceConfig = getBalanceConfig(req.config); + const transactionsConfig = getTransactionsConfig(req.config); + recordCollectedUsage( + { spendTokens, spendStructuredTokens }, + { + user: userId, + conversationId, + collectedUsage, + context: 'message', + balance: balanceConfig, + transactions: transactionsConfig, + model: primaryConfig.model || agent.model_parameters?.model, + }, + ).catch((err) => { + logger.error('[Responses API] Error recording usage:', err); + }); + + if (artifactPromises.length > 0) { + try { + await Promise.all(artifactPromises); + } catch (artifactError) { + logger.warn('[Responses API] Error processing artifacts:', artifactError); + } + } + + const response = buildAggregatedResponse(context, aggregator); + + if (request.store === true) { + try { + await saveConversation(req, conversationId, agentId, agent); + + await saveInputMessages(req, conversationId, inputMessages, agentId); + + await saveResponseOutput(req, conversationId, responseId, response, agentId); + + logger.debug( + `[Responses API] Stored response ${responseId} in conversation ${conversationId}`, + ); + } catch (saveError) { + logger.error('[Responses API] Error saving response:', saveError); + // Don't fail the request if saving fails + } + } + + res.json(response); + + const duration = Date.now() - requestStartTime; + logger.debug( + `[Responses API] Request ${responseId} completed in ${duration}ms (non-streaming)`, + ); + } + } catch (error) { + const errorMessage = error instanceof Error ? error.message : 'An error occurred'; + logger.error('[Responses API] Error:', error); + + // Check if we already started streaming (headers sent) + if (res.headersSent) { + // Headers already sent, write error event and close + writeDone(res); + res.end(); + } else { + // Forward upstream provider status codes (e.g., Anthropic 400s) instead of masking as 500 + const statusCode = + typeof error?.status === 'number' && error.status >= 400 && error.status < 600 + ? error.status + : 500; + const errorType = statusCode >= 400 && statusCode < 500 ? 'invalid_request' : 'server_error'; + sendResponsesErrorResponse(res, statusCode, errorMessage, errorType); + } + } +}; + +/** + * List available agents as models - GET /v1/models (also works with /v1/responses/models) + * + * Returns a list of available agents the user has remote access to. + * + * @param {import('express').Request} req + * @param {import('express').Response} res + */ +const listModels = async (req, res) => { + try { + const userId = req.user?.id; + const userRole = req.user?.role; + + if (!userId) { + return sendResponsesErrorResponse(res, 401, 'Authentication required', 'auth_error'); + } + + // Find agents the user has remote access to (VIEW permission on REMOTE_AGENT) + const accessibleAgentIds = await findAccessibleResources({ + userId, + role: userRole, + resourceType: ResourceType.REMOTE_AGENT, + requiredPermissions: PermissionBits.VIEW, + }); + + // Get the accessible agents + let agents = []; + if (accessibleAgentIds.length > 0) { + agents = await getAgents({ _id: { $in: accessibleAgentIds } }); + } + + // Convert to models format + const models = agents.map((agent) => ({ + id: agent.id, + object: 'model', + created: Math.floor(new Date(agent.createdAt).getTime() / 1000), + owned_by: agent.author ?? 'librechat', + // Additional metadata + name: agent.name, + description: agent.description, + provider: agent.provider, + })); + + res.json({ + object: 'list', + data: models, + }); + } catch (error) { + logger.error('[Responses API] Error listing models:', error); + sendResponsesErrorResponse( + res, + 500, + error instanceof Error ? error.message : 'Failed to list models', + 'server_error', + ); + } +}; + +/** + * Get Response - GET /v1/responses/:id + * + * Retrieves a stored response by its ID. + * The response ID maps to a conversationId in LibreChat's storage. + * + * @param {import('express').Request} req + * @param {import('express').Response} res + */ +const getResponse = async (req, res) => { + try { + const responseId = req.params.id; + const userId = req.user?.id; + + if (!responseId) { + return sendResponsesErrorResponse(res, 400, 'Response ID is required'); + } + + // The responseId could be either the response ID or the conversation ID + // Try to find a conversation with this ID + const conversation = await getConvo(userId, responseId); + + if (!conversation) { + return sendResponsesErrorResponse( + res, + 404, + `Response not found: ${responseId}`, + 'not_found', + 'response_not_found', + ); + } + + // Load messages for this conversation + const messages = await db.getMessages({ conversationId: responseId, user: userId }); + + if (!messages || messages.length === 0) { + return sendResponsesErrorResponse( + res, + 404, + `No messages found for response: ${responseId}`, + 'not_found', + 'response_not_found', + ); + } + + // Convert messages to Open Responses output format + const output = convertMessagesToOutputItems(messages); + + // Find the last assistant message for usage info + const lastAssistantMessage = messages.filter((m) => !m.isCreatedByUser).pop(); + + // Build the response object + const response = { + id: responseId, + object: 'response', + created_at: Math.floor(new Date(conversation.createdAt || Date.now()).getTime() / 1000), + completed_at: Math.floor(new Date(conversation.updatedAt || Date.now()).getTime() / 1000), + status: 'completed', + incomplete_details: null, + model: conversation.agentId || conversation.model || 'unknown', + previous_response_id: null, + instructions: null, + output, + error: null, + tools: [], + tool_choice: 'auto', + truncation: 'disabled', + parallel_tool_calls: true, + text: { format: { type: 'text' } }, + temperature: 1, + top_p: 1, + presence_penalty: 0, + frequency_penalty: 0, + top_logprobs: null, + reasoning: null, + user: userId, + usage: lastAssistantMessage?.tokenCount + ? { + input_tokens: 0, + output_tokens: lastAssistantMessage.tokenCount, + total_tokens: lastAssistantMessage.tokenCount, + } + : null, + max_output_tokens: null, + max_tool_calls: null, + store: true, + background: false, + service_tier: 'default', + metadata: {}, + safety_identifier: null, + prompt_cache_key: null, + }; + + res.json(response); + } catch (error) { + logger.error('[Responses API] Error getting response:', error); + sendResponsesErrorResponse( + res, + 500, + error instanceof Error ? error.message : 'Failed to get response', + 'server_error', + ); + } +}; + +module.exports = { + createResponse, + getResponse, + listModels, + setAppConfig, +}; diff --git a/api/server/controllers/agents/v1.js b/api/server/controllers/agents/v1.js index 9f0a4a2279..34078b2250 100644 --- a/api/server/controllers/agents/v1.js +++ b/api/server/controllers/agents/v1.js @@ -11,7 +11,9 @@ const { convertOcrToContextInPlace, } = require('@librechat/api'); const { + Time, Tools, + CacheKeys, Constants, FileSources, ResourceType, @@ -21,8 +23,6 @@ const { PermissionBits, actionDelimiter, removeNullishValues, - CacheKeys, - Time, } = require('librechat-data-provider'); const { getListAgentsByAccess, @@ -94,16 +94,25 @@ const createAgentHandler = async (req, res) => { const agent = await createAgent(agentData); - // Automatically grant owner permissions to the creator try { - await grantPermission({ - principalType: PrincipalType.USER, - principalId: userId, - resourceType: ResourceType.AGENT, - resourceId: agent._id, - accessRoleId: AccessRoleIds.AGENT_OWNER, - grantedBy: userId, - }); + await Promise.all([ + grantPermission({ + principalType: PrincipalType.USER, + principalId: userId, + resourceType: ResourceType.AGENT, + resourceId: agent._id, + accessRoleId: AccessRoleIds.AGENT_OWNER, + grantedBy: userId, + }), + grantPermission({ + principalType: PrincipalType.USER, + principalId: userId, + resourceType: ResourceType.REMOTE_AGENT, + resourceId: agent._id, + accessRoleId: AccessRoleIds.REMOTE_AGENT_OWNER, + grantedBy: userId, + }), + ]); logger.debug( `[createAgent] Granted owner permissions to user ${userId} for agent ${agent.id}`, ); @@ -396,16 +405,25 @@ const duplicateAgentHandler = async (req, res) => { newAgentData.actions = agentActions; const newAgent = await createAgent(newAgentData); - // Automatically grant owner permissions to the duplicator try { - await grantPermission({ - principalType: PrincipalType.USER, - principalId: userId, - resourceType: ResourceType.AGENT, - resourceId: newAgent._id, - accessRoleId: AccessRoleIds.AGENT_OWNER, - grantedBy: userId, - }); + await Promise.all([ + grantPermission({ + principalType: PrincipalType.USER, + principalId: userId, + resourceType: ResourceType.AGENT, + resourceId: newAgent._id, + accessRoleId: AccessRoleIds.AGENT_OWNER, + grantedBy: userId, + }), + grantPermission({ + principalType: PrincipalType.USER, + principalId: userId, + resourceType: ResourceType.REMOTE_AGENT, + resourceId: newAgent._id, + accessRoleId: AccessRoleIds.REMOTE_AGENT_OWNER, + grantedBy: userId, + }), + ]); logger.debug( `[duplicateAgent] Granted owner permissions to user ${userId} for duplicated agent ${newAgent.id}`, ); diff --git a/api/server/controllers/auth/LogoutController.js b/api/server/controllers/auth/LogoutController.js index ec66316285..0b3cf262b8 100644 --- a/api/server/controllers/auth/LogoutController.js +++ b/api/server/controllers/auth/LogoutController.js @@ -22,6 +22,7 @@ const logoutController = async (req, res) => { res.clearCookie('refreshToken'); res.clearCookie('openid_access_token'); + res.clearCookie('openid_id_token'); res.clearCookie('openid_user_id'); res.clearCookie('token_provider'); const response = { message }; diff --git a/api/server/controllers/auth/oauth.js b/api/server/controllers/auth/oauth.js new file mode 100644 index 0000000000..80c2ced002 --- /dev/null +++ b/api/server/controllers/auth/oauth.js @@ -0,0 +1,79 @@ +const { CacheKeys } = require('librechat-data-provider'); +const { logger, DEFAULT_SESSION_EXPIRY } = require('@librechat/data-schemas'); +const { + isEnabled, + getAdminPanelUrl, + isAdminPanelRedirect, + generateAdminExchangeCode, +} = require('@librechat/api'); +const { syncUserEntraGroupMemberships } = require('~/server/services/PermissionService'); +const { setAuthTokens, setOpenIDAuthTokens } = require('~/server/services/AuthService'); +const getLogStores = require('~/cache/getLogStores'); +const { checkBan } = require('~/server/middleware'); +const { generateToken } = require('~/models'); + +const domains = { + client: process.env.DOMAIN_CLIENT, + server: process.env.DOMAIN_SERVER, +}; + +function createOAuthHandler(redirectUri = domains.client) { + /** + * A handler to process OAuth authentication results. + * @type {Function} + * @param {ServerRequest} req - Express request object. + * @param {ServerResponse} res - Express response object. + * @param {NextFunction} next - Express next middleware function. + */ + return async (req, res, next) => { + try { + if (res.headersSent) { + return; + } + + await checkBan(req, res); + if (req.banned) { + return; + } + + /** Check if this is an admin panel redirect (cross-origin) */ + if (isAdminPanelRedirect(redirectUri, getAdminPanelUrl(), domains.client)) { + /** For admin panel, generate exchange code instead of setting cookies */ + const cache = getLogStores(CacheKeys.ADMIN_OAUTH_EXCHANGE); + const sessionExpiry = Number(process.env.SESSION_EXPIRY) || DEFAULT_SESSION_EXPIRY; + const token = await generateToken(req.user, sessionExpiry); + + /** Get refresh token from tokenset for OpenID users */ + const refreshToken = + req.user.tokenset?.refresh_token || req.user.federatedTokens?.refresh_token; + + const exchangeCode = await generateAdminExchangeCode(cache, req.user, token, refreshToken); + + const callbackUrl = new URL(redirectUri); + callbackUrl.searchParams.set('code', exchangeCode); + logger.info(`[OAuth] Admin panel redirect with exchange code for user: ${req.user.email}`); + return res.redirect(callbackUrl.toString()); + } + + /** Standard OAuth flow - set cookies and redirect */ + if ( + req.user && + req.user.provider == 'openid' && + isEnabled(process.env.OPENID_REUSE_TOKENS) === true + ) { + await syncUserEntraGroupMemberships(req.user, req.user.tokenset.access_token); + setOpenIDAuthTokens(req.user.tokenset, req, res, req.user._id.toString()); + } else { + await setAuthTokens(req.user._id, res); + } + res.redirect(redirectUri); + } catch (err) { + logger.error('Error in setting authentication tokens:', err); + next(err); + } + }; +} + +module.exports = { + createOAuthHandler, +}; diff --git a/api/server/experimental.js b/api/server/experimental.js index 91ef9ef286..4a457abf61 100644 --- a/api/server/experimental.js +++ b/api/server/experimental.js @@ -299,6 +299,7 @@ if (cluster.isMaster) { app.use('/api/auth', routes.auth); app.use('/api/actions', routes.actions); app.use('/api/keys', routes.keys); + app.use('/api/api-keys', routes.apiKeys); app.use('/api/user', routes.user); app.use('/api/search', routes.search); app.use('/api/messages', routes.messages); diff --git a/api/server/index.js b/api/server/index.js index a7ddd47f37..193eb423ad 100644 --- a/api/server/index.js +++ b/api/server/index.js @@ -134,8 +134,10 @@ const startServer = async () => { app.use('/oauth', routes.oauth); /* API Endpoints */ app.use('/api/auth', routes.auth); + app.use('/api/admin', routes.adminAuth); app.use('/api/actions', routes.actions); app.use('/api/keys', routes.keys); + app.use('/api/api-keys', routes.apiKeys); app.use('/api/user', routes.user); app.use('/api/search', routes.search); app.use('/api/messages', routes.messages); @@ -249,6 +251,15 @@ process.on('uncaughtException', (err) => { return; } + if (isEnabled(process.env.CONTINUE_ON_UNCAUGHT_EXCEPTION)) { + logger.error('Unhandled error encountered. The app will continue running.', { + name: err?.name, + message: err?.message, + stack: err?.stack, + }); + return; + } + process.exit(1); }); diff --git a/api/server/middleware/abortMiddleware.js b/api/server/middleware/abortMiddleware.js index b85f1439cc..d07a09682d 100644 --- a/api/server/middleware/abortMiddleware.js +++ b/api/server/middleware/abortMiddleware.js @@ -7,13 +7,89 @@ const { sanitizeMessageForTransmit, } = require('@librechat/api'); const { isAssistantsEndpoint, ErrorTypes } = require('librechat-data-provider'); +const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); const { truncateText, smartTruncateText } = require('~/app/clients/prompts'); const clearPendingReq = require('~/cache/clearPendingReq'); const { sendError } = require('~/server/middleware/error'); -const { spendTokens } = require('~/models/spendTokens'); const { saveMessage, getConvo } = require('~/models'); const { abortRun } = require('./abortRun'); +/** + * Spend tokens for all models from collected usage. + * This handles both sequential and parallel agent execution. + * + * IMPORTANT: After spending, this function clears the collectedUsage array + * to prevent double-spending. The array is shared with AgentClient.collectedUsage, + * so clearing it here prevents the finally block from also spending tokens. + * + * @param {Object} params + * @param {string} params.userId - User ID + * @param {string} params.conversationId - Conversation ID + * @param {Array} params.collectedUsage - Usage metadata from all models + * @param {string} [params.fallbackModel] - Fallback model name if not in usage + */ +async function spendCollectedUsage({ userId, conversationId, collectedUsage, fallbackModel }) { + if (!collectedUsage || collectedUsage.length === 0) { + return; + } + + const spendPromises = []; + + for (const usage of collectedUsage) { + if (!usage) { + continue; + } + + // Support both OpenAI format (input_token_details) and Anthropic format (cache_*_input_tokens) + const cache_creation = + Number(usage.input_token_details?.cache_creation) || + Number(usage.cache_creation_input_tokens) || + 0; + const cache_read = + Number(usage.input_token_details?.cache_read) || Number(usage.cache_read_input_tokens) || 0; + + const txMetadata = { + context: 'abort', + conversationId, + user: userId, + model: usage.model ?? fallbackModel, + }; + + if (cache_creation > 0 || cache_read > 0) { + spendPromises.push( + spendStructuredTokens(txMetadata, { + promptTokens: { + input: usage.input_tokens, + write: cache_creation, + read: cache_read, + }, + completionTokens: usage.output_tokens, + }).catch((err) => { + logger.error('[abortMiddleware] Error spending structured tokens for abort', err); + }), + ); + continue; + } + + spendPromises.push( + spendTokens(txMetadata, { + promptTokens: usage.input_tokens, + completionTokens: usage.output_tokens, + }).catch((err) => { + logger.error('[abortMiddleware] Error spending tokens for abort', err); + }), + ); + } + + // Wait for all token spending to complete + await Promise.all(spendPromises); + + // Clear the array to prevent double-spending from the AgentClient finally block. + // The collectedUsage array is shared by reference with AgentClient.collectedUsage, + // so clearing it here ensures recordCollectedUsage() sees an empty array and returns early. + collectedUsage.length = 0; +} + /** * Abort an active message generation. * Uses GenerationJobManager for all agent requests. @@ -39,9 +115,8 @@ async function abortMessage(req, res) { return; } - const { jobData, content, text } = abortResult; + const { jobData, content, text, collectedUsage } = abortResult; - // Count tokens and spend them const completionTokens = await countTokens(text); const promptTokens = jobData?.promptTokens ?? 0; @@ -62,10 +137,21 @@ async function abortMessage(req, res) { tokenCount: completionTokens, }; - await spendTokens( - { ...responseMessage, context: 'incomplete', user: userId }, - { promptTokens, completionTokens }, - ); + // Spend tokens for ALL models from collectedUsage (handles parallel agents/addedConvo) + if (collectedUsage && collectedUsage.length > 0) { + await spendCollectedUsage({ + userId, + conversationId: jobData?.conversationId, + collectedUsage, + fallbackModel: jobData?.model, + }); + } else { + // Fallback: no collected usage, use text-based token counting for primary model only + await spendTokens( + { ...responseMessage, context: 'incomplete', user: userId }, + { promptTokens, completionTokens }, + ); + } await saveMessage( req, diff --git a/api/server/middleware/abortMiddleware.spec.js b/api/server/middleware/abortMiddleware.spec.js new file mode 100644 index 0000000000..93f2ce558b --- /dev/null +++ b/api/server/middleware/abortMiddleware.spec.js @@ -0,0 +1,428 @@ +/** + * Tests for abortMiddleware - spendCollectedUsage function + * + * This tests the token spending logic for abort scenarios, + * particularly for parallel agents (addedConvo) where multiple + * models need their tokens spent. + */ + +const mockSpendTokens = jest.fn().mockResolvedValue(); +const mockSpendStructuredTokens = jest.fn().mockResolvedValue(); + +jest.mock('~/models/spendTokens', () => ({ + spendTokens: (...args) => mockSpendTokens(...args), + spendStructuredTokens: (...args) => mockSpendStructuredTokens(...args), +})); + +jest.mock('@librechat/data-schemas', () => ({ + logger: { + debug: jest.fn(), + error: jest.fn(), + warn: jest.fn(), + info: jest.fn(), + }, +})); + +jest.mock('@librechat/api', () => ({ + countTokens: jest.fn().mockResolvedValue(100), + isEnabled: jest.fn().mockReturnValue(false), + sendEvent: jest.fn(), + GenerationJobManager: { + abortJob: jest.fn(), + }, + sanitizeMessageForTransmit: jest.fn((msg) => msg), +})); + +jest.mock('librechat-data-provider', () => ({ + isAssistantsEndpoint: jest.fn().mockReturnValue(false), + ErrorTypes: { INVALID_REQUEST: 'INVALID_REQUEST', NO_SYSTEM_MESSAGES: 'NO_SYSTEM_MESSAGES' }, +})); + +jest.mock('~/app/clients/prompts', () => ({ + truncateText: jest.fn((text) => text), + smartTruncateText: jest.fn((text) => text), +})); + +jest.mock('~/cache/clearPendingReq', () => jest.fn().mockResolvedValue()); + +jest.mock('~/server/middleware/error', () => ({ + sendError: jest.fn(), +})); + +jest.mock('~/models', () => ({ + saveMessage: jest.fn().mockResolvedValue(), + getConvo: jest.fn().mockResolvedValue({ title: 'Test Chat' }), +})); + +jest.mock('./abortRun', () => ({ + abortRun: jest.fn(), +})); + +// Import the module after mocks are set up +// We need to extract the spendCollectedUsage function for testing +// Since it's not exported, we'll test it through the handleAbort flow + +describe('abortMiddleware - spendCollectedUsage', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + describe('spendCollectedUsage logic', () => { + // Since spendCollectedUsage is not exported, we test the logic directly + // by replicating the function here for unit testing + + const spendCollectedUsage = async ({ + userId, + conversationId, + collectedUsage, + fallbackModel, + }) => { + if (!collectedUsage || collectedUsage.length === 0) { + return; + } + + const spendPromises = []; + + for (const usage of collectedUsage) { + if (!usage) { + continue; + } + + const cache_creation = + Number(usage.input_token_details?.cache_creation) || + Number(usage.cache_creation_input_tokens) || + 0; + const cache_read = + Number(usage.input_token_details?.cache_read) || + Number(usage.cache_read_input_tokens) || + 0; + + const txMetadata = { + context: 'abort', + conversationId, + user: userId, + model: usage.model ?? fallbackModel, + }; + + if (cache_creation > 0 || cache_read > 0) { + spendPromises.push( + mockSpendStructuredTokens(txMetadata, { + promptTokens: { + input: usage.input_tokens, + write: cache_creation, + read: cache_read, + }, + completionTokens: usage.output_tokens, + }).catch(() => { + // Log error but don't throw + }), + ); + continue; + } + + spendPromises.push( + mockSpendTokens(txMetadata, { + promptTokens: usage.input_tokens, + completionTokens: usage.output_tokens, + }).catch(() => { + // Log error but don't throw + }), + ); + } + + // Wait for all token spending to complete + await Promise.all(spendPromises); + + // Clear the array to prevent double-spending + collectedUsage.length = 0; + }; + + it('should return early if collectedUsage is empty', async () => { + await spendCollectedUsage({ + userId: 'user-123', + conversationId: 'convo-123', + collectedUsage: [], + fallbackModel: 'gpt-4', + }); + + expect(mockSpendTokens).not.toHaveBeenCalled(); + expect(mockSpendStructuredTokens).not.toHaveBeenCalled(); + }); + + it('should return early if collectedUsage is null', async () => { + await spendCollectedUsage({ + userId: 'user-123', + conversationId: 'convo-123', + collectedUsage: null, + fallbackModel: 'gpt-4', + }); + + expect(mockSpendTokens).not.toHaveBeenCalled(); + expect(mockSpendStructuredTokens).not.toHaveBeenCalled(); + }); + + it('should skip null entries in collectedUsage', async () => { + const collectedUsage = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + null, + { input_tokens: 200, output_tokens: 60, model: 'gpt-4' }, + ]; + + await spendCollectedUsage({ + userId: 'user-123', + conversationId: 'convo-123', + collectedUsage, + fallbackModel: 'gpt-4', + }); + + expect(mockSpendTokens).toHaveBeenCalledTimes(2); + }); + + it('should spend tokens for single model', async () => { + const collectedUsage = [{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' }]; + + await spendCollectedUsage({ + userId: 'user-123', + conversationId: 'convo-123', + collectedUsage, + fallbackModel: 'gpt-4', + }); + + expect(mockSpendTokens).toHaveBeenCalledTimes(1); + expect(mockSpendTokens).toHaveBeenCalledWith( + expect.objectContaining({ + context: 'abort', + conversationId: 'convo-123', + user: 'user-123', + model: 'gpt-4', + }), + { promptTokens: 100, completionTokens: 50 }, + ); + }); + + it('should spend tokens for multiple models (parallel agents)', async () => { + const collectedUsage = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + { input_tokens: 80, output_tokens: 40, model: 'claude-3' }, + { input_tokens: 120, output_tokens: 60, model: 'gemini-pro' }, + ]; + + await spendCollectedUsage({ + userId: 'user-123', + conversationId: 'convo-123', + collectedUsage, + fallbackModel: 'gpt-4', + }); + + expect(mockSpendTokens).toHaveBeenCalledTimes(3); + + // Verify each model was called + expect(mockSpendTokens).toHaveBeenNthCalledWith( + 1, + expect.objectContaining({ model: 'gpt-4' }), + { promptTokens: 100, completionTokens: 50 }, + ); + expect(mockSpendTokens).toHaveBeenNthCalledWith( + 2, + expect.objectContaining({ model: 'claude-3' }), + { promptTokens: 80, completionTokens: 40 }, + ); + expect(mockSpendTokens).toHaveBeenNthCalledWith( + 3, + expect.objectContaining({ model: 'gemini-pro' }), + { promptTokens: 120, completionTokens: 60 }, + ); + }); + + it('should use fallbackModel when usage.model is missing', async () => { + const collectedUsage = [{ input_tokens: 100, output_tokens: 50 }]; + + await spendCollectedUsage({ + userId: 'user-123', + conversationId: 'convo-123', + collectedUsage, + fallbackModel: 'fallback-model', + }); + + expect(mockSpendTokens).toHaveBeenCalledWith( + expect.objectContaining({ model: 'fallback-model' }), + expect.any(Object), + ); + }); + + it('should use spendStructuredTokens for OpenAI format cache tokens', async () => { + const collectedUsage = [ + { + input_tokens: 100, + output_tokens: 50, + model: 'gpt-4', + input_token_details: { + cache_creation: 20, + cache_read: 10, + }, + }, + ]; + + await spendCollectedUsage({ + userId: 'user-123', + conversationId: 'convo-123', + collectedUsage, + fallbackModel: 'gpt-4', + }); + + expect(mockSpendStructuredTokens).toHaveBeenCalledTimes(1); + expect(mockSpendTokens).not.toHaveBeenCalled(); + expect(mockSpendStructuredTokens).toHaveBeenCalledWith( + expect.objectContaining({ model: 'gpt-4', context: 'abort' }), + { + promptTokens: { + input: 100, + write: 20, + read: 10, + }, + completionTokens: 50, + }, + ); + }); + + it('should use spendStructuredTokens for Anthropic format cache tokens', async () => { + const collectedUsage = [ + { + input_tokens: 100, + output_tokens: 50, + model: 'claude-3', + cache_creation_input_tokens: 25, + cache_read_input_tokens: 15, + }, + ]; + + await spendCollectedUsage({ + userId: 'user-123', + conversationId: 'convo-123', + collectedUsage, + fallbackModel: 'claude-3', + }); + + expect(mockSpendStructuredTokens).toHaveBeenCalledTimes(1); + expect(mockSpendTokens).not.toHaveBeenCalled(); + expect(mockSpendStructuredTokens).toHaveBeenCalledWith( + expect.objectContaining({ model: 'claude-3' }), + { + promptTokens: { + input: 100, + write: 25, + read: 15, + }, + completionTokens: 50, + }, + ); + }); + + it('should handle mixed cache and non-cache entries', async () => { + const collectedUsage = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + { + input_tokens: 150, + output_tokens: 30, + model: 'claude-3', + cache_creation_input_tokens: 20, + cache_read_input_tokens: 10, + }, + { input_tokens: 200, output_tokens: 20, model: 'gemini-pro' }, + ]; + + await spendCollectedUsage({ + userId: 'user-123', + conversationId: 'convo-123', + collectedUsage, + fallbackModel: 'gpt-4', + }); + + expect(mockSpendTokens).toHaveBeenCalledTimes(2); + expect(mockSpendStructuredTokens).toHaveBeenCalledTimes(1); + }); + + it('should handle real-world parallel agent abort scenario', async () => { + // Simulates: Primary agent (gemini) + addedConvo agent (gpt-5) aborted mid-stream + const collectedUsage = [ + { input_tokens: 31596, output_tokens: 151, model: 'gemini-3-flash-preview' }, + { input_tokens: 28000, output_tokens: 120, model: 'gpt-5.2' }, + ]; + + await spendCollectedUsage({ + userId: 'user-123', + conversationId: 'convo-123', + collectedUsage, + fallbackModel: 'gemini-3-flash-preview', + }); + + expect(mockSpendTokens).toHaveBeenCalledTimes(2); + + // Primary model + expect(mockSpendTokens).toHaveBeenNthCalledWith( + 1, + expect.objectContaining({ model: 'gemini-3-flash-preview' }), + { promptTokens: 31596, completionTokens: 151 }, + ); + + // Parallel model (addedConvo) + expect(mockSpendTokens).toHaveBeenNthCalledWith( + 2, + expect.objectContaining({ model: 'gpt-5.2' }), + { promptTokens: 28000, completionTokens: 120 }, + ); + }); + + it('should clear collectedUsage array after spending to prevent double-spending', async () => { + // This tests the race condition fix: after abort middleware spends tokens, + // the collectedUsage array is cleared so AgentClient.recordCollectedUsage() + // (which shares the same array reference) sees an empty array and returns early. + const collectedUsage = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + { input_tokens: 80, output_tokens: 40, model: 'claude-3' }, + ]; + + expect(collectedUsage.length).toBe(2); + + await spendCollectedUsage({ + userId: 'user-123', + conversationId: 'convo-123', + collectedUsage, + fallbackModel: 'gpt-4', + }); + + expect(mockSpendTokens).toHaveBeenCalledTimes(2); + + // The array should be cleared after spending + expect(collectedUsage.length).toBe(0); + }); + + it('should await all token spending operations before clearing array', async () => { + // Ensure we don't clear the array before spending completes + let spendCallCount = 0; + mockSpendTokens.mockImplementation(async () => { + spendCallCount++; + // Simulate async delay + await new Promise((resolve) => setTimeout(resolve, 10)); + }); + + const collectedUsage = [ + { input_tokens: 100, output_tokens: 50, model: 'gpt-4' }, + { input_tokens: 80, output_tokens: 40, model: 'claude-3' }, + ]; + + await spendCollectedUsage({ + userId: 'user-123', + conversationId: 'convo-123', + collectedUsage, + fallbackModel: 'gpt-4', + }); + + // Both spend calls should have completed + expect(spendCallCount).toBe(2); + + // Array should be cleared after awaiting + expect(collectedUsage.length).toBe(0); + }); + }); +}); diff --git a/api/server/middleware/buildEndpointOption.js b/api/server/middleware/buildEndpointOption.js index f56d850120..64ed8e7466 100644 --- a/api/server/middleware/buildEndpointOption.js +++ b/api/server/middleware/buildEndpointOption.js @@ -5,9 +5,11 @@ const { EModelEndpoint, isAgentsEndpoint, parseCompactConvo, + getDefaultParamsEndpoint, } = require('librechat-data-provider'); const azureAssistants = require('~/server/services/Endpoints/azureAssistants'); const assistants = require('~/server/services/Endpoints/assistants'); +const { getEndpointsConfig } = require('~/server/services/Config'); const agents = require('~/server/services/Endpoints/agents'); const { updateFilesUsage } = require('~/models'); @@ -19,9 +21,24 @@ const buildFunction = { async function buildEndpointOption(req, res, next) { const { endpoint, endpointType } = req.body; + + let endpointsConfig; + try { + endpointsConfig = await getEndpointsConfig(req); + } catch (error) { + logger.error('Error fetching endpoints config in buildEndpointOption', error); + } + + const defaultParamsEndpoint = getDefaultParamsEndpoint(endpointsConfig, endpoint); + let parsedBody; try { - parsedBody = parseCompactConvo({ endpoint, endpointType, conversation: req.body }); + parsedBody = parseCompactConvo({ + endpoint, + endpointType, + conversation: req.body, + defaultParamsEndpoint, + }); } catch (error) { logger.error(`Error parsing compact conversation for endpoint ${endpoint}`, error); logger.debug({ @@ -55,6 +72,7 @@ async function buildEndpointOption(req, res, next) { endpoint, endpointType, conversation: currentModelSpec.preset, + defaultParamsEndpoint, }); if (currentModelSpec.iconURL != null && currentModelSpec.iconURL !== '') { parsedBody.iconURL = currentModelSpec.iconURL; diff --git a/api/server/middleware/buildEndpointOption.spec.js b/api/server/middleware/buildEndpointOption.spec.js new file mode 100644 index 0000000000..eab5e2666b --- /dev/null +++ b/api/server/middleware/buildEndpointOption.spec.js @@ -0,0 +1,237 @@ +/** + * Wrap parseCompactConvo: the REAL function runs, but jest can observe + * calls and return values. Must be declared before require('./buildEndpointOption') + * so the destructured reference in the middleware captures the wrapper. + */ +jest.mock('librechat-data-provider', () => { + const actual = jest.requireActual('librechat-data-provider'); + return { + ...actual, + parseCompactConvo: jest.fn((...args) => actual.parseCompactConvo(...args)), + }; +}); + +const { EModelEndpoint, parseCompactConvo } = require('librechat-data-provider'); + +const mockBuildOptions = jest.fn((_endpoint, parsedBody) => ({ + ...parsedBody, + endpoint: _endpoint, +})); + +jest.mock('~/server/services/Endpoints/azureAssistants', () => ({ + buildOptions: mockBuildOptions, +})); +jest.mock('~/server/services/Endpoints/assistants', () => ({ + buildOptions: mockBuildOptions, +})); +jest.mock('~/server/services/Endpoints/agents', () => ({ + buildOptions: mockBuildOptions, +})); + +jest.mock('~/models', () => ({ + updateFilesUsage: jest.fn(), +})); + +const mockGetEndpointsConfig = jest.fn(); +jest.mock('~/server/services/Config', () => ({ + getEndpointsConfig: (...args) => mockGetEndpointsConfig(...args), +})); + +jest.mock('@librechat/api', () => ({ + handleError: jest.fn(), +})); + +const buildEndpointOption = require('./buildEndpointOption'); + +const createReq = (body, config = {}) => ({ + body, + config, + baseUrl: '/api/chat', +}); + +const createRes = () => ({ + status: jest.fn().mockReturnThis(), + json: jest.fn().mockReturnThis(), +}); + +describe('buildEndpointOption - defaultParamsEndpoint parsing', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('should pass defaultParamsEndpoint to parseCompactConvo and preserve maxOutputTokens', async () => { + mockGetEndpointsConfig.mockResolvedValue({ + AnthropicClaude: { + type: EModelEndpoint.custom, + customParams: { + defaultParamsEndpoint: EModelEndpoint.anthropic, + }, + }, + }); + + const req = createReq( + { + endpoint: 'AnthropicClaude', + endpointType: EModelEndpoint.custom, + model: 'anthropic/claude-opus-4.5', + temperature: 0.7, + maxOutputTokens: 8192, + topP: 0.9, + maxContextTokens: 50000, + }, + { modelSpecs: null }, + ); + + await buildEndpointOption(req, createRes(), jest.fn()); + + expect(parseCompactConvo).toHaveBeenCalledWith( + expect.objectContaining({ + defaultParamsEndpoint: EModelEndpoint.anthropic, + }), + ); + + const parsedResult = parseCompactConvo.mock.results[0].value; + expect(parsedResult.maxOutputTokens).toBe(8192); + expect(parsedResult.topP).toBe(0.9); + expect(parsedResult.temperature).toBe(0.7); + expect(parsedResult.maxContextTokens).toBe(50000); + }); + + it('should strip maxOutputTokens when no defaultParamsEndpoint is configured', async () => { + mockGetEndpointsConfig.mockResolvedValue({ + MyOpenRouter: { + type: EModelEndpoint.custom, + }, + }); + + const req = createReq( + { + endpoint: 'MyOpenRouter', + endpointType: EModelEndpoint.custom, + model: 'gpt-4o', + temperature: 0.7, + maxOutputTokens: 8192, + max_tokens: 4096, + }, + { modelSpecs: null }, + ); + + await buildEndpointOption(req, createRes(), jest.fn()); + + expect(parseCompactConvo).toHaveBeenCalledWith( + expect.objectContaining({ + defaultParamsEndpoint: undefined, + }), + ); + + const parsedResult = parseCompactConvo.mock.results[0].value; + expect(parsedResult.maxOutputTokens).toBeUndefined(); + expect(parsedResult.max_tokens).toBe(4096); + expect(parsedResult.temperature).toBe(0.7); + }); + + it('should strip bedrock region from custom endpoint without defaultParamsEndpoint', async () => { + mockGetEndpointsConfig.mockResolvedValue({ + MyEndpoint: { + type: EModelEndpoint.custom, + }, + }); + + const req = createReq( + { + endpoint: 'MyEndpoint', + endpointType: EModelEndpoint.custom, + model: 'gpt-4o', + temperature: 0.7, + region: 'us-east-1', + }, + { modelSpecs: null }, + ); + + await buildEndpointOption(req, createRes(), jest.fn()); + + const parsedResult = parseCompactConvo.mock.results[0].value; + expect(parsedResult.region).toBeUndefined(); + expect(parsedResult.temperature).toBe(0.7); + }); + + it('should pass defaultParamsEndpoint when re-parsing enforced model spec', async () => { + mockGetEndpointsConfig.mockResolvedValue({ + AnthropicClaude: { + type: EModelEndpoint.custom, + customParams: { + defaultParamsEndpoint: EModelEndpoint.anthropic, + }, + }, + }); + + const modelSpec = { + name: 'claude-opus-4.5', + preset: { + endpoint: 'AnthropicClaude', + endpointType: EModelEndpoint.custom, + model: 'anthropic/claude-opus-4.5', + temperature: 0.7, + maxOutputTokens: 8192, + maxContextTokens: 50000, + }, + }; + + const req = createReq( + { + endpoint: 'AnthropicClaude', + endpointType: EModelEndpoint.custom, + spec: 'claude-opus-4.5', + model: 'anthropic/claude-opus-4.5', + }, + { + modelSpecs: { + enforce: true, + list: [modelSpec], + }, + }, + ); + + await buildEndpointOption(req, createRes(), jest.fn()); + + const enforcedCall = parseCompactConvo.mock.calls[1]; + expect(enforcedCall[0]).toEqual( + expect.objectContaining({ + defaultParamsEndpoint: EModelEndpoint.anthropic, + }), + ); + + const enforcedResult = parseCompactConvo.mock.results[1].value; + expect(enforcedResult.maxOutputTokens).toBe(8192); + expect(enforcedResult.temperature).toBe(0.7); + expect(enforcedResult.maxContextTokens).toBe(50000); + }); + + it('should fall back to OpenAI schema when getEndpointsConfig fails', async () => { + mockGetEndpointsConfig.mockRejectedValue(new Error('Config unavailable')); + + const req = createReq( + { + endpoint: 'AnthropicClaude', + endpointType: EModelEndpoint.custom, + model: 'anthropic/claude-opus-4.5', + temperature: 0.7, + maxOutputTokens: 8192, + max_tokens: 4096, + }, + { modelSpecs: null }, + ); + + await buildEndpointOption(req, createRes(), jest.fn()); + + expect(parseCompactConvo).toHaveBeenCalledWith( + expect.objectContaining({ + defaultParamsEndpoint: undefined, + }), + ); + + const parsedResult = parseCompactConvo.mock.results[0].value; + expect(parsedResult.maxOutputTokens).toBeUndefined(); + expect(parsedResult.max_tokens).toBe(4096); + }); +}); diff --git a/api/server/middleware/checkSharePublicAccess.js b/api/server/middleware/checkSharePublicAccess.js index c094d54acb..0e95b9f6f8 100644 --- a/api/server/middleware/checkSharePublicAccess.js +++ b/api/server/middleware/checkSharePublicAccess.js @@ -9,6 +9,7 @@ const resourceToPermissionType = { [ResourceType.AGENT]: PermissionTypes.AGENTS, [ResourceType.PROMPTGROUP]: PermissionTypes.PROMPTS, [ResourceType.MCPSERVER]: PermissionTypes.MCP_SERVERS, + [ResourceType.REMOTE_AGENT]: PermissionTypes.REMOTE_AGENTS, }; /** diff --git a/api/server/middleware/requireJwtAuth.js b/api/server/middleware/requireJwtAuth.js index ed83c4773e..16b107aefc 100644 --- a/api/server/middleware/requireJwtAuth.js +++ b/api/server/middleware/requireJwtAuth.js @@ -7,16 +7,13 @@ const { isEnabled } = require('@librechat/api'); * Switches between JWT and OpenID authentication based on cookies and environment settings */ const requireJwtAuth = (req, res, next) => { - // Check if token provider is specified in cookies const cookieHeader = req.headers.cookie; const tokenProvider = cookieHeader ? cookies.parse(cookieHeader).token_provider : null; - // Use OpenID authentication if token provider is OpenID and OPENID_REUSE_TOKENS is enabled if (tokenProvider === 'openid' && isEnabled(process.env.OPENID_REUSE_TOKENS)) { return passport.authenticate('openidJwt', { session: false })(req, res, next); } - // Default to standard JWT authentication return passport.authenticate('jwt', { session: false })(req, res, next); }; diff --git a/api/server/routes/__tests__/convos.spec.js b/api/server/routes/__tests__/convos.spec.js index ef11b3cbbb..931ef006d0 100644 --- a/api/server/routes/__tests__/convos.spec.js +++ b/api/server/routes/__tests__/convos.spec.js @@ -385,6 +385,40 @@ describe('Convos Routes', () => { expect(deleteConvoSharedLink).not.toHaveBeenCalled(); }); + it('should return 400 when request body is empty (DoS prevention)', async () => { + const response = await request(app).delete('/api/convos').send({}); + + expect(response.status).toBe(400); + expect(response.body).toEqual({ error: 'no parameters provided' }); + expect(deleteConvos).not.toHaveBeenCalled(); + }); + + it('should return 400 when arg is null (DoS prevention)', async () => { + const response = await request(app).delete('/api/convos').send({ arg: null }); + + expect(response.status).toBe(400); + expect(response.body).toEqual({ error: 'no parameters provided' }); + expect(deleteConvos).not.toHaveBeenCalled(); + }); + + it('should return 400 when arg is undefined (DoS prevention)', async () => { + const response = await request(app).delete('/api/convos').send({ arg: undefined }); + + expect(response.status).toBe(400); + expect(response.body).toEqual({ error: 'no parameters provided' }); + expect(deleteConvos).not.toHaveBeenCalled(); + }); + + it('should return 400 when request body is null (DoS prevention)', async () => { + const response = await request(app) + .delete('/api/convos') + .set('Content-Type', 'application/json') + .send('null'); + + expect(response.status).toBe(400); + expect(deleteConvos).not.toHaveBeenCalled(); + }); + it('should return 500 if deleteConvoSharedLink fails', async () => { const mockConversationId = 'conv-error'; diff --git a/api/server/routes/__tests__/keys.spec.js b/api/server/routes/__tests__/keys.spec.js new file mode 100644 index 0000000000..0c96dd3bcb --- /dev/null +++ b/api/server/routes/__tests__/keys.spec.js @@ -0,0 +1,174 @@ +const express = require('express'); +const request = require('supertest'); + +jest.mock('~/models', () => ({ + updateUserKey: jest.fn(), + deleteUserKey: jest.fn(), + getUserKeyExpiry: jest.fn(), +})); + +jest.mock('~/server/middleware/requireJwtAuth', () => (req, res, next) => next()); + +jest.mock('~/server/middleware', () => ({ + requireJwtAuth: (req, res, next) => next(), +})); + +describe('Keys Routes', () => { + let app; + const { updateUserKey, deleteUserKey, getUserKeyExpiry } = require('~/models'); + + beforeAll(() => { + const keysRouter = require('../keys'); + + app = express(); + app.use(express.json()); + + app.use((req, res, next) => { + req.user = { id: 'test-user-123' }; + next(); + }); + + app.use('/api/keys', keysRouter); + }); + + beforeEach(() => { + jest.clearAllMocks(); + }); + + describe('PUT /', () => { + it('should update a user key with the authenticated user ID', async () => { + updateUserKey.mockResolvedValue({}); + + const response = await request(app) + .put('/api/keys') + .send({ name: 'openAI', value: 'sk-test-key-123', expiresAt: '2026-12-31' }); + + expect(response.status).toBe(201); + expect(updateUserKey).toHaveBeenCalledWith({ + userId: 'test-user-123', + name: 'openAI', + value: 'sk-test-key-123', + expiresAt: '2026-12-31', + }); + expect(updateUserKey).toHaveBeenCalledTimes(1); + }); + + it('should not allow userId override via request body (IDOR prevention)', async () => { + updateUserKey.mockResolvedValue({}); + + const response = await request(app).put('/api/keys').send({ + userId: 'attacker-injected-id', + name: 'openAI', + value: 'sk-attacker-key', + }); + + expect(response.status).toBe(201); + expect(updateUserKey).toHaveBeenCalledWith({ + userId: 'test-user-123', + name: 'openAI', + value: 'sk-attacker-key', + expiresAt: undefined, + }); + }); + + it('should ignore extraneous fields from request body', async () => { + updateUserKey.mockResolvedValue({}); + + const response = await request(app).put('/api/keys').send({ + name: 'openAI', + value: 'sk-test-key', + expiresAt: '2026-12-31', + _id: 'injected-mongo-id', + __v: 99, + extra: 'should-be-ignored', + }); + + expect(response.status).toBe(201); + expect(updateUserKey).toHaveBeenCalledWith({ + userId: 'test-user-123', + name: 'openAI', + value: 'sk-test-key', + expiresAt: '2026-12-31', + }); + }); + + it('should handle missing optional fields', async () => { + updateUserKey.mockResolvedValue({}); + + const response = await request(app) + .put('/api/keys') + .send({ name: 'anthropic', value: 'sk-ant-key' }); + + expect(response.status).toBe(201); + expect(updateUserKey).toHaveBeenCalledWith({ + userId: 'test-user-123', + name: 'anthropic', + value: 'sk-ant-key', + expiresAt: undefined, + }); + }); + + it('should return 400 when request body is null', async () => { + const response = await request(app) + .put('/api/keys') + .set('Content-Type', 'application/json') + .send('null'); + + expect(response.status).toBe(400); + expect(updateUserKey).not.toHaveBeenCalled(); + }); + }); + + describe('DELETE /:name', () => { + it('should delete a user key by name', async () => { + deleteUserKey.mockResolvedValue({}); + + const response = await request(app).delete('/api/keys/openAI'); + + expect(response.status).toBe(204); + expect(deleteUserKey).toHaveBeenCalledWith({ + userId: 'test-user-123', + name: 'openAI', + }); + expect(deleteUserKey).toHaveBeenCalledTimes(1); + }); + }); + + describe('DELETE /', () => { + it('should delete all keys when all=true', async () => { + deleteUserKey.mockResolvedValue({}); + + const response = await request(app).delete('/api/keys?all=true'); + + expect(response.status).toBe(204); + expect(deleteUserKey).toHaveBeenCalledWith({ + userId: 'test-user-123', + all: true, + }); + }); + + it('should return 400 when all query param is not true', async () => { + const response = await request(app).delete('/api/keys'); + + expect(response.status).toBe(400); + expect(response.body).toEqual({ error: 'Specify either all=true to delete.' }); + expect(deleteUserKey).not.toHaveBeenCalled(); + }); + }); + + describe('GET /', () => { + it('should return key expiry for a given key name', async () => { + const mockExpiry = { expiresAt: '2026-12-31' }; + getUserKeyExpiry.mockResolvedValue(mockExpiry); + + const response = await request(app).get('/api/keys?name=openAI'); + + expect(response.status).toBe(200); + expect(response.body).toEqual(mockExpiry); + expect(getUserKeyExpiry).toHaveBeenCalledWith({ + userId: 'test-user-123', + name: 'openAI', + }); + }); + }); +}); diff --git a/api/server/routes/__tests__/mcp.spec.js b/api/server/routes/__tests__/mcp.spec.js index 26d7988f0a..e87fcf8f15 100644 --- a/api/server/routes/__tests__/mcp.spec.js +++ b/api/server/routes/__tests__/mcp.spec.js @@ -1,8 +1,18 @@ +const crypto = require('crypto'); const express = require('express'); const request = require('supertest'); const mongoose = require('mongoose'); -const { MongoMemoryServer } = require('mongodb-memory-server'); +const cookieParser = require('cookie-parser'); const { getBasePath } = require('@librechat/api'); +const { MongoMemoryServer } = require('mongodb-memory-server'); + +function generateTestCsrfToken(flowId) { + return crypto + .createHmac('sha256', process.env.JWT_SECRET) + .update(flowId) + .digest('hex') + .slice(0, 32); +} const mockRegistryInstance = { getServerConfig: jest.fn(), @@ -130,6 +140,7 @@ describe('MCP Routes', () => { app = express(); app.use(express.json()); + app.use(cookieParser()); app.use((req, res, next) => { req.user = { id: 'test-user-id' }; @@ -168,12 +179,12 @@ describe('MCP Routes', () => { MCPOAuthHandler.initiateOAuthFlow.mockResolvedValue({ authorizationUrl: 'https://oauth.example.com/auth', - flowId: 'test-flow-id', + flowId: 'test-user-id:test-server', }); const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({ userId: 'test-user-id', - flowId: 'test-flow-id', + flowId: 'test-user-id:test-server', }); expect(response.status).toBe(302); @@ -190,7 +201,7 @@ describe('MCP Routes', () => { it('should return 403 when userId does not match authenticated user', async () => { const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({ userId: 'different-user-id', - flowId: 'test-flow-id', + flowId: 'test-user-id:test-server', }); expect(response.status).toBe(403); @@ -228,7 +239,7 @@ describe('MCP Routes', () => { const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({ userId: 'test-user-id', - flowId: 'test-flow-id', + flowId: 'test-user-id:test-server', }); expect(response.status).toBe(400); @@ -245,7 +256,7 @@ describe('MCP Routes', () => { const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({ userId: 'test-user-id', - flowId: 'test-flow-id', + flowId: 'test-user-id:test-server', }); expect(response.status).toBe(500); @@ -255,7 +266,7 @@ describe('MCP Routes', () => { it('should return 400 when flow state metadata is null', async () => { const mockFlowManager = { getFlowState: jest.fn().mockResolvedValue({ - id: 'test-flow-id', + id: 'test-user-id:test-server', metadata: null, }), }; @@ -265,7 +276,7 @@ describe('MCP Routes', () => { const response = await request(app).get('/api/mcp/test-server/oauth/initiate').query({ userId: 'test-user-id', - flowId: 'test-flow-id', + flowId: 'test-user-id:test-server', }); expect(response.status).toBe(400); @@ -280,7 +291,7 @@ describe('MCP Routes', () => { it('should redirect to error page when OAuth error is received', async () => { const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({ error: 'access_denied', - state: 'test-flow-id', + state: 'test-user-id:test-server', }); const basePath = getBasePath(); @@ -290,7 +301,7 @@ describe('MCP Routes', () => { it('should redirect to error page when code is missing', async () => { const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({ - state: 'test-flow-id', + state: 'test-user-id:test-server', }); const basePath = getBasePath(); @@ -308,15 +319,50 @@ describe('MCP Routes', () => { expect(response.headers.location).toBe(`${basePath}/oauth/error?error=missing_state`); }); - it('should redirect to error page when flow state is not found', async () => { - MCPOAuthHandler.getFlowState.mockResolvedValue(null); - + it('should redirect to error page when CSRF cookie is missing', async () => { const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({ code: 'test-auth-code', - state: 'invalid-flow-id', + state: 'test-user-id:test-server', }); const basePath = getBasePath(); + expect(response.status).toBe(302); + expect(response.headers.location).toBe( + `${basePath}/oauth/error?error=csrf_validation_failed`, + ); + }); + + it('should redirect to error page when CSRF cookie does not match state', async () => { + const csrfToken = generateTestCsrfToken('different-flow-id'); + const response = await request(app) + .get('/api/mcp/test-server/oauth/callback') + .set('Cookie', [`oauth_csrf=${csrfToken}`]) + .query({ + code: 'test-auth-code', + state: 'test-user-id:test-server', + }); + const basePath = getBasePath(); + + expect(response.status).toBe(302); + expect(response.headers.location).toBe( + `${basePath}/oauth/error?error=csrf_validation_failed`, + ); + }); + + it('should redirect to error page when flow state is not found', async () => { + MCPOAuthHandler.getFlowState.mockResolvedValue(null); + const flowId = 'invalid-flow:id'; + const csrfToken = generateTestCsrfToken(flowId); + + const response = await request(app) + .get('/api/mcp/test-server/oauth/callback') + .set('Cookie', [`oauth_csrf=${csrfToken}`]) + .query({ + code: 'test-auth-code', + state: flowId, + }); + const basePath = getBasePath(); + expect(response.status).toBe(302); expect(response.headers.location).toBe(`${basePath}/oauth/error?error=invalid_state`); }); @@ -369,16 +415,22 @@ describe('MCP Routes', () => { }); setCachedTools.mockResolvedValue(); - const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({ - code: 'test-auth-code', - state: 'test-flow-id', - }); + const flowId = 'test-user-id:test-server'; + const csrfToken = generateTestCsrfToken(flowId); + + const response = await request(app) + .get('/api/mcp/test-server/oauth/callback') + .set('Cookie', [`oauth_csrf=${csrfToken}`]) + .query({ + code: 'test-auth-code', + state: flowId, + }); const basePath = getBasePath(); expect(response.status).toBe(302); expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`); expect(MCPOAuthHandler.completeOAuthFlow).toHaveBeenCalledWith( - 'test-flow-id', + flowId, 'test-auth-code', mockFlowManager, {}, @@ -400,16 +452,24 @@ describe('MCP Routes', () => { 'mcp_oauth', mockTokens, ); - expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens'); + expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith( + 'test-user-id:test-server', + 'mcp_get_tokens', + ); }); it('should redirect to error page when callback processing fails', async () => { MCPOAuthHandler.getFlowState.mockRejectedValue(new Error('Callback error')); + const flowId = 'test-user-id:test-server'; + const csrfToken = generateTestCsrfToken(flowId); - const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({ - code: 'test-auth-code', - state: 'test-flow-id', - }); + const response = await request(app) + .get('/api/mcp/test-server/oauth/callback') + .set('Cookie', [`oauth_csrf=${csrfToken}`]) + .query({ + code: 'test-auth-code', + state: flowId, + }); const basePath = getBasePath(); expect(response.status).toBe(302); @@ -442,15 +502,21 @@ describe('MCP Routes', () => { getLogStores.mockReturnValue({}); require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); - const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({ - code: 'test-auth-code', - state: 'test-flow-id', - }); + const flowId = 'test-user-id:test-server'; + const csrfToken = generateTestCsrfToken(flowId); + + const response = await request(app) + .get('/api/mcp/test-server/oauth/callback') + .set('Cookie', [`oauth_csrf=${csrfToken}`]) + .query({ + code: 'test-auth-code', + state: flowId, + }); const basePath = getBasePath(); expect(response.status).toBe(302); expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`); - expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens'); + expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith(flowId, 'mcp_get_tokens'); }); it('should handle reconnection failure after OAuth', async () => { @@ -488,16 +554,22 @@ describe('MCP Routes', () => { getCachedTools.mockResolvedValue({}); setCachedTools.mockResolvedValue(); - const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({ - code: 'test-auth-code', - state: 'test-flow-id', - }); + const flowId = 'test-user-id:test-server'; + const csrfToken = generateTestCsrfToken(flowId); + + const response = await request(app) + .get('/api/mcp/test-server/oauth/callback') + .set('Cookie', [`oauth_csrf=${csrfToken}`]) + .query({ + code: 'test-auth-code', + state: flowId, + }); const basePath = getBasePath(); expect(response.status).toBe(302); expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`); expect(MCPTokenStorage.storeTokens).toHaveBeenCalled(); - expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('test-flow-id', 'mcp_get_tokens'); + expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith(flowId, 'mcp_get_tokens'); }); it('should redirect to error page if token storage fails', async () => { @@ -530,10 +602,16 @@ describe('MCP Routes', () => { }; require('~/config').getMCPManager.mockReturnValue(mockMcpManager); - const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({ - code: 'test-auth-code', - state: 'test-flow-id', - }); + const flowId = 'test-user-id:test-server'; + const csrfToken = generateTestCsrfToken(flowId); + + const response = await request(app) + .get('/api/mcp/test-server/oauth/callback') + .set('Cookie', [`oauth_csrf=${csrfToken}`]) + .query({ + code: 'test-auth-code', + state: flowId, + }); const basePath = getBasePath(); expect(response.status).toBe(302); @@ -589,22 +667,27 @@ describe('MCP Routes', () => { clearReconnection: jest.fn(), }); - const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({ - code: 'test-auth-code', - state: 'test-flow-id', - }); + const flowId = 'test-user-id:test-server'; + const csrfToken = generateTestCsrfToken(flowId); + + const response = await request(app) + .get('/api/mcp/test-server/oauth/callback') + .set('Cookie', [`oauth_csrf=${csrfToken}`]) + .query({ + code: 'test-auth-code', + state: flowId, + }); const basePath = getBasePath(); expect(response.status).toBe(302); expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`); - // Verify storeTokens was called with ORIGINAL flow state credentials expect(MCPTokenStorage.storeTokens).toHaveBeenCalledWith( expect.objectContaining({ userId: 'test-user-id', serverName: 'test-server', tokens: mockTokens, - clientInfo: clientInfo, // Uses original flow state, not any "updated" credentials + clientInfo: clientInfo, metadata: flowState.metadata, }), ); @@ -631,16 +714,21 @@ describe('MCP Routes', () => { getLogStores.mockReturnValue({}); require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); - const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({ - code: 'test-auth-code', - state: 'test-flow-id', - }); + const flowId = 'test-user-id:test-server'; + const csrfToken = generateTestCsrfToken(flowId); + + const response = await request(app) + .get('/api/mcp/test-server/oauth/callback') + .set('Cookie', [`oauth_csrf=${csrfToken}`]) + .query({ + code: 'test-auth-code', + state: flowId, + }); const basePath = getBasePath(); expect(response.status).toBe(302); expect(response.headers.location).toBe(`${basePath}/oauth/success?serverName=test-server`); - // Verify completeOAuthFlow was NOT called (prevented duplicate) expect(MCPOAuthHandler.completeOAuthFlow).not.toHaveBeenCalled(); expect(MCPTokenStorage.storeTokens).not.toHaveBeenCalled(); }); @@ -755,7 +843,7 @@ describe('MCP Routes', () => { getLogStores.mockReturnValue({}); require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); - const response = await request(app).get('/api/mcp/oauth/status/test-flow-id'); + const response = await request(app).get('/api/mcp/oauth/status/test-user-id:test-server'); expect(response.status).toBe(200); expect(response.body).toEqual({ @@ -766,6 +854,13 @@ describe('MCP Routes', () => { }); }); + it('should return 403 when flowId does not match authenticated user', async () => { + const response = await request(app).get('/api/mcp/oauth/status/other-user-id:test-server'); + + expect(response.status).toBe(403); + expect(response.body).toEqual({ error: 'Access denied' }); + }); + it('should return 404 when flow is not found', async () => { const mockFlowManager = { getFlowState: jest.fn().mockResolvedValue(null), @@ -774,7 +869,7 @@ describe('MCP Routes', () => { getLogStores.mockReturnValue({}); require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); - const response = await request(app).get('/api/mcp/oauth/status/non-existent-flow'); + const response = await request(app).get('/api/mcp/oauth/status/test-user-id:non-existent'); expect(response.status).toBe(404); expect(response.body).toEqual({ error: 'Flow not found' }); @@ -788,7 +883,7 @@ describe('MCP Routes', () => { getLogStores.mockReturnValue({}); require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); - const response = await request(app).get('/api/mcp/oauth/status/error-flow-id'); + const response = await request(app).get('/api/mcp/oauth/status/test-user-id:error-server'); expect(response.status).toBe(500); expect(response.body).toEqual({ error: 'Failed to get flow status' }); @@ -1375,7 +1470,7 @@ describe('MCP Routes', () => { refresh_token: 'edge-refresh-token', }; MCPOAuthHandler.getFlowState = jest.fn().mockResolvedValue({ - id: 'test-flow-id', + id: 'test-user-id:test-server', userId: 'test-user-id', metadata: { serverUrl: 'https://example.com', @@ -1403,8 +1498,12 @@ describe('MCP Routes', () => { }; require('~/config').getMCPManager.mockReturnValue(mockMcpManager); + const flowId = 'test-user-id:test-server'; + const csrfToken = generateTestCsrfToken(flowId); + const response = await request(app) - .get('/api/mcp/test-server/oauth/callback?code=test-code&state=test-flow-id') + .get(`/api/mcp/test-server/oauth/callback?code=test-code&state=${flowId}`) + .set('Cookie', [`oauth_csrf=${csrfToken}`]) .expect(302); const basePath = getBasePath(); @@ -1424,7 +1523,7 @@ describe('MCP Routes', () => { const mockFlowManager = { getFlowState: jest.fn().mockResolvedValue({ - id: 'test-flow-id', + id: 'test-user-id:test-server', userId: 'test-user-id', metadata: { serverUrl: 'https://example.com', oauth: {} }, clientInfo: {}, @@ -1453,8 +1552,12 @@ describe('MCP Routes', () => { }; require('~/config').getMCPManager.mockReturnValue(mockMcpManager); + const flowId = 'test-user-id:test-server'; + const csrfToken = generateTestCsrfToken(flowId); + const response = await request(app) - .get('/api/mcp/test-server/oauth/callback?code=test-code&state=test-flow-id') + .get(`/api/mcp/test-server/oauth/callback?code=test-code&state=${flowId}`) + .set('Cookie', [`oauth_csrf=${csrfToken}`]) .expect(302); const basePath = getBasePath(); diff --git a/api/server/routes/accessPermissions.js b/api/server/routes/accessPermissions.js index 79e7f3ddca..45afec133b 100644 --- a/api/server/routes/accessPermissions.js +++ b/api/server/routes/accessPermissions.js @@ -53,6 +53,12 @@ const checkResourcePermissionAccess = (requiredPermission) => (req, res, next) = requiredPermission, resourceIdParam: 'resourceId', }); + } else if (resourceType === ResourceType.REMOTE_AGENT) { + middleware = canAccessResource({ + resourceType: ResourceType.REMOTE_AGENT, + requiredPermission, + resourceIdParam: 'resourceId', + }); } else if (resourceType === ResourceType.PROMPTGROUP) { middleware = canAccessResource({ resourceType: ResourceType.PROMPTGROUP, diff --git a/api/server/routes/actions.js b/api/server/routes/actions.js index 14474a53d3..806edc66cc 100644 --- a/api/server/routes/actions.js +++ b/api/server/routes/actions.js @@ -1,14 +1,47 @@ const express = require('express'); const jwt = require('jsonwebtoken'); -const { getAccessToken, getBasePath } = require('@librechat/api'); const { logger } = require('@librechat/data-schemas'); const { CacheKeys } = require('librechat-data-provider'); +const { + getBasePath, + getAccessToken, + setOAuthSession, + validateOAuthCsrf, + OAUTH_CSRF_COOKIE, + setOAuthCsrfCookie, + validateOAuthSession, + OAUTH_SESSION_COOKIE, +} = require('@librechat/api'); const { findToken, updateToken, createToken } = require('~/models'); +const { requireJwtAuth } = require('~/server/middleware'); const { getFlowStateManager } = require('~/config'); const { getLogStores } = require('~/cache'); const router = express.Router(); const JWT_SECRET = process.env.JWT_SECRET; +const OAUTH_CSRF_COOKIE_PATH = '/api/actions'; + +/** + * Sets a CSRF cookie binding the action OAuth flow to the current browser session. + * Must be called before the user opens the IdP authorization URL. + * + * @route POST /actions/:action_id/oauth/bind + */ +router.post('/:action_id/oauth/bind', requireJwtAuth, setOAuthSession, async (req, res) => { + try { + const { action_id } = req.params; + const user = req.user; + if (!user?.id) { + return res.status(401).json({ error: 'User not authenticated' }); + } + const flowId = `${user.id}:${action_id}`; + setOAuthCsrfCookie(res, flowId, OAUTH_CSRF_COOKIE_PATH); + res.json({ success: true }); + } catch (error) { + logger.error('[Action OAuth] Failed to set CSRF binding cookie', error); + res.status(500).json({ error: 'Failed to bind OAuth flow' }); + } +}); /** * Handles the OAuth callback and exchanges the authorization code for tokens. @@ -45,7 +78,22 @@ router.get('/:action_id/oauth/callback', async (req, res) => { await flowManager.failFlow(identifier, 'oauth', 'Invalid user ID in state parameter'); return res.redirect(`${basePath}/oauth/error?error=invalid_state`); } + identifier = `${decodedState.user}:${action_id}`; + + if ( + !validateOAuthCsrf(req, res, identifier, OAUTH_CSRF_COOKIE_PATH) && + !validateOAuthSession(req, decodedState.user) + ) { + logger.error('[Action OAuth] CSRF validation failed: no valid CSRF or session cookie', { + identifier, + hasCsrfCookie: !!req.cookies?.[OAUTH_CSRF_COOKIE], + hasSessionCookie: !!req.cookies?.[OAUTH_SESSION_COOKIE], + }); + await flowManager.failFlow(identifier, 'oauth', 'CSRF validation failed'); + return res.redirect(`${basePath}/oauth/error?error=csrf_validation_failed`); + } + const flowState = await flowManager.getFlowState(identifier, 'oauth'); if (!flowState) { throw new Error('OAuth flow not found'); @@ -71,7 +119,6 @@ router.get('/:action_id/oauth/callback', async (req, res) => { ); await flowManager.completeFlow(identifier, 'oauth', tokenData); - /** Redirect to React success page */ const serverName = flowState.metadata?.action_name || `Action ${action_id}`; const redirectUrl = `${basePath}/oauth/success?serverName=${encodeURIComponent(serverName)}`; res.redirect(redirectUrl); diff --git a/api/server/routes/admin/auth.js b/api/server/routes/admin/auth.js new file mode 100644 index 0000000000..291b5eaaf8 --- /dev/null +++ b/api/server/routes/admin/auth.js @@ -0,0 +1,127 @@ +const express = require('express'); +const passport = require('passport'); +const { randomState } = require('openid-client'); +const { logger } = require('@librechat/data-schemas'); +const { CacheKeys } = require('librechat-data-provider'); +const { + requireAdmin, + getAdminPanelUrl, + exchangeAdminCode, + createSetBalanceConfig, +} = require('@librechat/api'); +const { loginController } = require('~/server/controllers/auth/LoginController'); +const { createOAuthHandler } = require('~/server/controllers/auth/oauth'); +const { getAppConfig } = require('~/server/services/Config'); +const getLogStores = require('~/cache/getLogStores'); +const { getOpenIdConfig } = require('~/strategies'); +const middleware = require('~/server/middleware'); +const { Balance } = require('~/db/models'); + +const setBalanceConfig = createSetBalanceConfig({ + getAppConfig, + Balance, +}); + +const router = express.Router(); + +router.post( + '/login/local', + middleware.logHeaders, + middleware.loginLimiter, + middleware.checkBan, + middleware.requireLocalAuth, + requireAdmin, + setBalanceConfig, + loginController, +); + +router.get('/verify', middleware.requireJwtAuth, requireAdmin, (req, res) => { + const { password: _p, totpSecret: _t, __v, ...user } = req.user; + user.id = user._id.toString(); + res.status(200).json({ user }); +}); + +router.get('/oauth/openid/check', (req, res) => { + const openidConfig = getOpenIdConfig(); + if (!openidConfig) { + return res.status(404).json({ + error: 'OpenID configuration not found', + error_code: 'OPENID_NOT_CONFIGURED', + }); + } + res.status(200).json({ message: 'OpenID check successful' }); +}); + +router.get('/oauth/openid', (req, res, next) => { + return passport.authenticate('openidAdmin', { + session: false, + state: randomState(), + })(req, res, next); +}); + +router.get( + '/oauth/openid/callback', + passport.authenticate('openidAdmin', { + failureRedirect: `${getAdminPanelUrl()}/auth/openid/callback?error=auth_failed&error_description=Authentication+failed`, + failureMessage: true, + session: false, + }), + requireAdmin, + setBalanceConfig, + middleware.checkDomainAllowed, + createOAuthHandler(`${getAdminPanelUrl()}/auth/openid/callback`), +); + +/** Regex pattern for valid exchange codes: 64 hex characters */ +const EXCHANGE_CODE_PATTERN = /^[a-f0-9]{64}$/i; + +/** + * Exchange OAuth authorization code for tokens. + * This endpoint is called server-to-server by the admin panel. + * The code is one-time-use and expires in 30 seconds. + * + * POST /api/admin/oauth/exchange + * Body: { code: string } + * Response: { token: string, refreshToken: string, user: object } + */ +router.post('/oauth/exchange', middleware.loginLimiter, async (req, res) => { + try { + const { code } = req.body; + + if (!code) { + logger.warn('[admin/oauth/exchange] Missing authorization code'); + return res.status(400).json({ + error: 'Missing authorization code', + error_code: 'MISSING_CODE', + }); + } + + if (typeof code !== 'string' || !EXCHANGE_CODE_PATTERN.test(code)) { + logger.warn('[admin/oauth/exchange] Invalid authorization code format'); + return res.status(400).json({ + error: 'Invalid authorization code format', + error_code: 'INVALID_CODE_FORMAT', + }); + } + + const cache = getLogStores(CacheKeys.ADMIN_OAUTH_EXCHANGE); + const result = await exchangeAdminCode(cache, code); + + if (!result) { + return res.status(401).json({ + error: 'Invalid or expired authorization code', + error_code: 'INVALID_OR_EXPIRED_CODE', + }); + } + + res.json(result); + } catch (error) { + logger.error('[admin/oauth/exchange] Error:', error); + res.status(500).json({ + error: 'Internal server error', + error_code: 'INTERNAL_ERROR', + }); + } +}); + +module.exports = router; diff --git a/api/server/routes/agents/__tests__/abort.spec.js b/api/server/routes/agents/__tests__/abort.spec.js new file mode 100644 index 0000000000..442665d973 --- /dev/null +++ b/api/server/routes/agents/__tests__/abort.spec.js @@ -0,0 +1,303 @@ +/** + * Tests for the agent abort endpoint + * + * Tests the following fixes from PR #11462: + * 1. Authorization check - only job owner can abort + * 2. Early abort handling - skip save when no responseMessageId + * 3. Partial response saving - save message before returning + */ + +const express = require('express'); +const request = require('supertest'); + +const mockLogger = { + debug: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + info: jest.fn(), +}; + +const mockGenerationJobManager = { + getJob: jest.fn(), + abortJob: jest.fn(), + getActiveJobIdsForUser: jest.fn(), +}; + +const mockSaveMessage = jest.fn(); + +jest.mock('@librechat/data-schemas', () => ({ + ...jest.requireActual('@librechat/data-schemas'), + logger: mockLogger, +})); + +jest.mock('@librechat/api', () => ({ + ...jest.requireActual('@librechat/api'), + isEnabled: jest.fn().mockReturnValue(false), + GenerationJobManager: mockGenerationJobManager, +})); + +jest.mock('~/models', () => ({ + saveMessage: (...args) => mockSaveMessage(...args), +})); + +jest.mock('~/server/middleware', () => ({ + uaParser: (req, res, next) => next(), + checkBan: (req, res, next) => next(), + requireJwtAuth: (req, res, next) => { + req.user = { id: 'test-user-123' }; + next(); + }, + messageIpLimiter: (req, res, next) => next(), + configMiddleware: (req, res, next) => next(), + messageUserLimiter: (req, res, next) => next(), +})); + +// Mock the chat module - needs to be a router +jest.mock('~/server/routes/agents/chat', () => require('express').Router()); + +// Mock the v1 module - v1 is directly used as middleware +jest.mock('~/server/routes/agents/v1', () => ({ + v1: require('express').Router(), +})); + +// Import after mocks +const agentRoutes = require('~/server/routes/agents/index'); + +describe('Agent Abort Endpoint', () => { + let app; + + beforeAll(() => { + app = express(); + app.use(express.json()); + app.use('/api/agents', agentRoutes); + }); + + beforeEach(() => { + jest.clearAllMocks(); + }); + + describe('POST /chat/abort', () => { + describe('Authorization', () => { + it("should return 403 when user tries to abort another user's job", async () => { + const jobStreamId = 'test-stream-123'; + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'other-user-456' }, + }); + + const response = await request(app) + .post('/api/agents/chat/abort') + .send({ conversationId: jobStreamId }); + + expect(response.status).toBe(403); + expect(response.body).toEqual({ error: 'Unauthorized' }); + expect(mockLogger.warn).toHaveBeenCalledWith( + expect.stringContaining('Unauthorized abort attempt'), + ); + expect(mockGenerationJobManager.abortJob).not.toHaveBeenCalled(); + }); + + it('should allow abort when user owns the job', async () => { + const jobStreamId = 'test-stream-123'; + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'test-user-123' }, + }); + + mockGenerationJobManager.abortJob.mockResolvedValue({ + success: true, + jobData: null, + content: [], + text: '', + }); + + const response = await request(app) + .post('/api/agents/chat/abort') + .send({ conversationId: jobStreamId }); + + expect(response.status).toBe(200); + expect(response.body).toEqual({ success: true, aborted: jobStreamId }); + expect(mockGenerationJobManager.abortJob).toHaveBeenCalledWith(jobStreamId); + }); + + it('should allow abort when job has no userId metadata (backwards compatibility)', async () => { + const jobStreamId = 'test-stream-123'; + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: {}, + }); + + mockGenerationJobManager.abortJob.mockResolvedValue({ + success: true, + jobData: null, + content: [], + text: '', + }); + + const response = await request(app) + .post('/api/agents/chat/abort') + .send({ conversationId: jobStreamId }); + + expect(response.status).toBe(200); + expect(response.body).toEqual({ success: true, aborted: jobStreamId }); + }); + }); + + describe('Early Abort Handling', () => { + it('should skip message saving when responseMessageId is missing (early abort)', async () => { + const jobStreamId = 'test-stream-123'; + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'test-user-123' }, + }); + + mockGenerationJobManager.abortJob.mockResolvedValue({ + success: true, + jobData: { + userMessage: { messageId: 'user-msg-123' }, + // No responseMessageId - early abort before generation started + conversationId: jobStreamId, + }, + content: [], + text: '', + }); + + const response = await request(app) + .post('/api/agents/chat/abort') + .send({ conversationId: jobStreamId }); + + expect(response.status).toBe(200); + expect(mockSaveMessage).not.toHaveBeenCalled(); + }); + + it('should skip message saving when userMessage is missing', async () => { + const jobStreamId = 'test-stream-123'; + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'test-user-123' }, + }); + + mockGenerationJobManager.abortJob.mockResolvedValue({ + success: true, + jobData: { + // No userMessage + responseMessageId: 'response-msg-123', + conversationId: jobStreamId, + }, + content: [], + text: '', + }); + + const response = await request(app) + .post('/api/agents/chat/abort') + .send({ conversationId: jobStreamId }); + + expect(response.status).toBe(200); + expect(mockSaveMessage).not.toHaveBeenCalled(); + }); + }); + + describe('Partial Response Saving', () => { + it('should save partial response when both userMessage and responseMessageId exist', async () => { + const jobStreamId = 'test-stream-123'; + const userMessageId = 'user-msg-123'; + const responseMessageId = 'response-msg-456'; + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'test-user-123' }, + }); + + mockGenerationJobManager.abortJob.mockResolvedValue({ + success: true, + jobData: { + userMessage: { messageId: userMessageId }, + responseMessageId, + conversationId: jobStreamId, + sender: 'TestAgent', + endpoint: 'anthropic', + model: 'claude-3', + }, + content: [{ type: 'text', text: 'Partial response...' }], + text: 'Partial response...', + }); + + mockSaveMessage.mockResolvedValue(); + + const response = await request(app) + .post('/api/agents/chat/abort') + .send({ conversationId: jobStreamId }); + + expect(response.status).toBe(200); + expect(mockSaveMessage).toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ + messageId: responseMessageId, + parentMessageId: userMessageId, + conversationId: jobStreamId, + content: [{ type: 'text', text: 'Partial response...' }], + text: 'Partial response...', + sender: 'TestAgent', + endpoint: 'anthropic', + model: 'claude-3', + unfinished: true, + error: false, + isCreatedByUser: false, + user: 'test-user-123', + }), + expect.objectContaining({ + context: 'api/server/routes/agents/index.js - abort endpoint', + }), + ); + }); + + it('should handle saveMessage errors gracefully', async () => { + const jobStreamId = 'test-stream-123'; + + mockGenerationJobManager.getJob.mockResolvedValue({ + metadata: { userId: 'test-user-123' }, + }); + + mockGenerationJobManager.abortJob.mockResolvedValue({ + success: true, + jobData: { + userMessage: { messageId: 'user-msg-123' }, + responseMessageId: 'response-msg-456', + conversationId: jobStreamId, + }, + content: [], + text: '', + }); + + mockSaveMessage.mockRejectedValue(new Error('Database error')); + + const response = await request(app) + .post('/api/agents/chat/abort') + .send({ conversationId: jobStreamId }); + + // Should still return success even if save fails + expect(response.status).toBe(200); + expect(response.body).toEqual({ success: true, aborted: jobStreamId }); + expect(mockLogger.error).toHaveBeenCalledWith( + expect.stringContaining('Failed to save partial response'), + ); + }); + }); + + describe('Job Not Found', () => { + it('should return 404 when job is not found', async () => { + mockGenerationJobManager.getJob.mockResolvedValue(null); + mockGenerationJobManager.getActiveJobIdsForUser.mockResolvedValue([]); + + const response = await request(app) + .post('/api/agents/chat/abort') + .send({ conversationId: 'non-existent-job' }); + + expect(response.status).toBe(404); + expect(response.body).toEqual({ + error: 'Job not found', + streamId: 'non-existent-job', + }); + }); + }); + }); +}); diff --git a/api/server/routes/agents/__tests__/responses.spec.js b/api/server/routes/agents/__tests__/responses.spec.js new file mode 100644 index 0000000000..4d83219b84 --- /dev/null +++ b/api/server/routes/agents/__tests__/responses.spec.js @@ -0,0 +1,1125 @@ +/** + * Open Responses API Integration Tests + * + * Tests the /v1/responses endpoint against the Open Responses specification + * compliance tests. Uses real Anthropic API for LLM calls. + * + * @see https://openresponses.org/specification + * @see https://github.com/openresponses/openresponses/blob/main/src/lib/compliance-tests.ts + */ + +// Load environment variables from root .env file for API keys +require('dotenv').config({ path: require('path').resolve(__dirname, '../../../../../.env') }); + +const originalEnv = { + CREDS_KEY: process.env.CREDS_KEY, + CREDS_IV: process.env.CREDS_IV, +}; + +process.env.CREDS_KEY = '0123456789abcdef0123456789abcdef'; +process.env.CREDS_IV = '0123456789abcdef'; + +/** Skip tests if ANTHROPIC_API_KEY is not available */ +const SKIP_INTEGRATION_TESTS = !process.env.ANTHROPIC_API_KEY; +if (SKIP_INTEGRATION_TESTS) { + console.warn('ANTHROPIC_API_KEY not found - skipping integration tests'); +} + +jest.mock('meilisearch', () => ({ + MeiliSearch: jest.fn().mockImplementation(() => ({ + getIndex: jest.fn().mockRejectedValue(new Error('mocked')), + index: jest.fn().mockReturnValue({ + getRawInfo: jest.fn().mockResolvedValue({ primaryKey: 'id' }), + updateSettings: jest.fn().mockResolvedValue({}), + addDocuments: jest.fn().mockResolvedValue({}), + updateDocuments: jest.fn().mockResolvedValue({}), + deleteDocument: jest.fn().mockResolvedValue({}), + }), + })), +})); + +jest.mock('~/server/services/Config', () => ({ + loadCustomConfig: jest.fn(() => Promise.resolve({})), + getAppConfig: jest.fn().mockResolvedValue({ + paths: { + uploads: '/tmp', + dist: '/tmp/dist', + fonts: '/tmp/fonts', + assets: '/tmp/assets', + }, + fileStrategy: 'local', + imageOutputType: 'PNG', + endpoints: { + agents: { + allowedProviders: ['anthropic', 'openAI'], + }, + }, + }), + setCachedTools: jest.fn(), + getCachedTools: jest.fn(), + getMCPServerTools: jest.fn().mockReturnValue([]), +})); + +jest.mock('~/app/clients/tools', () => ({ + createOpenAIImageTools: jest.fn(() => []), + createYouTubeTools: jest.fn(() => []), + manifestToolMap: {}, + toolkits: [], +})); + +jest.mock('~/config', () => ({ + createMCPServersRegistry: jest.fn(), + createMCPManager: jest.fn().mockResolvedValue({ + getAppToolFunctions: jest.fn().mockResolvedValue({}), + }), +})); + +const express = require('express'); +const request = require('supertest'); +const mongoose = require('mongoose'); +const { v4: uuidv4 } = require('uuid'); +const { MongoMemoryServer } = require('mongodb-memory-server'); +const { hashToken, getRandomValues, createModels } = require('@librechat/data-schemas'); +const { + SystemRoles, + ResourceType, + AccessRoleIds, + PrincipalType, + PrincipalModel, + PermissionBits, + EModelEndpoint, +} = require('librechat-data-provider'); + +/** @type {import('mongoose').Model} */ +let Agent; +/** @type {import('mongoose').Model} */ +let AgentApiKey; +/** @type {import('mongoose').Model} */ +let User; +/** @type {import('mongoose').Model} */ +let AclEntry; +/** @type {import('mongoose').Model} */ +let AccessRole; + +/** + * Parse SSE stream into events + * @param {string} text - Raw SSE text + * @returns {Array<{event: string, data: unknown}>} + */ +function parseSSEEvents(text) { + const events = []; + const lines = text.split('\n'); + + let currentEvent = ''; + let currentData = ''; + + for (const line of lines) { + if (line.startsWith('event:')) { + currentEvent = line.slice(6).trim(); + } else if (line.startsWith('data:')) { + currentData = line.slice(5).trim(); + } else if (line === '' && currentData) { + if (currentData === '[DONE]') { + events.push({ event: 'done', data: '[DONE]' }); + } else { + try { + const parsed = JSON.parse(currentData); + events.push({ + event: currentEvent || parsed.type || 'unknown', + data: parsed, + }); + } catch { + // Skip unparseable data + } + } + currentEvent = ''; + currentData = ''; + } + } + + return events; +} + +/** + * Valid streaming event types per Open Responses specification + * @see https://github.com/openresponses/openresponses/blob/main/src/lib/sse-parser.ts + */ +const VALID_STREAMING_EVENT_TYPES = new Set([ + // Standard Open Responses events + 'response.created', + 'response.queued', + 'response.in_progress', + 'response.completed', + 'response.failed', + 'response.incomplete', + 'response.output_item.added', + 'response.output_item.done', + 'response.content_part.added', + 'response.content_part.done', + 'response.output_text.delta', + 'response.output_text.done', + 'response.refusal.delta', + 'response.refusal.done', + 'response.function_call_arguments.delta', + 'response.function_call_arguments.done', + 'response.reasoning_summary_part.added', + 'response.reasoning_summary_part.done', + 'response.reasoning.delta', + 'response.reasoning.done', + 'response.reasoning_summary_text.delta', + 'response.reasoning_summary_text.done', + 'response.output_text.annotation.added', + 'error', + // LibreChat extension events (prefixed per Open Responses spec) + // @see https://openresponses.org/specification#extending-streaming-events + 'librechat:attachment', +]); + +/** + * Validate a streaming event against Open Responses spec + * @param {Object} event - Parsed event with data + * @returns {string[]} Array of validation errors + */ +function validateStreamingEvent(event) { + const errors = []; + const data = event.data; + + if (!data || typeof data !== 'object') { + return errors; // Skip non-object data (e.g., [DONE]) + } + + const eventType = data.type; + + // Check event type is valid + if (!VALID_STREAMING_EVENT_TYPES.has(eventType)) { + errors.push(`Invalid event type: ${eventType}`); + return errors; + } + + // Validate required fields based on event type + switch (eventType) { + case 'response.output_text.delta': + if (typeof data.sequence_number !== 'number') { + errors.push('response.output_text.delta: missing sequence_number'); + } + if (typeof data.item_id !== 'string') { + errors.push('response.output_text.delta: missing item_id'); + } + if (typeof data.output_index !== 'number') { + errors.push('response.output_text.delta: missing output_index'); + } + if (typeof data.content_index !== 'number') { + errors.push('response.output_text.delta: missing content_index'); + } + if (typeof data.delta !== 'string') { + errors.push('response.output_text.delta: missing delta'); + } + if (!Array.isArray(data.logprobs)) { + errors.push('response.output_text.delta: missing logprobs array'); + } + break; + + case 'response.output_text.done': + if (typeof data.sequence_number !== 'number') { + errors.push('response.output_text.done: missing sequence_number'); + } + if (typeof data.item_id !== 'string') { + errors.push('response.output_text.done: missing item_id'); + } + if (typeof data.output_index !== 'number') { + errors.push('response.output_text.done: missing output_index'); + } + if (typeof data.content_index !== 'number') { + errors.push('response.output_text.done: missing content_index'); + } + if (typeof data.text !== 'string') { + errors.push('response.output_text.done: missing text'); + } + if (!Array.isArray(data.logprobs)) { + errors.push('response.output_text.done: missing logprobs array'); + } + break; + + case 'response.reasoning.delta': + if (typeof data.sequence_number !== 'number') { + errors.push('response.reasoning.delta: missing sequence_number'); + } + if (typeof data.item_id !== 'string') { + errors.push('response.reasoning.delta: missing item_id'); + } + if (typeof data.output_index !== 'number') { + errors.push('response.reasoning.delta: missing output_index'); + } + if (typeof data.content_index !== 'number') { + errors.push('response.reasoning.delta: missing content_index'); + } + if (typeof data.delta !== 'string') { + errors.push('response.reasoning.delta: missing delta'); + } + break; + + case 'response.reasoning.done': + if (typeof data.sequence_number !== 'number') { + errors.push('response.reasoning.done: missing sequence_number'); + } + if (typeof data.item_id !== 'string') { + errors.push('response.reasoning.done: missing item_id'); + } + if (typeof data.output_index !== 'number') { + errors.push('response.reasoning.done: missing output_index'); + } + if (typeof data.content_index !== 'number') { + errors.push('response.reasoning.done: missing content_index'); + } + if (typeof data.text !== 'string') { + errors.push('response.reasoning.done: missing text'); + } + break; + + case 'response.in_progress': + case 'response.completed': + case 'response.failed': + if (!data.response || typeof data.response !== 'object') { + errors.push(`${eventType}: missing response object`); + } + break; + + case 'response.output_item.added': + case 'response.output_item.done': + if (typeof data.output_index !== 'number') { + errors.push(`${eventType}: missing output_index`); + } + if (!data.item || typeof data.item !== 'object') { + errors.push(`${eventType}: missing item object`); + } + break; + } + + return errors; +} + +/** + * Validate all streaming events and return errors + * @param {Array} events - Array of parsed events + * @returns {string[]} Array of all validation errors + */ +function validateAllStreamingEvents(events) { + const allErrors = []; + for (const event of events) { + const errors = validateStreamingEvent(event); + allErrors.push(...errors); + } + return allErrors; +} + +/** + * Create a test agent with Anthropic provider + * @param {Object} overrides + * @returns {Promise} + */ +async function createTestAgent(overrides = {}) { + const timestamp = new Date(); + const agentData = { + id: `agent_${uuidv4().replace(/-/g, '').substring(0, 21)}`, + name: 'Test Anthropic Agent', + description: 'An agent for testing Open Responses API', + instructions: 'You are a helpful assistant. Be concise.', + provider: EModelEndpoint.anthropic, + model: 'claude-sonnet-4-5-20250929', + author: new mongoose.Types.ObjectId(), + tools: [], + model_parameters: {}, + ...overrides, + }; + + const versionData = { ...agentData }; + delete versionData.author; + + const initialAgentData = { + ...agentData, + versions: [ + { + ...versionData, + createdAt: timestamp, + updatedAt: timestamp, + }, + ], + category: 'general', + }; + + return (await Agent.create(initialAgentData)).toObject(); +} + +/** + * Create an agent with extended thinking enabled + * @param {Object} overrides + * @returns {Promise} + */ +async function createThinkingAgent(overrides = {}) { + return createTestAgent({ + name: 'Test Thinking Agent', + description: 'An agent with extended thinking enabled', + model_parameters: { + thinking: { + type: 'enabled', + budget_tokens: 5000, + }, + }, + ...overrides, + }); +} + +const describeWithApiKey = SKIP_INTEGRATION_TESTS ? describe.skip : describe; + +describeWithApiKey('Open Responses API Integration Tests', () => { + // Increase timeout for real API calls + jest.setTimeout(120000); + + let mongoServer; + let app; + let testAgent; + let thinkingAgent; + let testUser; + let testApiKey; // The raw API key for Authorization header + + afterAll(() => { + process.env.CREDS_KEY = originalEnv.CREDS_KEY; + process.env.CREDS_IV = originalEnv.CREDS_IV; + }); + + beforeAll(async () => { + // Start MongoDB Memory Server + mongoServer = await MongoMemoryServer.create(); + const mongoUri = mongoServer.getUri(); + + // Connect to MongoDB + await mongoose.connect(mongoUri); + + // Register all models + const models = createModels(mongoose); + + // Get models + Agent = models.Agent; + AgentApiKey = models.AgentApiKey; + User = models.User; + AclEntry = models.AclEntry; + AccessRole = models.AccessRole; + + // Create minimal Express app with just the responses routes + app = express(); + app.use(express.json()); + + // Mount the responses routes + const responsesRoutes = require('~/server/routes/agents/responses'); + app.use('/api/agents/v1/responses', responsesRoutes); + + // Create test user + testUser = await User.create({ + name: 'Test API User', + username: 'testapiuser', + email: 'testapiuser@test.com', + emailVerified: true, + provider: 'local', + role: SystemRoles.ADMIN, + }); + + // Create REMOTE_AGENT access roles (if they don't exist) + const existingRoles = await AccessRole.find({ + accessRoleId: { + $in: [ + AccessRoleIds.REMOTE_AGENT_VIEWER, + AccessRoleIds.REMOTE_AGENT_EDITOR, + AccessRoleIds.REMOTE_AGENT_OWNER, + ], + }, + }); + + if (existingRoles.length === 0) { + await AccessRole.create([ + { + accessRoleId: AccessRoleIds.REMOTE_AGENT_VIEWER, + name: 'API Viewer', + description: 'Can query the agent via API', + resourceType: ResourceType.REMOTE_AGENT, + permBits: PermissionBits.VIEW, + }, + { + accessRoleId: AccessRoleIds.REMOTE_AGENT_EDITOR, + name: 'API Editor', + description: 'Can view and modify the agent via API', + resourceType: ResourceType.REMOTE_AGENT, + permBits: PermissionBits.VIEW | PermissionBits.EDIT, + }, + { + accessRoleId: AccessRoleIds.REMOTE_AGENT_OWNER, + name: 'API Owner', + description: 'Full API access + can grant remote access to others', + resourceType: ResourceType.REMOTE_AGENT, + permBits: + PermissionBits.VIEW | + PermissionBits.EDIT | + PermissionBits.DELETE | + PermissionBits.SHARE, + }, + ]); + } + + // Generate and create an API key for the test user + const rawKey = `sk-${await getRandomValues(32)}`; + const keyHash = await hashToken(rawKey); + const keyPrefix = rawKey.substring(0, 8); + + await AgentApiKey.create({ + userId: testUser._id, + name: 'Test API Key', + keyHash, + keyPrefix, + }); + + testApiKey = rawKey; + + // Create test agents with the test user as author + testAgent = await createTestAgent({ author: testUser._id }); + thinkingAgent = await createThinkingAgent({ author: testUser._id }); + + // Grant REMOTE_AGENT permissions for the test agents + await AclEntry.create([ + { + principalType: PrincipalType.USER, + principalModel: PrincipalModel.USER, + principalId: testUser._id, + resourceType: ResourceType.REMOTE_AGENT, + resourceId: testAgent._id, + accessRoleId: AccessRoleIds.REMOTE_AGENT_OWNER, + permBits: + PermissionBits.VIEW | PermissionBits.EDIT | PermissionBits.DELETE | PermissionBits.SHARE, + }, + { + principalType: PrincipalType.USER, + principalModel: PrincipalModel.USER, + principalId: testUser._id, + resourceType: ResourceType.REMOTE_AGENT, + resourceId: thinkingAgent._id, + accessRoleId: AccessRoleIds.REMOTE_AGENT_OWNER, + permBits: + PermissionBits.VIEW | PermissionBits.EDIT | PermissionBits.DELETE | PermissionBits.SHARE, + }, + ]); + }, 60000); + + afterAll(async () => { + await mongoose.disconnect(); + await mongoServer.stop(); + }); + + beforeEach(async () => { + // Clean up any test data between tests if needed + }); + + /* =========================================================================== + * COMPLIANCE TESTS + * Based on: https://github.com/openresponses/openresponses/blob/main/src/lib/compliance-tests.ts + * =========================================================================== */ + + /** Helper to add auth header to requests */ + const authRequest = () => ({ + post: (url) => request(app).post(url).set('Authorization', `Bearer ${testApiKey}`), + get: (url) => request(app).get(url).set('Authorization', `Bearer ${testApiKey}`), + }); + + describe('Compliance Tests', () => { + describe('basic-response', () => { + it('should return a valid ResponseResource for a simple text request', async () => { + const response = await authRequest() + .post('/api/agents/v1/responses') + .send({ + model: testAgent.id, + input: [ + { + type: 'message', + role: 'user', + content: 'Say hello in exactly 3 words.', + }, + ], + }); + + expect(response.status).toBe(200); + expect(response.body).toBeDefined(); + + // Validate ResponseResource schema + const body = response.body; + expect(body.id).toMatch(/^resp_/); + expect(body.object).toBe('response'); + expect(typeof body.created_at).toBe('number'); + expect(body.status).toBe('completed'); + expect(body.model).toBe(testAgent.id); + + // Validate output + expect(Array.isArray(body.output)).toBe(true); + expect(body.output.length).toBeGreaterThan(0); + + // Should have at least one message item + const messageItem = body.output.find((item) => item.type === 'message'); + expect(messageItem).toBeDefined(); + expect(messageItem.role).toBe('assistant'); + expect(messageItem.status).toBe('completed'); + expect(Array.isArray(messageItem.content)).toBe(true); + }); + }); + + describe('streaming-response', () => { + it('should return valid SSE streaming events', async () => { + const response = await authRequest() + .post('/api/agents/v1/responses') + .send({ + model: testAgent.id, + input: [ + { + type: 'message', + role: 'user', + content: 'Count from 1 to 5.', + }, + ], + stream: true, + }) + .buffer(true) + .parse((res, callback) => { + let data = ''; + res.on('data', (chunk) => { + data += chunk.toString(); + }); + res.on('end', () => { + callback(null, data); + }); + }); + + expect(response.status).toBe(200); + expect(response.headers['content-type']).toMatch(/text\/event-stream/); + + const events = parseSSEEvents(response.body); + expect(events.length).toBeGreaterThan(0); + + // Validate all streaming events against Open Responses spec + // This catches issues like: + // - Invalid event types (e.g., response.reasoning_text.delta instead of response.reasoning.delta) + // - Missing required fields (e.g., logprobs on output_text events) + const validationErrors = validateAllStreamingEvents(events); + if (validationErrors.length > 0) { + console.error('Streaming event validation errors:', validationErrors); + } + expect(validationErrors).toEqual([]); + + // Validate streaming event types + const eventTypes = events.map((e) => e.event); + + // Should have response.created first (per Open Responses spec) + expect(eventTypes).toContain('response.created'); + + // Should have response.in_progress + expect(eventTypes).toContain('response.in_progress'); + + // response.created should come before response.in_progress + const createdIdx = eventTypes.indexOf('response.created'); + const inProgressIdx = eventTypes.indexOf('response.in_progress'); + expect(createdIdx).toBeLessThan(inProgressIdx); + + // Should have response.completed or response.failed + expect(eventTypes.some((t) => t === 'response.completed' || t === 'response.failed')).toBe( + true, + ); + + // Should have [DONE] + expect(eventTypes).toContain('done'); + + // Validate response.completed has full response + const completedEvent = events.find((e) => e.event === 'response.completed'); + if (completedEvent) { + expect(completedEvent.data.response).toBeDefined(); + expect(completedEvent.data.response.status).toBe('completed'); + expect(completedEvent.data.response.output.length).toBeGreaterThan(0); + } + }); + + it('should emit valid event types per Open Responses spec', async () => { + const response = await authRequest() + .post('/api/agents/v1/responses') + .send({ + model: testAgent.id, + input: [ + { + type: 'message', + role: 'user', + content: 'Say hi.', + }, + ], + stream: true, + }) + .buffer(true) + .parse((res, callback) => { + let data = ''; + res.on('data', (chunk) => { + data += chunk.toString(); + }); + res.on('end', () => { + callback(null, data); + }); + }); + + expect(response.status).toBe(200); + + const events = parseSSEEvents(response.body); + + // Check all event types are valid + for (const event of events) { + if (event.data && typeof event.data === 'object' && event.data.type) { + expect(VALID_STREAMING_EVENT_TYPES.has(event.data.type)).toBe(true); + } + } + }); + + it('should include logprobs array in output_text events', async () => { + const response = await authRequest() + .post('/api/agents/v1/responses') + .send({ + model: testAgent.id, + input: [ + { + type: 'message', + role: 'user', + content: 'Say one word.', + }, + ], + stream: true, + }) + .buffer(true) + .parse((res, callback) => { + let data = ''; + res.on('data', (chunk) => { + data += chunk.toString(); + }); + res.on('end', () => { + callback(null, data); + }); + }); + + expect(response.status).toBe(200); + + const events = parseSSEEvents(response.body); + + // Find output_text delta/done events and verify logprobs + const textDeltaEvents = events.filter( + (e) => e.data && e.data.type === 'response.output_text.delta', + ); + const textDoneEvents = events.filter( + (e) => e.data && e.data.type === 'response.output_text.done', + ); + + // Should have at least one output_text event + expect(textDeltaEvents.length + textDoneEvents.length).toBeGreaterThan(0); + + // All output_text.delta events must have logprobs array + for (const event of textDeltaEvents) { + expect(Array.isArray(event.data.logprobs)).toBe(true); + } + + // All output_text.done events must have logprobs array + for (const event of textDoneEvents) { + expect(Array.isArray(event.data.logprobs)).toBe(true); + } + }); + }); + + describe('system-prompt', () => { + it('should handle developer role messages in input (as system)', async () => { + // Note: For Anthropic, system messages must be first and there can only be one. + // Since the agent already has instructions, we use 'developer' role which + // gets merged into the system prompt, or we test with a simple user message + // that instructs the behavior. + const response = await authRequest() + .post('/api/agents/v1/responses') + .send({ + model: testAgent.id, + input: [ + { + type: 'message', + role: 'user', + content: 'Pretend you are a pirate and say hello in pirate speak.', + }, + ], + }); + + expect(response.status).toBe(200); + expect(response.body.status).toBe('completed'); + expect(response.body.output.length).toBeGreaterThan(0); + + // The response should reflect the pirate persona + const messageItem = response.body.output.find((item) => item.type === 'message'); + expect(messageItem).toBeDefined(); + expect(messageItem.content.length).toBeGreaterThan(0); + }); + }); + + describe('multi-turn', () => { + it('should handle multi-turn conversation history', async () => { + const response = await authRequest() + .post('/api/agents/v1/responses') + .send({ + model: testAgent.id, + input: [ + { + type: 'message', + role: 'user', + content: 'My name is Alice.', + }, + { + type: 'message', + role: 'assistant', + content: 'Hello Alice! Nice to meet you. How can I help you today?', + }, + { + type: 'message', + role: 'user', + content: 'What is my name?', + }, + ], + }); + + expect(response.status).toBe(200); + expect(response.body.status).toBe('completed'); + + // The response should reference "Alice" + const messageItem = response.body.output.find((item) => item.type === 'message'); + expect(messageItem).toBeDefined(); + + const textContent = messageItem.content.find((c) => c.type === 'output_text'); + expect(textContent).toBeDefined(); + expect(textContent.text.toLowerCase()).toContain('alice'); + }); + }); + + // Note: tool-calling test requires tool setup which may need additional configuration + // Note: image-input test requires vision-capable model + + describe('string-input', () => { + it('should accept simple string input', async () => { + const response = await authRequest().post('/api/agents/v1/responses').send({ + model: testAgent.id, + input: 'Hello!', + }); + + expect(response.status).toBe(200); + expect(response.body.status).toBe('completed'); + expect(response.body.output.length).toBeGreaterThan(0); + }); + }); + }); + + /* =========================================================================== + * EXTENDED THINKING TESTS + * Tests reasoning output from Claude models with extended thinking enabled + * =========================================================================== */ + + describe('Extended Thinking', () => { + it('should return reasoning output when thinking is enabled', async () => { + const response = await authRequest() + .post('/api/agents/v1/responses') + .send({ + model: thinkingAgent.id, + input: [ + { + type: 'message', + role: 'user', + content: 'What is 15 * 7? Think step by step.', + }, + ], + }); + + expect(response.status).toBe(200); + expect(response.body.status).toBe('completed'); + + // Check for reasoning item in output + const reasoningItem = response.body.output.find((item) => item.type === 'reasoning'); + // If reasoning is present, validate its structure per Open Responses spec + // Note: reasoning items do NOT have a 'status' field per the spec + // @see https://github.com/openresponses/openresponses/blob/main/src/generated/kubb/zod/reasoningBodySchema.ts + if (reasoningItem) { + expect(reasoningItem).toHaveProperty('id'); + expect(reasoningItem).toHaveProperty('type', 'reasoning'); + // Note: 'status' is NOT a field on reasoning items per the spec + expect(reasoningItem).toHaveProperty('summary'); + expect(Array.isArray(reasoningItem.summary)).toBe(true); + + // Validate content items + if (reasoningItem.content && reasoningItem.content.length > 0) { + const reasoningContent = reasoningItem.content[0]; + expect(reasoningContent).toHaveProperty('type', 'reasoning_text'); + expect(reasoningContent).toHaveProperty('text'); + } + } + + const messageItem = response.body.output.find((item) => item.type === 'message'); + expect(messageItem).toBeDefined(); + }); + + it('should stream reasoning events when thinking is enabled', async () => { + const response = await authRequest() + .post('/api/agents/v1/responses') + .send({ + model: thinkingAgent.id, + input: [ + { + type: 'message', + role: 'user', + content: 'What is 12 + 8? Think step by step.', + }, + ], + stream: true, + }) + .buffer(true) + .parse((res, callback) => { + let data = ''; + res.on('data', (chunk) => { + data += chunk.toString(); + }); + res.on('end', () => { + callback(null, data); + }); + }); + + expect(response.status).toBe(200); + + const events = parseSSEEvents(response.body); + + // Validate all events against Open Responses spec + const validationErrors = validateAllStreamingEvents(events); + if (validationErrors.length > 0) { + console.error('Reasoning streaming event validation errors:', validationErrors); + } + expect(validationErrors).toEqual([]); + + // Check for reasoning-related events using correct event types per Open Responses spec + // Note: The spec uses response.reasoning.delta NOT response.reasoning_text.delta + const reasoningDeltaEvents = events.filter( + (e) => e.data && e.data.type === 'response.reasoning.delta', + ); + const reasoningDoneEvents = events.filter( + (e) => e.data && e.data.type === 'response.reasoning.done', + ); + + // If reasoning events are present, validate their structure + if (reasoningDeltaEvents.length > 0) { + const deltaEvent = reasoningDeltaEvents[0]; + expect(deltaEvent.data).toHaveProperty('item_id'); + expect(deltaEvent.data).toHaveProperty('delta'); + expect(deltaEvent.data).toHaveProperty('output_index'); + expect(deltaEvent.data).toHaveProperty('content_index'); + expect(deltaEvent.data).toHaveProperty('sequence_number'); + } + + if (reasoningDoneEvents.length > 0) { + const doneEvent = reasoningDoneEvents[0]; + expect(doneEvent.data).toHaveProperty('item_id'); + expect(doneEvent.data).toHaveProperty('text'); + expect(doneEvent.data).toHaveProperty('output_index'); + expect(doneEvent.data).toHaveProperty('content_index'); + expect(doneEvent.data).toHaveProperty('sequence_number'); + } + + // Verify stream completed properly + const eventTypes = events.map((e) => e.event); + expect(eventTypes).toContain('response.completed'); + }); + }); + + /* =========================================================================== + * SCHEMA VALIDATION TESTS + * Verify response schema compliance + * =========================================================================== */ + + describe('Schema Validation', () => { + it('should include all required fields in response', async () => { + const response = await authRequest().post('/api/agents/v1/responses').send({ + model: testAgent.id, + input: 'Test', + }); + + expect(response.status).toBe(200); + const body = response.body; + + // Required fields per Open Responses spec + expect(body).toHaveProperty('id'); + expect(body).toHaveProperty('object', 'response'); + expect(body).toHaveProperty('created_at'); + expect(body).toHaveProperty('completed_at'); + expect(body).toHaveProperty('status'); + expect(body).toHaveProperty('model'); + expect(body).toHaveProperty('output'); + expect(body).toHaveProperty('tools'); + expect(body).toHaveProperty('tool_choice'); + expect(body).toHaveProperty('truncation'); + expect(body).toHaveProperty('parallel_tool_calls'); + expect(body).toHaveProperty('text'); + expect(body).toHaveProperty('temperature'); + expect(body).toHaveProperty('top_p'); + expect(body).toHaveProperty('presence_penalty'); + expect(body).toHaveProperty('frequency_penalty'); + expect(body).toHaveProperty('top_logprobs'); + expect(body).toHaveProperty('store'); + expect(body).toHaveProperty('background'); + expect(body).toHaveProperty('service_tier'); + expect(body).toHaveProperty('metadata'); + + // top_logprobs must be a number (not null) + expect(typeof body.top_logprobs).toBe('number'); + + // Usage must have required detail fields + expect(body).toHaveProperty('usage'); + expect(body.usage).toHaveProperty('input_tokens'); + expect(body.usage).toHaveProperty('output_tokens'); + expect(body.usage).toHaveProperty('total_tokens'); + expect(body.usage).toHaveProperty('input_tokens_details'); + expect(body.usage).toHaveProperty('output_tokens_details'); + expect(body.usage.input_tokens_details).toHaveProperty('cached_tokens'); + expect(body.usage.output_tokens_details).toHaveProperty('reasoning_tokens'); + }); + + it('should have valid message item structure', async () => { + const response = await authRequest().post('/api/agents/v1/responses').send({ + model: testAgent.id, + input: 'Hello', + }); + + expect(response.status).toBe(200); + + const messageItem = response.body.output.find((item) => item.type === 'message'); + expect(messageItem).toBeDefined(); + + // Message item required fields + expect(messageItem).toHaveProperty('type', 'message'); + expect(messageItem).toHaveProperty('id'); + expect(messageItem).toHaveProperty('status'); + expect(messageItem).toHaveProperty('role', 'assistant'); + expect(messageItem).toHaveProperty('content'); + expect(Array.isArray(messageItem.content)).toBe(true); + + // Content part structure - verify all required fields + if (messageItem.content.length > 0) { + const textContent = messageItem.content.find((c) => c.type === 'output_text'); + if (textContent) { + expect(textContent).toHaveProperty('type', 'output_text'); + expect(textContent).toHaveProperty('text'); + expect(textContent).toHaveProperty('annotations'); + expect(textContent).toHaveProperty('logprobs'); + expect(Array.isArray(textContent.annotations)).toBe(true); + expect(Array.isArray(textContent.logprobs)).toBe(true); + } + } + + // Verify reasoning item has required summary field + const reasoningItem = response.body.output.find((item) => item.type === 'reasoning'); + if (reasoningItem) { + expect(reasoningItem).toHaveProperty('type', 'reasoning'); + expect(reasoningItem).toHaveProperty('id'); + expect(reasoningItem).toHaveProperty('summary'); + expect(Array.isArray(reasoningItem.summary)).toBe(true); + } + }); + }); + + /* =========================================================================== + * RESPONSE STORAGE TESTS + * Tests for store: true and GET /v1/responses/:id + * =========================================================================== */ + + describe('Response Storage', () => { + it('should store response when store: true and retrieve it', async () => { + // Create a stored response + const createResponse = await authRequest().post('/api/agents/v1/responses').send({ + model: testAgent.id, + input: 'Remember this: The answer is 42.', + store: true, + }); + + expect(createResponse.status).toBe(200); + expect(createResponse.body.status).toBe('completed'); + + const responseId = createResponse.body.id; + expect(responseId).toMatch(/^resp_/); + + // Small delay to ensure database write completes + await new Promise((resolve) => setTimeout(resolve, 500)); + + // Retrieve the stored response + const getResponseResult = await authRequest().get(`/api/agents/v1/responses/${responseId}`); + + // Note: The response might be stored under conversationId, not responseId + // If we get 404, that's expected behavior for now since we store by conversationId + if (getResponseResult.status === 200) { + expect(getResponseResult.body.object).toBe('response'); + expect(getResponseResult.body.status).toBe('completed'); + expect(getResponseResult.body.output.length).toBeGreaterThan(0); + } + }); + + it('should return 404 for non-existent response', async () => { + const response = await authRequest().get('/api/agents/v1/responses/resp_nonexistent123'); + + expect(response.status).toBe(404); + expect(response.body.error).toBeDefined(); + }); + }); + + /* =========================================================================== + * ERROR HANDLING TESTS + * =========================================================================== */ + + describe('Error Handling', () => { + it('should return error for missing model', async () => { + const response = await authRequest().post('/api/agents/v1/responses').send({ + input: 'Hello', + }); + + expect(response.status).toBe(400); + expect(response.body.error).toBeDefined(); + }); + + it('should return error for missing input', async () => { + const response = await authRequest().post('/api/agents/v1/responses').send({ + model: testAgent.id, + }); + + expect(response.status).toBe(400); + expect(response.body.error).toBeDefined(); + }); + + it('should return error for non-existent agent', async () => { + const response = await authRequest().post('/api/agents/v1/responses').send({ + model: 'agent_nonexistent123456789', + input: 'Hello', + }); + + expect(response.status).toBe(404); + expect(response.body.error).toBeDefined(); + }); + }); + + /* =========================================================================== + * MODELS ENDPOINT TESTS + * =========================================================================== */ + + describe('GET /v1/responses/models', () => { + it('should list available agents as models', async () => { + const response = await authRequest().get('/api/agents/v1/responses/models'); + + expect(response.status).toBe(200); + expect(response.body.object).toBe('list'); + expect(Array.isArray(response.body.data)).toBe(true); + + // Should include our test agent + const foundAgent = response.body.data.find((m) => m.id === testAgent.id); + expect(foundAgent).toBeDefined(); + expect(foundAgent.object).toBe('model'); + expect(foundAgent.name).toBe(testAgent.name); + }); + }); +}); diff --git a/api/server/routes/agents/index.js b/api/server/routes/agents/index.js index 6933a11534..f8d39cb4d8 100644 --- a/api/server/routes/agents/index.js +++ b/api/server/routes/agents/index.js @@ -9,6 +9,9 @@ const { configMiddleware, messageUserLimiter, } = require('~/server/middleware'); +const { saveMessage } = require('~/models'); +const openai = require('./openai'); +const responses = require('./responses'); const { v1 } = require('./v1'); const chat = require('./chat'); @@ -16,6 +19,20 @@ const { LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {}; const router = express.Router(); +/** + * Open Responses API routes (API key authentication handled in route file) + * Mounted at /agents/v1/responses (full path: /api/agents/v1/responses) + * NOTE: Must be mounted BEFORE /v1 to avoid being caught by the less specific route + * @see https://openresponses.org/specification + */ +router.use('/v1/responses', responses); + +/** + * OpenAI-compatible API routes (API key authentication handled in route file) + * Mounted at /agents/v1 (full path: /api/agents/v1/chat/completions) + */ +router.use('/v1', openai); + router.use(requireJwtAuth); router.use(checkBan); router.use(uaParser); @@ -46,6 +63,10 @@ router.get('/chat/stream/:streamId', async (req, res) => { }); } + if (job.metadata?.userId && job.metadata.userId !== req.user.id) { + return res.status(403).json({ error: 'Unauthorized' }); + } + res.setHeader('Content-Encoding', 'identity'); res.setHeader('Content-Type', 'text/event-stream'); res.setHeader('Cache-Control', 'no-cache, no-transform'); @@ -194,9 +215,53 @@ router.post('/chat/abort', async (req, res) => { logger.debug(`[AgentStream] Computed jobStreamId: ${jobStreamId}`); if (job && jobStreamId) { + if (job.metadata?.userId && job.metadata.userId !== userId) { + logger.warn(`[AgentStream] Unauthorized abort attempt for ${jobStreamId} by user ${userId}`); + return res.status(403).json({ error: 'Unauthorized' }); + } + logger.debug(`[AgentStream] Job found, aborting: ${jobStreamId}`); - await GenerationJobManager.abortJob(jobStreamId); - logger.debug(`[AgentStream] Job aborted successfully: ${jobStreamId}`); + const abortResult = await GenerationJobManager.abortJob(jobStreamId); + logger.debug(`[AgentStream] Job aborted successfully: ${jobStreamId}`, { + abortResultSuccess: abortResult.success, + abortResultUserMessageId: abortResult.jobData?.userMessage?.messageId, + abortResultResponseMessageId: abortResult.jobData?.responseMessageId, + }); + + // CRITICAL: Save partial response BEFORE returning to prevent race condition. + // If user sends a follow-up immediately after abort, the parentMessageId must exist in DB. + // Only save if we have a valid responseMessageId (skip early aborts before generation started) + if ( + abortResult.success && + abortResult.jobData?.userMessage?.messageId && + abortResult.jobData?.responseMessageId + ) { + const { jobData, content, text } = abortResult; + const responseMessage = { + messageId: jobData.responseMessageId, + parentMessageId: jobData.userMessage.messageId, + conversationId: jobData.conversationId, + content: content || [], + text: text || '', + sender: jobData.sender || 'AI', + endpoint: jobData.endpoint, + model: jobData.model, + unfinished: true, + error: false, + isCreatedByUser: false, + user: userId, + }; + + try { + await saveMessage(req, responseMessage, { + context: 'api/server/routes/agents/index.js - abort endpoint', + }); + logger.debug(`[AgentStream] Saved partial response for: ${jobStreamId}`); + } catch (saveError) { + logger.error(`[AgentStream] Failed to save partial response: ${saveError.message}`); + } + } + return res.json({ success: true, aborted: jobStreamId }); } diff --git a/api/server/routes/agents/openai.js b/api/server/routes/agents/openai.js new file mode 100644 index 0000000000..9a0d9a3564 --- /dev/null +++ b/api/server/routes/agents/openai.js @@ -0,0 +1,110 @@ +/** + * OpenAI-compatible API routes for LibreChat agents. + * + * Provides a /v1/chat/completions compatible interface for + * interacting with LibreChat agents remotely via API. + * + * Usage: + * POST /v1/chat/completions - Chat with an agent + * GET /v1/models - List available agents + * GET /v1/models/:model - Get agent details + * + * Request format: + * { + * "model": "agent_id_here", + * "messages": [{"role": "user", "content": "Hello!"}], + * "stream": true + * } + */ +const express = require('express'); +const { PermissionTypes, Permissions } = require('librechat-data-provider'); +const { + generateCheckAccess, + createRequireApiKeyAuth, + createCheckRemoteAgentAccess, +} = require('@librechat/api'); +const { + OpenAIChatCompletionController, + ListModelsController, + GetModelController, +} = require('~/server/controllers/agents/openai'); +const { getEffectivePermissions } = require('~/server/services/PermissionService'); +const { validateAgentApiKey, findUser } = require('~/models'); +const { configMiddleware } = require('~/server/middleware'); +const { getRoleByName } = require('~/models/Role'); +const { getAgent } = require('~/models/Agent'); + +const router = express.Router(); + +const requireApiKeyAuth = createRequireApiKeyAuth({ + validateAgentApiKey, + findUser, +}); + +const checkRemoteAgentsFeature = generateCheckAccess({ + permissionType: PermissionTypes.REMOTE_AGENTS, + permissions: [Permissions.USE], + getRoleByName, +}); + +const checkAgentPermission = createCheckRemoteAgentAccess({ + getAgent, + getEffectivePermissions, +}); + +router.use(requireApiKeyAuth); +router.use(configMiddleware); +router.use(checkRemoteAgentsFeature); + +/** + * @route POST /v1/chat/completions + * @desc OpenAI-compatible chat completions with agents + * @access Private (API key auth required) + * + * Request body: + * { + * "model": "agent_id", // Required: The agent ID to use + * "messages": [...], // Required: Array of chat messages + * "stream": true, // Optional: Whether to stream (default: false) + * "conversation_id": "...", // Optional: Conversation ID for context + * "parent_message_id": "..." // Optional: Parent message for threading + * } + * + * Response (streaming): + * - SSE stream with OpenAI chat.completion.chunk format + * - Includes delta.reasoning for thinking/reasoning content + * + * Response (non-streaming): + * - Standard OpenAI chat.completion format + */ +router.post('/chat/completions', checkAgentPermission, OpenAIChatCompletionController); + +/** + * @route GET /v1/models + * @desc List available agents as models + * @access Private (API key auth required) + * + * Response: + * { + * "object": "list", + * "data": [ + * { + * "id": "agent_id", + * "object": "model", + * "name": "Agent Name", + * "provider": "openai", + * ... + * } + * ] + * } + */ +router.get('/models', ListModelsController); + +/** + * @route GET /v1/models/:model + * @desc Get details for a specific agent/model + * @access Private (API key auth required) + */ +router.get('/models/:model', GetModelController); + +module.exports = router; diff --git a/api/server/routes/agents/responses.js b/api/server/routes/agents/responses.js new file mode 100644 index 0000000000..431942e921 --- /dev/null +++ b/api/server/routes/agents/responses.js @@ -0,0 +1,144 @@ +/** + * Open Responses API routes for LibreChat agents. + * + * Implements the Open Responses specification for a forward-looking, + * agentic API that uses items as the fundamental unit and semantic + * streaming events. + * + * Usage: + * POST /v1/responses - Create a response + * GET /v1/models - List available agents + * + * Request format: + * { + * "model": "agent_id_here", + * "input": "Hello!" or [{ type: "message", role: "user", content: "Hello!" }], + * "stream": true, + * "previous_response_id": "optional_conversation_id" + * } + * + * @see https://openresponses.org/specification + */ +const express = require('express'); +const { PermissionTypes, Permissions } = require('librechat-data-provider'); +const { + generateCheckAccess, + createRequireApiKeyAuth, + createCheckRemoteAgentAccess, +} = require('@librechat/api'); +const { + createResponse, + getResponse, + listModels, +} = require('~/server/controllers/agents/responses'); +const { getEffectivePermissions } = require('~/server/services/PermissionService'); +const { validateAgentApiKey, findUser } = require('~/models'); +const { configMiddleware } = require('~/server/middleware'); +const { getRoleByName } = require('~/models/Role'); +const { getAgent } = require('~/models/Agent'); + +const router = express.Router(); + +const requireApiKeyAuth = createRequireApiKeyAuth({ + validateAgentApiKey, + findUser, +}); + +const checkRemoteAgentsFeature = generateCheckAccess({ + permissionType: PermissionTypes.REMOTE_AGENTS, + permissions: [Permissions.USE], + getRoleByName, +}); + +const checkAgentPermission = createCheckRemoteAgentAccess({ + getAgent, + getEffectivePermissions, +}); + +router.use(requireApiKeyAuth); +router.use(configMiddleware); +router.use(checkRemoteAgentsFeature); + +/** + * @route POST /v1/responses + * @desc Create a model response following Open Responses specification + * @access Private (API key auth required) + * + * Request body: + * { + * "model": "agent_id", // Required: The agent ID to use + * "input": "..." | [...], // Required: String or array of input items + * "stream": true, // Optional: Whether to stream (default: false) + * "previous_response_id": "...", // Optional: Previous response for continuation + * "instructions": "...", // Optional: Additional instructions + * "tools": [...], // Optional: Additional tools + * "tool_choice": "auto", // Optional: Tool choice mode + * "max_output_tokens": 4096, // Optional: Max tokens + * "temperature": 0.7 // Optional: Temperature + * } + * + * Response (streaming): + * - SSE stream with semantic events: + * - response.in_progress + * - response.output_item.added + * - response.content_part.added + * - response.output_text.delta + * - response.output_text.done + * - response.function_call_arguments.delta + * - response.output_item.done + * - response.completed + * - [DONE] + * + * Response (non-streaming): + * { + * "id": "resp_xxx", + * "object": "response", + * "created_at": 1234567890, + * "status": "completed", + * "model": "agent_id", + * "output": [...], // Array of output items + * "usage": { ... } + * } + */ +router.post('/', checkAgentPermission, createResponse); + +/** + * @route GET /v1/responses/models + * @desc List available agents as models + * @access Private (API key auth required) + * + * Response: + * { + * "object": "list", + * "data": [ + * { + * "id": "agent_id", + * "object": "model", + * "name": "Agent Name", + * "provider": "openai", + * ... + * } + * ] + * } + */ +router.get('/models', listModels); + +/** + * @route GET /v1/responses/:id + * @desc Retrieve a stored response by ID + * @access Private (API key auth required) + * + * Response: + * { + * "id": "resp_xxx", + * "object": "response", + * "created_at": 1234567890, + * "status": "completed", + * "model": "agent_id", + * "output": [...], + * "usage": { ... } + * } + */ +router.get('/:id', getResponse); + +module.exports = router; diff --git a/api/server/routes/apiKeys.js b/api/server/routes/apiKeys.js new file mode 100644 index 0000000000..29dcc326f5 --- /dev/null +++ b/api/server/routes/apiKeys.js @@ -0,0 +1,36 @@ +const express = require('express'); +const { generateCheckAccess, createApiKeyHandlers } = require('@librechat/api'); +const { PermissionTypes, Permissions } = require('librechat-data-provider'); +const { + getAgentApiKeyById, + createAgentApiKey, + deleteAgentApiKey, + listAgentApiKeys, +} = require('~/models'); +const { requireJwtAuth } = require('~/server/middleware'); +const { getRoleByName } = require('~/models/Role'); + +const router = express.Router(); + +const handlers = createApiKeyHandlers({ + createAgentApiKey, + listAgentApiKeys, + deleteAgentApiKey, + getAgentApiKeyById, +}); + +const checkRemoteAgentsUse = generateCheckAccess({ + permissionType: PermissionTypes.REMOTE_AGENTS, + permissions: [Permissions.USE], + getRoleByName, +}); + +router.post('/', requireJwtAuth, checkRemoteAgentsUse, handlers.createApiKey); + +router.get('/', requireJwtAuth, checkRemoteAgentsUse, handlers.listApiKeys); + +router.get('/:id', requireJwtAuth, checkRemoteAgentsUse, handlers.getApiKey); + +router.delete('/:id', requireJwtAuth, checkRemoteAgentsUse, handlers.deleteApiKey); + +module.exports = router; diff --git a/api/server/routes/convos.js b/api/server/routes/convos.js index 75b3656f59..bb9c4ebea9 100644 --- a/api/server/routes/convos.js +++ b/api/server/routes/convos.js @@ -98,7 +98,7 @@ router.get('/gen_title/:conversationId', async (req, res) => { router.delete('/', async (req, res) => { let filter = {}; - const { conversationId, source, thread_id, endpoint } = req.body.arg; + const { conversationId, source, thread_id, endpoint } = req.body?.arg ?? {}; // Prevent deletion of all conversations if (!conversationId && !source && !thread_id && !endpoint) { @@ -160,7 +160,7 @@ router.delete('/all', async (req, res) => { * @returns {object} 200 - The updated conversation object. */ router.post('/archive', validateConvoAccess, async (req, res) => { - const { conversationId, isArchived } = req.body.arg ?? {}; + const { conversationId, isArchived } = req.body?.arg ?? {}; if (!conversationId) { return res.status(400).json({ error: 'conversationId is required' }); @@ -194,7 +194,7 @@ const MAX_CONVO_TITLE_LENGTH = 1024; * @returns {object} 201 - The updated conversation object. */ router.post('/update', validateConvoAccess, async (req, res) => { - const { conversationId, title } = req.body.arg ?? {}; + const { conversationId, title } = req.body?.arg ?? {}; if (!conversationId) { return res.status(400).json({ error: 'conversationId is required' }); diff --git a/api/server/routes/files/images.js b/api/server/routes/files/images.js index b8be413f4f..8072612a69 100644 --- a/api/server/routes/files/images.js +++ b/api/server/routes/files/images.js @@ -2,11 +2,11 @@ const path = require('path'); const fs = require('fs').promises; const express = require('express'); const { logger } = require('@librechat/data-schemas'); -const { isAgentsEndpoint } = require('librechat-data-provider'); +const { isAssistantsEndpoint } = require('librechat-data-provider'); const { - filterFile, - processImageFile, processAgentFileUpload, + processImageFile, + filterFile, } = require('~/server/services/Files/process'); const router = express.Router(); @@ -21,7 +21,7 @@ router.post('/', async (req, res) => { metadata.temp_file_id = metadata.file_id; metadata.file_id = req.file_id; - if (isAgentsEndpoint(metadata.endpoint) && metadata.tool_resource != null) { + if (!isAssistantsEndpoint(metadata.endpoint) && metadata.tool_resource != null) { return await processAgentFileUpload({ req, res, metadata }); } diff --git a/api/server/routes/index.js b/api/server/routes/index.js index f3571099cb..6a48919db3 100644 --- a/api/server/routes/index.js +++ b/api/server/routes/index.js @@ -1,6 +1,7 @@ const accessPermissions = require('./accessPermissions'); const assistants = require('./assistants'); const categories = require('./categories'); +const adminAuth = require('./admin/auth'); const endpoints = require('./endpoints'); const staticRoute = require('./static'); const messages = require('./messages'); @@ -9,6 +10,7 @@ const presets = require('./presets'); const prompts = require('./prompts'); const balance = require('./balance'); const actions = require('./actions'); +const apiKeys = require('./apiKeys'); const banner = require('./banner'); const search = require('./search'); const models = require('./models'); @@ -28,7 +30,9 @@ const mcp = require('./mcp'); module.exports = { mcp, auth, + adminAuth, keys, + apiKeys, user, tags, roles, diff --git a/api/server/routes/keys.js b/api/server/routes/keys.js index 620e4d234b..dfd68f69c4 100644 --- a/api/server/routes/keys.js +++ b/api/server/routes/keys.js @@ -5,7 +5,11 @@ const { requireJwtAuth } = require('~/server/middleware'); const router = express.Router(); router.put('/', requireJwtAuth, async (req, res) => { - await updateUserKey({ userId: req.user.id, ...req.body }); + if (req.body == null || typeof req.body !== 'object') { + return res.status(400).send({ error: 'Invalid request body.' }); + } + const { name, value, expiresAt } = req.body; + await updateUserKey({ userId: req.user.id, name, value, expiresAt }); res.status(201).send(); }); diff --git a/api/server/routes/mcp.js b/api/server/routes/mcp.js index f01c7ff71c..2db8c2c462 100644 --- a/api/server/routes/mcp.js +++ b/api/server/routes/mcp.js @@ -8,18 +8,32 @@ const { Permissions, } = require('librechat-data-provider'); const { + getBasePath, createSafeUser, MCPOAuthHandler, MCPTokenStorage, - getBasePath, + setOAuthSession, getUserMCPAuthMap, + validateOAuthCsrf, + OAUTH_CSRF_COOKIE, + setOAuthCsrfCookie, generateCheckAccess, + validateOAuthSession, + OAUTH_SESSION_COOKIE, } = require('@librechat/api'); const { - getMCPManager, - getFlowStateManager, + createMCPServerController, + updateMCPServerController, + deleteMCPServerController, + getMCPServersList, + getMCPServerById, + getMCPTools, +} = require('~/server/controllers/mcp'); +const { getOAuthReconnectionManager, getMCPServersRegistry, + getFlowStateManager, + getMCPManager, } = require('~/config'); const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP'); const { requireJwtAuth, canAccessMCPServerResource } = require('~/server/middleware'); @@ -27,20 +41,14 @@ const { findToken, updateToken, createToken, deleteTokens } = require('~/models' const { getUserPluginAuthValue } = require('~/server/services/PluginService'); const { updateMCPServerTools } = require('~/server/services/Config/mcp'); const { reinitMCPServer } = require('~/server/services/Tools/mcp'); -const { getMCPTools } = require('~/server/controllers/mcp'); const { findPluginAuthsByKeys } = require('~/models'); const { getRoleByName } = require('~/models/Role'); const { getLogStores } = require('~/cache'); -const { - createMCPServerController, - getMCPServerById, - getMCPServersList, - updateMCPServerController, - deleteMCPServerController, -} = require('~/server/controllers/mcp'); const router = Router(); +const OAUTH_CSRF_COOKIE_PATH = '/api/mcp'; + /** * Get all MCP tools available to the user * Returns only MCP tools, completely decoupled from regular LibreChat tools @@ -53,7 +61,7 @@ router.get('/tools', requireJwtAuth, async (req, res) => { * Initiate OAuth flow * This endpoint is called when the user clicks the auth link in the UI */ -router.get('/:serverName/oauth/initiate', requireJwtAuth, async (req, res) => { +router.get('/:serverName/oauth/initiate', requireJwtAuth, setOAuthSession, async (req, res) => { try { const { serverName } = req.params; const { userId, flowId } = req.query; @@ -93,7 +101,7 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, async (req, res) => { logger.debug('[MCP OAuth] OAuth flow initiated', { oauthFlowId, authorizationUrl }); - // Redirect user to the authorization URL + setOAuthCsrfCookie(res, oauthFlowId, OAUTH_CSRF_COOKIE_PATH); res.redirect(authorizationUrl); } catch (error) { logger.error('[MCP OAuth] Failed to initiate OAuth', error); @@ -138,6 +146,25 @@ router.get('/:serverName/oauth/callback', async (req, res) => { const flowId = state; logger.debug('[MCP OAuth] Using flow ID from state', { flowId }); + const flowParts = flowId.split(':'); + if (flowParts.length < 2 || !flowParts[0] || !flowParts[1]) { + logger.error('[MCP OAuth] Invalid flow ID format in state', { flowId }); + return res.redirect(`${basePath}/oauth/error?error=invalid_state`); + } + + const [flowUserId] = flowParts; + if ( + !validateOAuthCsrf(req, res, flowId, OAUTH_CSRF_COOKIE_PATH) && + !validateOAuthSession(req, flowUserId) + ) { + logger.error('[MCP OAuth] CSRF validation failed: no valid CSRF or session cookie', { + flowId, + hasCsrfCookie: !!req.cookies?.[OAUTH_CSRF_COOKIE], + hasSessionCookie: !!req.cookies?.[OAUTH_SESSION_COOKIE], + }); + return res.redirect(`${basePath}/oauth/error?error=csrf_validation_failed`); + } + const flowsCache = getLogStores(CacheKeys.FLOWS); const flowManager = getFlowStateManager(flowsCache); @@ -302,13 +329,47 @@ router.get('/oauth/tokens/:flowId', requireJwtAuth, async (req, res) => { } }); +/** + * Set CSRF binding cookie for OAuth flows initiated outside of HTTP request/response + * (e.g. during chat via SSE). The frontend should call this before opening the OAuth URL + * so the callback can verify the browser matches the flow initiator. + */ +router.post('/:serverName/oauth/bind', requireJwtAuth, setOAuthSession, async (req, res) => { + try { + const { serverName } = req.params; + const user = req.user; + + if (!user?.id) { + return res.status(401).json({ error: 'User not authenticated' }); + } + + const flowId = MCPOAuthHandler.generateFlowId(user.id, serverName); + setOAuthCsrfCookie(res, flowId, OAUTH_CSRF_COOKIE_PATH); + + res.json({ success: true }); + } catch (error) { + logger.error('[MCP OAuth] Failed to set CSRF binding cookie', error); + res.status(500).json({ error: 'Failed to bind OAuth flow' }); + } +}); + /** * Check OAuth flow status * This endpoint can be used to poll the status of an OAuth flow */ -router.get('/oauth/status/:flowId', async (req, res) => { +router.get('/oauth/status/:flowId', requireJwtAuth, async (req, res) => { try { const { flowId } = req.params; + const user = req.user; + + if (!user?.id) { + return res.status(401).json({ error: 'User not authenticated' }); + } + + if (!flowId.startsWith(`${user.id}:`) && !flowId.startsWith('system:')) { + return res.status(403).json({ error: 'Access denied' }); + } + const flowsCache = getLogStores(CacheKeys.FLOWS); const flowManager = getFlowStateManager(flowsCache); @@ -375,7 +436,7 @@ router.post('/oauth/cancel/:serverName', requireJwtAuth, async (req, res) => { * Reinitialize MCP server * This endpoint allows reinitializing a specific MCP server */ -router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => { +router.post('/:serverName/reinitialize', requireJwtAuth, setOAuthSession, async (req, res) => { try { const { serverName } = req.params; const user = createSafeUser(req.user); @@ -421,6 +482,11 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => { const { success, message, oauthRequired, oauthUrl } = result; + if (oauthRequired) { + const flowId = MCPOAuthHandler.generateFlowId(user.id, serverName); + setOAuthCsrfCookie(res, flowId, OAUTH_CSRF_COOKIE_PATH); + } + res.json({ success, message, diff --git a/api/server/routes/oauth.js b/api/server/routes/oauth.js index 64d29210ac..f4bb5b6026 100644 --- a/api/server/routes/oauth.js +++ b/api/server/routes/oauth.js @@ -4,10 +4,9 @@ const passport = require('passport'); const { randomState } = require('openid-client'); const { logger } = require('@librechat/data-schemas'); const { ErrorTypes } = require('librechat-data-provider'); -const { isEnabled, createSetBalanceConfig } = require('@librechat/api'); -const { checkDomainAllowed, loginLimiter, logHeaders, checkBan } = require('~/server/middleware'); -const { syncUserEntraGroupMemberships } = require('~/server/services/PermissionService'); -const { setAuthTokens, setOpenIDAuthTokens } = require('~/server/services/AuthService'); +const { createSetBalanceConfig } = require('@librechat/api'); +const { checkDomainAllowed, loginLimiter, logHeaders } = require('~/server/middleware'); +const { createOAuthHandler } = require('~/server/controllers/auth/oauth'); const { getAppConfig } = require('~/server/services/Config'); const { Balance } = require('~/db/models'); @@ -26,36 +25,11 @@ const domains = { router.use(logHeaders); router.use(loginLimiter); -const oauthHandler = async (req, res, next) => { - try { - if (res.headersSent) { - return; - } - - await checkBan(req, res); - if (req.banned) { - return; - } - if ( - req.user && - req.user.provider == 'openid' && - isEnabled(process.env.OPENID_REUSE_TOKENS) === true - ) { - await syncUserEntraGroupMemberships(req.user, req.user.tokenset.access_token); - setOpenIDAuthTokens(req.user.tokenset, req, res, req.user._id.toString()); - } else { - await setAuthTokens(req.user._id, res); - } - res.redirect(domains.client); - } catch (err) { - logger.error('Error in setting authentication tokens:', err); - next(err); - } -}; +const oauthHandler = createOAuthHandler(); router.get('/error', (req, res) => { /** A single error message is pushed by passport when authentication fails. */ - const errorMessage = req.session?.messages?.pop() || 'Unknown error'; + const errorMessage = req.session?.messages?.pop() || 'Unknown OAuth error'; logger.error('Error in OAuth authentication:', { message: errorMessage, }); diff --git a/api/server/routes/roles.js b/api/server/routes/roles.js index abb53141bd..12e18c7624 100644 --- a/api/server/routes/roles.js +++ b/api/server/routes/roles.js @@ -6,9 +6,10 @@ const { agentPermissionsSchema, promptPermissionsSchema, memoryPermissionsSchema, + mcpServersPermissionsSchema, marketplacePermissionsSchema, peoplePickerPermissionsSchema, - mcpServersPermissionsSchema, + remoteAgentsPermissionsSchema, } = require('librechat-data-provider'); const { checkAdmin, requireJwtAuth } = require('~/server/middleware'); const { updateRoleByName, getRoleByName } = require('~/models/Role'); @@ -51,6 +52,11 @@ const permissionConfigs = { permissionType: PermissionTypes.MARKETPLACE, errorMessage: 'Invalid marketplace permissions.', }, + 'remote-agents': { + schema: remoteAgentsPermissionsSchema, + permissionType: PermissionTypes.REMOTE_AGENTS, + errorMessage: 'Invalid remote agents permissions.', + }, }; /** @@ -160,4 +166,10 @@ router.put('/:roleName/mcp-servers', checkAdmin, createPermissionUpdateHandler(' */ router.put('/:roleName/marketplace', checkAdmin, createPermissionUpdateHandler('marketplace')); +/** + * PUT /api/roles/:roleName/remote-agents + * Update remote agents (API) permissions for a specific role + */ +router.put('/:roleName/remote-agents', checkAdmin, createPermissionUpdateHandler('remote-agents')); + module.exports = router; diff --git a/api/server/services/ActionService.js b/api/server/services/ActionService.js index a2a515d14a..5e96726a46 100644 --- a/api/server/services/ActionService.js +++ b/api/server/services/ActionService.js @@ -8,6 +8,7 @@ const { logAxiosError, refreshAccessToken, GenerationJobManager, + createSSRFSafeAgents, } = require('@librechat/api'); const { Time, @@ -133,6 +134,7 @@ async function loadActionSets(searchParams) { * @param {import('zod').ZodTypeAny | undefined} [params.zodSchema] - The Zod schema for tool input validation/definition * @param {{ oauth_client_id?: string; oauth_client_secret?: string; }} params.encrypted - The encrypted values for the action. * @param {string | null} [params.streamId] - The stream ID for resumable streams. + * @param {boolean} [params.useSSRFProtection] - When true, uses SSRF-safe HTTP agents that validate resolved IPs at connect time. * @returns { Promise unknown}> } An object with `_call` method to execute the tool input. */ async function createActionTool({ @@ -145,7 +147,9 @@ async function createActionTool({ description, encrypted, streamId = null, + useSSRFProtection = false, }) { + const ssrfAgents = useSSRFProtection ? createSSRFSafeAgents() : undefined; /** @type {(toolInput: Object | string, config: GraphRunnableConfig) => Promise} */ const _call = async (toolInput, config) => { try { @@ -201,7 +205,7 @@ async function createActionTool({ async () => { const eventData = { event: GraphEvents.ON_RUN_STEP_DELTA, data }; if (streamId) { - GenerationJobManager.emitChunk(streamId, eventData); + await GenerationJobManager.emitChunk(streamId, eventData); } else { sendEvent(res, eventData); } @@ -231,7 +235,7 @@ async function createActionTool({ data.delta.expires_at = undefined; const successEventData = { event: GraphEvents.ON_RUN_STEP_DELTA, data }; if (streamId) { - GenerationJobManager.emitChunk(streamId, successEventData); + await GenerationJobManager.emitChunk(streamId, successEventData); } else { sendEvent(res, successEventData); } @@ -324,7 +328,7 @@ async function createActionTool({ } } - const response = await preparedExecutor.execute(); + const response = await preparedExecutor.execute(ssrfAgents); if (typeof response.data === 'object') { return JSON.stringify(response.data); diff --git a/api/server/services/Artifacts/update.js b/api/server/services/Artifacts/update.js index d068593f8c..be1644b11c 100644 --- a/api/server/services/Artifacts/update.js +++ b/api/server/services/Artifacts/update.js @@ -73,15 +73,25 @@ const replaceArtifactContent = (originalText, artifact, original, updated) => { return null; } - // Check if there are code blocks - const codeBlockStart = artifactContent.indexOf('```\n', contentStart); + // Check if there are code blocks - handle both ```\n and ```lang\n formats + let codeBlockStart = artifactContent.indexOf('```', contentStart); const codeBlockEnd = artifactContent.lastIndexOf('\n```', contentEnd); + // If we found opening backticks, find the actual newline (skipping any language identifier) + if (codeBlockStart !== -1) { + const newlineAfterBackticks = artifactContent.indexOf('\n', codeBlockStart); + if (newlineAfterBackticks !== -1 && newlineAfterBackticks < contentEnd) { + codeBlockStart = newlineAfterBackticks; + } else { + codeBlockStart = -1; + } + } + // Determine where to look for the original content let searchStart, searchEnd; if (codeBlockStart !== -1) { - // Code block starts - searchStart = codeBlockStart + 4; // after ```\n + // Code block starts - searchStart is right after the newline following ```[lang] + searchStart = codeBlockStart + 1; // after the newline if (codeBlockEnd !== -1 && codeBlockEnd > codeBlockStart) { // Code block has proper ending diff --git a/api/server/services/Artifacts/update.spec.js b/api/server/services/Artifacts/update.spec.js index 2a3e0bbe39..39a4f02863 100644 --- a/api/server/services/Artifacts/update.spec.js +++ b/api/server/services/Artifacts/update.spec.js @@ -494,5 +494,268 @@ ${original}`; /```\n {2}function test\(\) \{\n {4}return \{\n {6}value: 100\n {4}\};\n {2}\}\n```/, ); }); + + test('should handle code blocks with language identifiers (```svg, ```html, etc.)', () => { + const svgContent = ` + + +`; + + /** Artifact with language identifier in code block */ + const artifactText = `${ARTIFACT_START}{identifier="test-svg" type="image/svg+xml" title="Test SVG"} +\`\`\`svg +${svgContent} +\`\`\` +${ARTIFACT_END}`; + + const message = { text: artifactText }; + const artifacts = findAllArtifacts(message); + expect(artifacts).toHaveLength(1); + + const updatedSvg = svgContent.replace('#FFFFFF', '#131313'); + const result = replaceArtifactContent(artifactText, artifacts[0], svgContent, updatedSvg); + + expect(result).not.toBeNull(); + expect(result).toContain('#131313'); + expect(result).not.toContain('#FFFFFF'); + expect(result).toMatch(/```svg\n/); + }); + + test('should handle code blocks with complex language identifiers', () => { + const htmlContent = ` + +Test +Hello +`; + + const artifactText = `${ARTIFACT_START}{identifier="test-html" type="text/html" title="Test HTML"} +\`\`\`html +${htmlContent} +\`\`\` +${ARTIFACT_END}`; + + const message = { text: artifactText }; + const artifacts = findAllArtifacts(message); + + const updatedHtml = htmlContent.replace('Hello', 'Updated'); + const result = replaceArtifactContent(artifactText, artifacts[0], htmlContent, updatedHtml); + + expect(result).not.toBeNull(); + expect(result).toContain('Updated'); + expect(result).toMatch(/```html\n/); + }); + }); + + describe('code block edge cases', () => { + test('should handle code block without language identifier (```\\n)', () => { + const content = 'const x = 1;\nconst y = 2;'; + const artifactText = `${ARTIFACT_START}{identifier="test" type="text/plain" title="Test"} +\`\`\` +${content} +\`\`\` +${ARTIFACT_END}`; + + const message = { text: artifactText }; + const artifacts = findAllArtifacts(message); + + const result = replaceArtifactContent(artifactText, artifacts[0], content, 'updated'); + + expect(result).not.toBeNull(); + expect(result).toContain('updated'); + expect(result).toMatch(/```\nupdated\n```/); + }); + + test('should handle various language identifiers', () => { + const languages = [ + 'javascript', + 'typescript', + 'python', + 'jsx', + 'tsx', + 'css', + 'json', + 'xml', + 'markdown', + 'md', + ]; + + for (const lang of languages) { + const content = `test content for ${lang}`; + const artifactText = `${ARTIFACT_START}{identifier="test-${lang}" type="text/plain" title="Test"} +\`\`\`${lang} +${content} +\`\`\` +${ARTIFACT_END}`; + + const message = { text: artifactText }; + const artifacts = findAllArtifacts(message); + expect(artifacts).toHaveLength(1); + + const result = replaceArtifactContent(artifactText, artifacts[0], content, 'updated'); + + expect(result).not.toBeNull(); + expect(result).toContain('updated'); + expect(result).toMatch(new RegExp(`\`\`\`${lang}\\n`)); + } + }); + + test('should handle single character language identifier', () => { + const content = 'single char lang'; + const artifactText = `${ARTIFACT_START}{identifier="test" type="text/plain" title="Test"} +\`\`\`r +${content} +\`\`\` +${ARTIFACT_END}`; + + const message = { text: artifactText }; + const artifacts = findAllArtifacts(message); + + const result = replaceArtifactContent(artifactText, artifacts[0], content, 'updated'); + + expect(result).not.toBeNull(); + expect(result).toContain('updated'); + expect(result).toMatch(/```r\n/); + }); + + test('should handle code block with content that looks like code fence', () => { + const content = 'Line 1\nSome text with ``` backticks in middle\nLine 3'; + const artifactText = `${ARTIFACT_START}{identifier="test" type="text/plain" title="Test"} +\`\`\`text +${content} +\`\`\` +${ARTIFACT_END}`; + + const message = { text: artifactText }; + const artifacts = findAllArtifacts(message); + + const result = replaceArtifactContent(artifactText, artifacts[0], content, 'updated'); + + expect(result).not.toBeNull(); + expect(result).toContain('updated'); + }); + + test('should handle code block with trailing whitespace in language line', () => { + const content = 'whitespace test'; + /** Note: trailing spaces after 'python' */ + const artifactText = `${ARTIFACT_START}{identifier="test" type="text/plain" title="Test"} +\`\`\`python +${content} +\`\`\` +${ARTIFACT_END}`; + + const message = { text: artifactText }; + const artifacts = findAllArtifacts(message); + + const result = replaceArtifactContent(artifactText, artifacts[0], content, 'updated'); + + expect(result).not.toBeNull(); + expect(result).toContain('updated'); + }); + + test('should handle react/jsx content with complex syntax', () => { + const jsxContent = `function App() { + const [count, setCount] = useState(0); + return ( +
+

Count: {count}

+ +
+ ); +}`; + + const artifactText = `${ARTIFACT_START}{identifier="react-app" type="application/vnd.react" title="React App"} +\`\`\`jsx +${jsxContent} +\`\`\` +${ARTIFACT_END}`; + + const message = { text: artifactText }; + const artifacts = findAllArtifacts(message); + + const updatedJsx = jsxContent.replace('Increment', 'Click me'); + const result = replaceArtifactContent(artifactText, artifacts[0], jsxContent, updatedJsx); + + expect(result).not.toBeNull(); + expect(result).toContain('Click me'); + expect(result).not.toContain('Increment'); + expect(result).toMatch(/```jsx\n/); + }); + + test('should handle mermaid diagram content', () => { + const mermaidContent = `graph TD + A[Start] --> B{Is it?} + B -->|Yes| C[OK] + B -->|No| D[End]`; + + const artifactText = `${ARTIFACT_START}{identifier="diagram" type="application/vnd.mermaid" title="Flow"} +\`\`\`mermaid +${mermaidContent} +\`\`\` +${ARTIFACT_END}`; + + const message = { text: artifactText }; + const artifacts = findAllArtifacts(message); + + const updatedMermaid = mermaidContent.replace('Start', 'Begin'); + const result = replaceArtifactContent( + artifactText, + artifacts[0], + mermaidContent, + updatedMermaid, + ); + + expect(result).not.toBeNull(); + expect(result).toContain('Begin'); + expect(result).toMatch(/```mermaid\n/); + }); + + test('should handle artifact without code block (plain text)', () => { + const content = 'Just plain text without code fences'; + const artifactText = `${ARTIFACT_START}{identifier="plain" type="text/plain" title="Plain"} +${content} +${ARTIFACT_END}`; + + const message = { text: artifactText }; + const artifacts = findAllArtifacts(message); + + const result = replaceArtifactContent( + artifactText, + artifacts[0], + content, + 'updated plain text', + ); + + expect(result).not.toBeNull(); + expect(result).toContain('updated plain text'); + expect(result).not.toContain('```'); + }); + + test('should handle multiline content with various newline patterns', () => { + const content = `Line 1 +Line 2 + +Line 4 after empty line + Indented line + Double indented`; + + const artifactText = `${ARTIFACT_START}{identifier="test" type="text/plain" title="Test"} +\`\`\` +${content} +\`\`\` +${ARTIFACT_END}`; + + const message = { text: artifactText }; + const artifacts = findAllArtifacts(message); + + const updated = content.replace('Line 1', 'First Line'); + const result = replaceArtifactContent(artifactText, artifacts[0], content, updated); + + expect(result).not.toBeNull(); + expect(result).toContain('First Line'); + expect(result).toContain(' Indented line'); + expect(result).toContain(' Double indented'); + }); }); }); diff --git a/api/server/services/AuthService.js b/api/server/services/AuthService.js index a400bce8b7..ef50a365b9 100644 --- a/api/server/services/AuthService.js +++ b/api/server/services/AuthService.js @@ -7,7 +7,13 @@ const { DEFAULT_REFRESH_TOKEN_EXPIRY, } = require('@librechat/data-schemas'); const { ErrorTypes, SystemRoles, errorsToString } = require('librechat-data-provider'); -const { isEnabled, checkEmailConfig, isEmailDomainAllowed, math } = require('@librechat/api'); +const { + math, + isEnabled, + checkEmailConfig, + isEmailDomainAllowed, + shouldUseSecureCookie, +} = require('@librechat/api'); const { findUser, findToken, @@ -33,7 +39,6 @@ const domains = { server: process.env.DOMAIN_SERVER, }; -const isProduction = process.env.NODE_ENV === 'production'; const genericVerificationMessage = 'Please check your email to verify your email address.'; /** @@ -392,13 +397,13 @@ const setAuthTokens = async (userId, res, _session = null) => { res.cookie('refreshToken', refreshToken, { expires: new Date(refreshTokenExpires), httpOnly: true, - secure: isProduction, + secure: shouldUseSecureCookie(), sameSite: 'strict', }); res.cookie('token_provider', 'librechat', { expires: new Date(refreshTokenExpires), httpOnly: true, - secure: isProduction, + secure: shouldUseSecureCookie(), sameSite: 'strict', }); return token; @@ -419,7 +424,7 @@ const setAuthTokens = async (userId, res, _session = null) => { * @param {Object} req - request object (for session access) * @param {Object} res - response object * @param {string} [userId] - Optional MongoDB user ID for image path validation - * @returns {String} - access token + * @returns {String} - id_token (preferred) or access_token as the app auth token */ const setOpenIDAuthTokens = (tokenset, req, res, userId, existingRefreshToken) => { try { @@ -448,34 +453,62 @@ const setOpenIDAuthTokens = (tokenset, req, res, userId, existingRefreshToken) = return; } + /** + * Use id_token as the app authentication token (Bearer token for JWKS validation). + * The id_token is always a standard JWT signed by the IdP's JWKS keys with the app's + * client_id as audience. The access_token may be opaque or intended for a different + * audience (e.g., Microsoft Graph API), which fails JWKS validation. + * Falls back to access_token for providers where id_token is not available. + */ + const appAuthToken = tokenset.id_token || tokenset.access_token; + + /** + * Always set refresh token cookie so it survives express session expiry. + * The session cookie maxAge (SESSION_EXPIRY, default 15 min) is typically shorter + * than the OIDC token lifetime (~1 hour). Without this cookie fallback, the refresh + * token stored only in the session is lost when the session expires, causing the user + * to be signed out on the next token refresh attempt. + * The refresh token is small (opaque string) so it doesn't hit the HTTP/2 header + * size limits that motivated session storage for the larger access_token/id_token. + */ + res.cookie('refreshToken', refreshToken, { + expires: expirationDate, + httpOnly: true, + secure: shouldUseSecureCookie(), + sameSite: 'strict', + }); + /** Store tokens server-side in session to avoid large cookies */ if (req.session) { req.session.openidTokens = { accessToken: tokenset.access_token, + idToken: tokenset.id_token, refreshToken: refreshToken, expiresAt: expirationDate.getTime(), }; } else { logger.warn('[setOpenIDAuthTokens] No session available, falling back to cookies'); - res.cookie('refreshToken', refreshToken, { - expires: expirationDate, - httpOnly: true, - secure: isProduction, - sameSite: 'strict', - }); res.cookie('openid_access_token', tokenset.access_token, { expires: expirationDate, httpOnly: true, - secure: isProduction, + secure: shouldUseSecureCookie(), sameSite: 'strict', }); + if (tokenset.id_token) { + res.cookie('openid_id_token', tokenset.id_token, { + expires: expirationDate, + httpOnly: true, + secure: shouldUseSecureCookie(), + sameSite: 'strict', + }); + } } /** Small cookie to indicate token provider (required for auth middleware) */ res.cookie('token_provider', 'openid', { expires: expirationDate, httpOnly: true, - secure: isProduction, + secure: shouldUseSecureCookie(), sameSite: 'strict', }); if (userId && isEnabled(process.env.OPENID_REUSE_TOKENS)) { @@ -486,11 +519,11 @@ const setOpenIDAuthTokens = (tokenset, req, res, userId, existingRefreshToken) = res.cookie('openid_user_id', signedUserId, { expires: expirationDate, httpOnly: true, - secure: isProduction, + secure: shouldUseSecureCookie(), sameSite: 'strict', }); } - return tokenset.access_token; + return appAuthToken; } catch (error) { logger.error('[setOpenIDAuthTokens] Error in setting authentication tokens:', error); throw error; diff --git a/api/server/services/AuthService.spec.js b/api/server/services/AuthService.spec.js new file mode 100644 index 0000000000..da78f8d775 --- /dev/null +++ b/api/server/services/AuthService.spec.js @@ -0,0 +1,269 @@ +jest.mock('@librechat/data-schemas', () => ({ + logger: { info: jest.fn(), warn: jest.fn(), debug: jest.fn(), error: jest.fn() }, + DEFAULT_SESSION_EXPIRY: 900000, + DEFAULT_REFRESH_TOKEN_EXPIRY: 604800000, +})); +jest.mock('librechat-data-provider', () => ({ + ErrorTypes: {}, + SystemRoles: { USER: 'USER', ADMIN: 'ADMIN' }, + errorsToString: jest.fn(), +})); +jest.mock('@librechat/api', () => ({ + isEnabled: jest.fn((val) => val === 'true' || val === true), + checkEmailConfig: jest.fn(), + isEmailDomainAllowed: jest.fn(), + math: jest.fn((val, fallback) => (val ? Number(val) : fallback)), + shouldUseSecureCookie: jest.fn(() => false), +})); +jest.mock('~/models', () => ({ + findUser: jest.fn(), + findToken: jest.fn(), + createUser: jest.fn(), + updateUser: jest.fn(), + countUsers: jest.fn(), + getUserById: jest.fn(), + findSession: jest.fn(), + createToken: jest.fn(), + deleteTokens: jest.fn(), + deleteSession: jest.fn(), + createSession: jest.fn(), + generateToken: jest.fn(), + deleteUserById: jest.fn(), + generateRefreshToken: jest.fn(), +})); +jest.mock('~/strategies/validators', () => ({ registerSchema: { parse: jest.fn() } })); +jest.mock('~/server/services/Config', () => ({ getAppConfig: jest.fn() })); +jest.mock('~/server/utils', () => ({ sendEmail: jest.fn() })); + +const { shouldUseSecureCookie } = require('@librechat/api'); +const { setOpenIDAuthTokens } = require('./AuthService'); + +/** Helper to build a mock Express response */ +function mockResponse() { + const cookies = {}; + const res = { + cookie: jest.fn((name, value, options) => { + cookies[name] = { value, options }; + }), + _cookies: cookies, + }; + return res; +} + +/** Helper to build a mock Express request with session */ +function mockRequest(sessionData = {}) { + return { + session: { openidTokens: null, ...sessionData }, + }; +} + +describe('setOpenIDAuthTokens', () => { + const env = process.env; + + beforeEach(() => { + jest.clearAllMocks(); + process.env = { + ...env, + JWT_REFRESH_SECRET: 'test-refresh-secret', + OPENID_REUSE_TOKENS: 'true', + }; + }); + + afterAll(() => { + process.env = env; + }); + + describe('token selection (id_token vs access_token)', () => { + it('should return id_token when both id_token and access_token are present', () => { + const tokenset = { + id_token: 'the-id-token', + access_token: 'the-access-token', + refresh_token: 'the-refresh-token', + }; + const req = mockRequest(); + const res = mockResponse(); + + const result = setOpenIDAuthTokens(tokenset, req, res, 'user-123'); + expect(result).toBe('the-id-token'); + }); + + it('should return access_token when id_token is not available', () => { + const tokenset = { + access_token: 'the-access-token', + refresh_token: 'the-refresh-token', + }; + const req = mockRequest(); + const res = mockResponse(); + + const result = setOpenIDAuthTokens(tokenset, req, res, 'user-123'); + expect(result).toBe('the-access-token'); + }); + + it('should return access_token when id_token is undefined', () => { + const tokenset = { + id_token: undefined, + access_token: 'the-access-token', + refresh_token: 'the-refresh-token', + }; + const req = mockRequest(); + const res = mockResponse(); + + const result = setOpenIDAuthTokens(tokenset, req, res, 'user-123'); + expect(result).toBe('the-access-token'); + }); + + it('should return access_token when id_token is null', () => { + const tokenset = { + id_token: null, + access_token: 'the-access-token', + refresh_token: 'the-refresh-token', + }; + const req = mockRequest(); + const res = mockResponse(); + + const result = setOpenIDAuthTokens(tokenset, req, res, 'user-123'); + expect(result).toBe('the-access-token'); + }); + + it('should return id_token even when id_token and access_token differ', () => { + const tokenset = { + id_token: 'id-token-jwt-signed-by-idp', + access_token: 'opaque-graph-api-token', + refresh_token: 'refresh-token', + }; + const req = mockRequest(); + const res = mockResponse(); + + const result = setOpenIDAuthTokens(tokenset, req, res, 'user-123'); + expect(result).toBe('id-token-jwt-signed-by-idp'); + expect(result).not.toBe('opaque-graph-api-token'); + }); + }); + + describe('session token storage', () => { + it('should store the original access_token in session (not id_token)', () => { + const tokenset = { + id_token: 'the-id-token', + access_token: 'the-access-token', + refresh_token: 'the-refresh-token', + }; + const req = mockRequest(); + const res = mockResponse(); + + setOpenIDAuthTokens(tokenset, req, res, 'user-123'); + + expect(req.session.openidTokens.accessToken).toBe('the-access-token'); + expect(req.session.openidTokens.refreshToken).toBe('the-refresh-token'); + }); + }); + + describe('cookie secure flag', () => { + it('should call shouldUseSecureCookie for every cookie set', () => { + const tokenset = { + id_token: 'the-id-token', + access_token: 'the-access-token', + refresh_token: 'the-refresh-token', + }; + const req = mockRequest(); + const res = mockResponse(); + + setOpenIDAuthTokens(tokenset, req, res, 'user-123'); + + // token_provider + openid_user_id (session path, so no refreshToken/openid_access_token cookies) + const secureCalls = shouldUseSecureCookie.mock.calls.length; + expect(secureCalls).toBeGreaterThanOrEqual(2); + + // Verify all cookies use the result of shouldUseSecureCookie + for (const [, cookie] of Object.entries(res._cookies)) { + expect(cookie.options.secure).toBe(false); + } + }); + + it('should set secure: true when shouldUseSecureCookie returns true', () => { + shouldUseSecureCookie.mockReturnValue(true); + + const tokenset = { + id_token: 'the-id-token', + access_token: 'the-access-token', + refresh_token: 'the-refresh-token', + }; + const req = mockRequest(); + const res = mockResponse(); + + setOpenIDAuthTokens(tokenset, req, res, 'user-123'); + + for (const [, cookie] of Object.entries(res._cookies)) { + expect(cookie.options.secure).toBe(true); + } + }); + + it('should use shouldUseSecureCookie for cookie fallback path (no session)', () => { + shouldUseSecureCookie.mockReturnValue(false); + + const tokenset = { + id_token: 'the-id-token', + access_token: 'the-access-token', + refresh_token: 'the-refresh-token', + }; + const req = { session: null }; + const res = mockResponse(); + + setOpenIDAuthTokens(tokenset, req, res, 'user-123'); + + // In the cookie fallback path, we get: refreshToken, openid_access_token, token_provider, openid_user_id + expect(res.cookie).toHaveBeenCalledWith( + 'refreshToken', + expect.any(String), + expect.objectContaining({ secure: false }), + ); + expect(res.cookie).toHaveBeenCalledWith( + 'openid_access_token', + expect.any(String), + expect.objectContaining({ secure: false }), + ); + expect(res.cookie).toHaveBeenCalledWith( + 'token_provider', + 'openid', + expect.objectContaining({ secure: false }), + ); + }); + }); + + describe('edge cases', () => { + it('should return undefined when tokenset is null', () => { + const req = mockRequest(); + const res = mockResponse(); + const result = setOpenIDAuthTokens(null, req, res, 'user-123'); + expect(result).toBeUndefined(); + }); + + it('should return undefined when access_token is missing', () => { + const tokenset = { refresh_token: 'refresh' }; + const req = mockRequest(); + const res = mockResponse(); + const result = setOpenIDAuthTokens(tokenset, req, res, 'user-123'); + expect(result).toBeUndefined(); + }); + + it('should return undefined when no refresh token is available', () => { + const tokenset = { access_token: 'access', id_token: 'id' }; + const req = mockRequest(); + const res = mockResponse(); + const result = setOpenIDAuthTokens(tokenset, req, res, 'user-123'); + expect(result).toBeUndefined(); + }); + + it('should use existingRefreshToken when tokenset has no refresh_token', () => { + const tokenset = { + id_token: 'the-id-token', + access_token: 'the-access-token', + }; + const req = mockRequest(); + const res = mockResponse(); + + const result = setOpenIDAuthTokens(tokenset, req, res, 'user-123', 'existing-refresh'); + expect(result).toBe('the-id-token'); + expect(req.session.openidTokens.refreshToken).toBe('existing-refresh'); + }); + }); +}); diff --git a/api/server/services/Config/__tests__/getCachedTools.spec.js b/api/server/services/Config/__tests__/getCachedTools.spec.js index 48ab6e0737..38d488ed38 100644 --- a/api/server/services/Config/__tests__/getCachedTools.spec.js +++ b/api/server/services/Config/__tests__/getCachedTools.spec.js @@ -1,10 +1,92 @@ -const { ToolCacheKeys } = require('../getCachedTools'); +const { CacheKeys } = require('librechat-data-provider'); + +jest.mock('~/cache/getLogStores'); +const getLogStores = require('~/cache/getLogStores'); + +const mockCache = { get: jest.fn(), set: jest.fn(), delete: jest.fn() }; +getLogStores.mockReturnValue(mockCache); + +const { + ToolCacheKeys, + getCachedTools, + setCachedTools, + getMCPServerTools, + invalidateCachedTools, +} = require('../getCachedTools'); + +describe('getCachedTools', () => { + beforeEach(() => { + jest.clearAllMocks(); + getLogStores.mockReturnValue(mockCache); + }); -describe('getCachedTools - Cache Isolation Security', () => { describe('ToolCacheKeys.MCP_SERVER', () => { it('should generate cache keys that include userId', () => { const key = ToolCacheKeys.MCP_SERVER('user123', 'github'); expect(key).toBe('tools:mcp:user123:github'); }); }); + + describe('TOOL_CACHE namespace usage', () => { + it('getCachedTools should use TOOL_CACHE namespace', async () => { + mockCache.get.mockResolvedValue(null); + await getCachedTools(); + expect(getLogStores).toHaveBeenCalledWith(CacheKeys.TOOL_CACHE); + }); + + it('getCachedTools with MCP server options should use TOOL_CACHE namespace', async () => { + mockCache.get.mockResolvedValue({ tool1: {} }); + await getCachedTools({ userId: 'user1', serverName: 'github' }); + expect(getLogStores).toHaveBeenCalledWith(CacheKeys.TOOL_CACHE); + expect(mockCache.get).toHaveBeenCalledWith(ToolCacheKeys.MCP_SERVER('user1', 'github')); + }); + + it('setCachedTools should use TOOL_CACHE namespace', async () => { + mockCache.set.mockResolvedValue(true); + const tools = { tool1: { type: 'function' } }; + await setCachedTools(tools); + expect(getLogStores).toHaveBeenCalledWith(CacheKeys.TOOL_CACHE); + expect(mockCache.set).toHaveBeenCalledWith(ToolCacheKeys.GLOBAL, tools, expect.any(Number)); + }); + + it('setCachedTools with MCP server options should use TOOL_CACHE namespace', async () => { + mockCache.set.mockResolvedValue(true); + const tools = { tool1: { type: 'function' } }; + await setCachedTools(tools, { userId: 'user1', serverName: 'github' }); + expect(getLogStores).toHaveBeenCalledWith(CacheKeys.TOOL_CACHE); + expect(mockCache.set).toHaveBeenCalledWith( + ToolCacheKeys.MCP_SERVER('user1', 'github'), + tools, + expect.any(Number), + ); + }); + + it('invalidateCachedTools should use TOOL_CACHE namespace', async () => { + mockCache.delete.mockResolvedValue(true); + await invalidateCachedTools({ invalidateGlobal: true }); + expect(getLogStores).toHaveBeenCalledWith(CacheKeys.TOOL_CACHE); + expect(mockCache.delete).toHaveBeenCalledWith(ToolCacheKeys.GLOBAL); + }); + + it('getMCPServerTools should use TOOL_CACHE namespace', async () => { + mockCache.get.mockResolvedValue(null); + await getMCPServerTools('user1', 'github'); + expect(getLogStores).toHaveBeenCalledWith(CacheKeys.TOOL_CACHE); + expect(mockCache.get).toHaveBeenCalledWith(ToolCacheKeys.MCP_SERVER('user1', 'github')); + }); + + it('should NOT use CONFIG_STORE namespace', async () => { + mockCache.get.mockResolvedValue(null); + await getCachedTools(); + await getMCPServerTools('user1', 'github'); + mockCache.set.mockResolvedValue(true); + await setCachedTools({ tool1: {} }); + mockCache.delete.mockResolvedValue(true); + await invalidateCachedTools({ invalidateGlobal: true }); + + const allCalls = getLogStores.mock.calls.flat(); + expect(allCalls).not.toContain(CacheKeys.CONFIG_STORE); + expect(allCalls.every((key) => key === CacheKeys.TOOL_CACHE)).toBe(true); + }); + }); }); diff --git a/api/server/services/Config/getCachedTools.js b/api/server/services/Config/getCachedTools.js index cf1618a646..eb7a08305a 100644 --- a/api/server/services/Config/getCachedTools.js +++ b/api/server/services/Config/getCachedTools.js @@ -20,7 +20,7 @@ const ToolCacheKeys = { * @returns {Promise} The available tools object or null if not cached */ async function getCachedTools(options = {}) { - const cache = getLogStores(CacheKeys.CONFIG_STORE); + const cache = getLogStores(CacheKeys.TOOL_CACHE); const { userId, serverName } = options; // Return MCP server-specific tools if requested @@ -43,7 +43,7 @@ async function getCachedTools(options = {}) { * @returns {Promise} Whether the operation was successful */ async function setCachedTools(tools, options = {}) { - const cache = getLogStores(CacheKeys.CONFIG_STORE); + const cache = getLogStores(CacheKeys.TOOL_CACHE); const { userId, serverName, ttl = Time.TWELVE_HOURS } = options; // Cache by MCP server if specified (requires userId) @@ -65,7 +65,7 @@ async function setCachedTools(tools, options = {}) { * @returns {Promise} */ async function invalidateCachedTools(options = {}) { - const cache = getLogStores(CacheKeys.CONFIG_STORE); + const cache = getLogStores(CacheKeys.TOOL_CACHE); const { userId, serverName, invalidateGlobal = false } = options; const keysToDelete = []; @@ -89,7 +89,7 @@ async function invalidateCachedTools(options = {}) { * @returns {Promise} The available tools for the server */ async function getMCPServerTools(userId, serverName) { - const cache = getLogStores(CacheKeys.CONFIG_STORE); + const cache = getLogStores(CacheKeys.TOOL_CACHE); const serverTools = await cache.get(ToolCacheKeys.MCP_SERVER(userId, serverName)); if (serverTools) { diff --git a/api/server/services/Config/loadConfigModels.js b/api/server/services/Config/loadConfigModels.js index 6354d10331..2bc83ecc3a 100644 --- a/api/server/services/Config/loadConfigModels.js +++ b/api/server/services/Config/loadConfigModels.js @@ -28,6 +28,11 @@ async function loadConfigModels(req) { modelsConfig[EModelEndpoint.azureAssistants] = azureConfig.assistantModels; } + const bedrockConfig = appConfig.endpoints?.[EModelEndpoint.bedrock]; + if (bedrockConfig?.models && Array.isArray(bedrockConfig.models)) { + modelsConfig[EModelEndpoint.bedrock] = bedrockConfig.models; + } + if (!Array.isArray(appConfig.endpoints?.[EModelEndpoint.custom])) { return modelsConfig; } diff --git a/api/server/services/Config/mcp.js b/api/server/services/Config/mcp.js index 15ea62a028..cc4e98b59e 100644 --- a/api/server/services/Config/mcp.js +++ b/api/server/services/Config/mcp.js @@ -35,7 +35,7 @@ async function updateMCPServerTools({ userId, serverName, tools }) { await setCachedTools(serverTools, { userId, serverName }); - const cache = getLogStores(CacheKeys.CONFIG_STORE); + const cache = getLogStores(CacheKeys.TOOL_CACHE); await cache.delete(CacheKeys.TOOLS); logger.debug( `[MCP Cache] Updated ${tools.length} tools for server ${serverName} (user: ${userId})`, @@ -61,7 +61,7 @@ async function mergeAppTools(appTools) { const cachedTools = await getCachedTools(); const mergedTools = { ...cachedTools, ...appTools }; await setCachedTools(mergedTools); - const cache = getLogStores(CacheKeys.CONFIG_STORE); + const cache = getLogStores(CacheKeys.TOOL_CACHE); await cache.delete(CacheKeys.TOOLS); logger.debug(`Merged ${count} app-level tools`); } catch (error) { diff --git a/api/server/services/Endpoints/agents/addedConvo.js b/api/server/services/Endpoints/agents/addedConvo.js index 240622ed9f..7e9385267a 100644 --- a/api/server/services/Endpoints/agents/addedConvo.js +++ b/api/server/services/Endpoints/agents/addedConvo.js @@ -31,6 +31,7 @@ setGetAgent(getAgent); * @param {Function} params.loadTools - Function to load agent tools * @param {Array} params.requestFiles - Request files * @param {string} params.conversationId - The conversation ID + * @param {string} [params.parentMessageId] - The parent message ID for thread filtering * @param {Set} params.allowedProviders - Set of allowed providers * @param {Map} params.agentConfigs - Map of agent configs to add to * @param {string} params.primaryAgentId - The primary agent ID @@ -46,6 +47,7 @@ const processAddedConvo = async ({ loadTools, requestFiles, conversationId, + parentMessageId, allowedProviders, agentConfigs, primaryAgentId, @@ -91,6 +93,7 @@ const processAddedConvo = async ({ loadTools, requestFiles, conversationId, + parentMessageId, agent: addedAgent, endpointOption, allowedProviders, @@ -99,9 +102,12 @@ const processAddedConvo = async ({ getConvoFiles, getFiles: db.getFiles, getUserKey: db.getUserKey, + getMessages: db.getMessages, updateFilesUsage: db.updateFilesUsage, + getUserCodeFiles: db.getUserCodeFiles, getUserKeyValues: db.getUserKeyValues, getToolFilesByIds: db.getToolFilesByIds, + getCodeGeneratedFiles: db.getCodeGeneratedFiles, }, ); diff --git a/api/server/services/Endpoints/agents/initialize.js b/api/server/services/Endpoints/agents/initialize.js index 626beed153..0888f23cd5 100644 --- a/api/server/services/Endpoints/agents/initialize.js +++ b/api/server/services/Endpoints/agents/initialize.js @@ -3,10 +3,11 @@ const { createContentAggregator } = require('@librechat/agents'); const { initializeAgent, validateAgentModel, - getCustomEndpointConfig, - createSequentialChainEdges, createEdgeCollector, filterOrphanedEdges, + GenerationJobManager, + getCustomEndpointConfig, + createSequentialChainEdges, } = require('@librechat/api'); const { EModelEndpoint, @@ -18,8 +19,8 @@ const { createToolEndCallback, getDefaultHandlers, } = require('~/server/controllers/agents/callbacks'); +const { loadAgentTools, loadToolsForExecution } = require('~/server/services/ToolService'); const { getModelsConfig } = require('~/server/controllers/ModelController'); -const { loadAgentTools } = require('~/server/services/ToolService'); const AgentClient = require('~/server/controllers/agents/client'); const { getConvoFiles } = require('~/models/Conversation'); const { processAddedConvo } = require('./addedConvo'); @@ -31,8 +32,10 @@ const db = require('~/models'); * Creates a tool loader function for the agent. * @param {AbortSignal} signal - The abort signal * @param {string | null} [streamId] - The stream ID for resumable mode + * @param {boolean} [definitionsOnly=false] - When true, returns only serializable + * tool definitions without creating full tool instances (for event-driven mode) */ -function createToolLoader(signal, streamId = null) { +function createToolLoader(signal, streamId = null, definitionsOnly = false) { /** * @param {object} params * @param {ServerRequest} params.req @@ -43,21 +46,33 @@ function createToolLoader(signal, streamId = null) { * @param {string} params.model * @param {AgentToolResources} params.tool_resources * @returns {Promise<{ - * tools: StructuredTool[], - * toolContextMap: Record, - * userMCPAuthMap?: Record> + * tools?: StructuredTool[], + * toolContextMap: Record, + * toolDefinitions?: import('@librechat/agents').LCTool[], + * userMCPAuthMap?: Record>, + * toolRegistry?: import('@librechat/agents').LCToolRegistry * } | undefined>} */ - return async function loadTools({ req, res, agentId, tools, provider, model, tool_resources }) { - const agent = { id: agentId, tools, provider, model }; + return async function loadTools({ + req, + res, + tools, + model, + agentId, + provider, + tool_options, + tool_resources, + }) { + const agent = { id: agentId, tools, provider, model, tool_options }; try { return await loadAgentTools({ req, res, agent, signal, - tool_resources, streamId, + tool_resources, + definitionsOnly, }); } catch (error) { logger.error('Error loading tools for agent ' + agentId, error); @@ -80,8 +95,47 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => { const artifactPromises = []; const { contentParts, aggregateContent } = createContentAggregator(); const toolEndCallback = createToolEndCallback({ req, res, artifactPromises, streamId }); + + /** + * Agent context store - populated after initialization, accessed by callback via closure. + * Maps agentId -> { userMCPAuthMap, agent, tool_resources, toolRegistry, openAIApiKey } + * @type {Map>, + * agent?: object, + * tool_resources?: object, + * toolRegistry?: import('@librechat/agents').LCToolRegistry, + * openAIApiKey?: string + * }>} + */ + const agentToolContexts = new Map(); + + const toolExecuteOptions = { + loadTools: async (toolNames, agentId) => { + const ctx = agentToolContexts.get(agentId) ?? {}; + logger.debug(`[ON_TOOL_EXECUTE] ctx found: ${!!ctx.userMCPAuthMap}, agent: ${ctx.agent?.id}`); + logger.debug(`[ON_TOOL_EXECUTE] toolRegistry size: ${ctx.toolRegistry?.size ?? 'undefined'}`); + + const result = await loadToolsForExecution({ + req, + res, + signal, + streamId, + toolNames, + agent: ctx.agent, + toolRegistry: ctx.toolRegistry, + userMCPAuthMap: ctx.userMCPAuthMap, + tool_resources: ctx.tool_resources, + }); + + logger.debug(`[ON_TOOL_EXECUTE] loaded ${result.loadedTools?.length ?? 0} tools`); + return result; + }, + toolEndCallback, + }; + const eventHandlers = getDefaultHandlers({ res, + toolExecuteOptions, aggregateContent, toolEndCallback, collectedUsage, @@ -114,11 +168,14 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => { const agentConfigs = new Map(); const allowedProviders = new Set(appConfig?.endpoints?.[EModelEndpoint.agents]?.allowedProviders); - const loadTools = createToolLoader(signal, streamId); + /** Event-driven mode: only load tool definitions, not full instances */ + const loadTools = createToolLoader(signal, streamId, true); /** @type {Array} */ const requestFiles = req.body.files ?? []; /** @type {string} */ const conversationId = req.body.conversationId; + /** @type {string | undefined} */ + const parentMessageId = req.body.parentMessageId; const primaryConfig = await initializeAgent( { @@ -127,6 +184,7 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => { loadTools, requestFiles, conversationId, + parentMessageId, agent: primaryAgent, endpointOption, allowedProviders, @@ -136,12 +194,31 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => { getConvoFiles, getFiles: db.getFiles, getUserKey: db.getUserKey, + getMessages: db.getMessages, updateFilesUsage: db.updateFilesUsage, getUserKeyValues: db.getUserKeyValues, + getUserCodeFiles: db.getUserCodeFiles, getToolFilesByIds: db.getToolFilesByIds, + getCodeGeneratedFiles: db.getCodeGeneratedFiles, }, ); + logger.debug( + `[initializeClient] Tool definitions for primary agent: ${primaryConfig.toolDefinitions?.length ?? 0}`, + ); + + /** Store primary agent's tool context for ON_TOOL_EXECUTE callback */ + logger.debug(`[initializeClient] Storing tool context for agentId: ${primaryConfig.id}`); + logger.debug( + `[initializeClient] toolRegistry size: ${primaryConfig.toolRegistry?.size ?? 'undefined'}`, + ); + agentToolContexts.set(primaryConfig.id, { + agent: primaryAgent, + toolRegistry: primaryConfig.toolRegistry, + userMCPAuthMap: primaryConfig.userMCPAuthMap, + tool_resources: primaryConfig.tool_resources, + }); + const agent_ids = primaryConfig.agent_ids; let userMCPAuthMap = primaryConfig.userMCPAuthMap; @@ -178,6 +255,7 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => { loadTools, requestFiles, conversationId, + parentMessageId, endpointOption, allowedProviders, }, @@ -185,16 +263,29 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => { getConvoFiles, getFiles: db.getFiles, getUserKey: db.getUserKey, + getMessages: db.getMessages, updateFilesUsage: db.updateFilesUsage, getUserKeyValues: db.getUserKeyValues, + getUserCodeFiles: db.getUserCodeFiles, getToolFilesByIds: db.getToolFilesByIds, + getCodeGeneratedFiles: db.getCodeGeneratedFiles, }, ); + if (userMCPAuthMap != null) { Object.assign(userMCPAuthMap, config.userMCPAuthMap ?? {}); } else { userMCPAuthMap = config.userMCPAuthMap; } + + /** Store handoff agent's tool context for ON_TOOL_EXECUTE callback */ + agentToolContexts.set(agentId, { + agent, + toolRegistry: config.toolRegistry, + userMCPAuthMap: config.userMCPAuthMap, + tool_resources: config.tool_resources, + }); + agentConfigs.set(agentId, config); return agent; } @@ -242,17 +333,18 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => { const { userMCPAuthMap: updatedMCPAuthMap } = await processAddedConvo({ req, res, - endpointOption, - modelsConfig, - logViolation, loadTools, + logViolation, + modelsConfig, requestFiles, - conversationId, - allowedProviders, agentConfigs, - primaryAgentId: primaryConfig.id, primaryAgent, + endpointOption, userMCPAuthMap, + conversationId, + parentMessageId, + allowedProviders, + primaryAgentId: primaryConfig.id, }); if (updatedMCPAuthMap) { @@ -314,6 +406,10 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => { endpoint: isEphemeralAgentId(primaryConfig.id) ? primaryConfig.endpoint : EModelEndpoint.agents, }); + if (streamId) { + GenerationJobManager.setCollectedUsage(streamId, collectedUsage); + } + return { client, userMCPAuthMap }; }; diff --git a/api/server/services/Endpoints/agents/title.js b/api/server/services/Endpoints/agents/title.js index 1d6d359bd6..e31cdeea11 100644 --- a/api/server/services/Endpoints/agents/title.js +++ b/api/server/services/Endpoints/agents/title.js @@ -71,7 +71,7 @@ const addTitle = async (req, { text, response, client }) => { conversationId: response.conversationId, title, }, - { context: 'api/server/services/Endpoints/agents/title.js' }, + { context: 'api/server/services/Endpoints/agents/title.js', noUpsert: true }, ); } catch (error) { logger.error('Error generating title:', error); diff --git a/api/server/services/Endpoints/assistants/title.js b/api/server/services/Endpoints/assistants/title.js index a34de4d1af..1fae68cf54 100644 --- a/api/server/services/Endpoints/assistants/title.js +++ b/api/server/services/Endpoints/assistants/title.js @@ -69,7 +69,7 @@ const addTitle = async (req, { text, responseText, conversationId }) => { conversationId, title, }, - { context: 'api/server/services/Endpoints/assistants/addTitle.js' }, + { context: 'api/server/services/Endpoints/assistants/addTitle.js', noUpsert: true }, ); } catch (error) { logger.error('[addTitle] Error generating title:', error); @@ -81,7 +81,7 @@ const addTitle = async (req, { text, responseText, conversationId }) => { conversationId, title: fallbackTitle, }, - { context: 'api/server/services/Endpoints/assistants/addTitle.js' }, + { context: 'api/server/services/Endpoints/assistants/addTitle.js', noUpsert: true }, ); } }; diff --git a/api/server/services/Endpoints/azureAssistants/initialize.js b/api/server/services/Endpoints/azureAssistants/initialize.js index 6a9118ea8a..e81f0bcd8a 100644 --- a/api/server/services/Endpoints/azureAssistants/initialize.js +++ b/api/server/services/Endpoints/azureAssistants/initialize.js @@ -128,7 +128,6 @@ const initializeClient = async ({ req, res, version, endpointOption, initAppClie const groupName = modelGroupMap[modelName].group; clientOptions.addParams = azureConfig.groupMap[groupName].addParams; clientOptions.dropParams = azureConfig.groupMap[groupName].dropParams; - clientOptions.forcePrompt = azureConfig.groupMap[groupName].forcePrompt; clientOptions.reverseProxyUrl = baseURL ?? clientOptions.reverseProxyUrl; clientOptions.headers = opts.defaultHeaders; diff --git a/api/server/services/Endpoints/index.js b/api/server/services/Endpoints/index.js index 034162702d..3cabfe1c58 100644 --- a/api/server/services/Endpoints/index.js +++ b/api/server/services/Endpoints/index.js @@ -12,7 +12,7 @@ const initGoogle = require('~/server/services/Endpoints/google/initialize'); * @returns {boolean} - True if the provider is a known custom provider, false otherwise */ function isKnownCustomProvider(provider) { - return [Providers.XAI, Providers.DEEPSEEK, Providers.OPENROUTER].includes( + return [Providers.XAI, Providers.DEEPSEEK, Providers.OPENROUTER, Providers.MOONSHOT].includes( provider?.toLowerCase() || '', ); } @@ -20,6 +20,7 @@ function isKnownCustomProvider(provider) { const providerConfigMap = { [Providers.XAI]: initCustom, [Providers.DEEPSEEK]: initCustom, + [Providers.MOONSHOT]: initCustom, [Providers.OPENROUTER]: initCustom, [EModelEndpoint.openAI]: initOpenAI, [EModelEndpoint.google]: initGoogle, diff --git a/api/server/services/Files/Azure/crud.js b/api/server/services/Files/Azure/crud.js index 25bd749276..8f681bd06c 100644 --- a/api/server/services/Files/Azure/crud.js +++ b/api/server/services/Files/Azure/crud.js @@ -4,7 +4,7 @@ const mime = require('mime'); const axios = require('axios'); const fetch = require('node-fetch'); const { logger } = require('@librechat/data-schemas'); -const { getAzureContainerClient } = require('@librechat/api'); +const { getAzureContainerClient, deleteRagFile } = require('@librechat/api'); const defaultBasePath = 'images'; const { AZURE_STORAGE_PUBLIC_ACCESS = 'true', AZURE_CONTAINER_NAME = 'files' } = process.env; @@ -102,6 +102,8 @@ async function getAzureURL({ fileName, basePath = defaultBasePath, userId, conta * @param {MongoFile} params.file - The file object. */ async function deleteFileFromAzure(req, file) { + await deleteRagFile({ userId: req.user.id, file }); + try { const containerClient = await getAzureContainerClient(AZURE_CONTAINER_NAME); const blobPath = file.filepath.split(`${AZURE_CONTAINER_NAME}/`)[1]; diff --git a/api/server/services/Files/Code/process.js b/api/server/services/Files/Code/process.js index 15df6de0d6..3f0bfcfc87 100644 --- a/api/server/services/Files/Code/process.js +++ b/api/server/services/Files/Code/process.js @@ -6,27 +6,68 @@ const { getCodeBaseURL } = require('@librechat/agents'); const { logAxiosError, getBasePath } = require('@librechat/api'); const { Tools, + megabyte, + fileConfig, FileContext, FileSources, imageExtRegex, + inferMimeType, EToolResources, + EModelEndpoint, + mergeFileConfig, + getEndpointFileConfig, } = require('librechat-data-provider'); const { filterFilesByAgentAccess } = require('~/server/services/Files/permissions'); +const { createFile, getFiles, updateFile, claimCodeFile } = require('~/models'); const { getStrategyFunctions } = require('~/server/services/Files/strategies'); const { convertImage } = require('~/server/services/Files/images/convert'); -const { createFile, getFiles, updateFile } = require('~/models'); +const { determineFileType } = require('~/server/utils'); /** - * Process OpenAI image files, convert to target format, save and return file metadata. + * Creates a fallback download URL response when file cannot be processed locally. + * Used when: file exceeds size limit, storage strategy unavailable, or download error occurs. + * @param {Object} params - The parameters. + * @param {string} params.name - The filename. + * @param {string} params.session_id - The code execution session ID. + * @param {string} params.id - The file ID from the code environment. + * @param {string} params.conversationId - The current conversation ID. + * @param {string} params.toolCallId - The tool call ID that generated the file. + * @param {string} params.messageId - The current message ID. + * @param {number} params.expiresAt - Expiration timestamp (24 hours from creation). + * @returns {Object} Fallback response with download URL. + */ +const createDownloadFallback = ({ + id, + name, + messageId, + expiresAt, + session_id, + toolCallId, + conversationId, +}) => { + const basePath = getBasePath(); + return { + filename: name, + filepath: `${basePath}/api/files/code/download/${session_id}/${id}`, + expiresAt, + conversationId, + toolCallId, + messageId, + }; +}; + +/** + * Process code execution output files - downloads and saves both images and non-image files. + * All files are saved to local storage with fileIdentifier metadata for code env re-upload. * @param {ServerRequest} params.req - The Express request object. - * @param {string} params.id - The file ID. + * @param {string} params.id - The file ID from the code environment. * @param {string} params.name - The filename. * @param {string} params.apiKey - The code execution API key. * @param {string} params.toolCallId - The tool call ID that generated the file. * @param {string} params.session_id - The code execution session ID. * @param {string} params.conversationId - The current conversation ID. * @param {string} params.messageId - The current message ID. - * @returns {Promise} The file metadata or undefined if an error occurs. + * @returns {Promise} The file metadata or undefined if an error occurs. */ const processCodeOutput = async ({ req, @@ -41,19 +82,15 @@ const processCodeOutput = async ({ const appConfig = req.config; const currentDate = new Date(); const baseURL = getCodeBaseURL(); - const basePath = getBasePath(); - const fileExt = path.extname(name); - if (!fileExt || !imageExtRegex.test(name)) { - return { - filename: name, - filepath: `${basePath}/api/files/code/download/${session_id}/${id}`, - /** Note: expires 24 hours after creation */ - expiresAt: currentDate.getTime() + 86400000, - conversationId, - toolCallId, - messageId, - }; - } + const fileExt = path.extname(name).toLowerCase(); + const isImage = fileExt && imageExtRegex.test(name); + + const mergedFileConfig = mergeFileConfig(appConfig.fileConfig); + const endpointFileConfig = getEndpointFileConfig({ + fileConfig: mergedFileConfig, + endpoint: EModelEndpoint.agents, + }); + const fileSizeLimit = endpointFileConfig.fileSizeLimit ?? mergedFileConfig.serverFileSizeLimit; try { const formattedDate = currentDate.toISOString(); @@ -70,29 +107,143 @@ const processCodeOutput = async ({ const buffer = Buffer.from(response.data, 'binary'); - const file_id = v4(); - const _file = await convertImage(req, buffer, 'high', `${file_id}${fileExt}`); - const file = { - ..._file, - file_id, - usage: 1, + // Enforce file size limit + if (buffer.length > fileSizeLimit) { + logger.warn( + `[processCodeOutput] File "${name}" (${(buffer.length / megabyte).toFixed(2)} MB) exceeds size limit of ${(fileSizeLimit / megabyte).toFixed(2)} MB, falling back to download URL`, + ); + return createDownloadFallback({ + id, + name, + messageId, + toolCallId, + session_id, + conversationId, + expiresAt: currentDate.getTime() + 86400000, + }); + } + + const fileIdentifier = `${session_id}/${id}`; + + /** + * Atomically claim a file_id for this (filename, conversationId, context) tuple. + * Uses $setOnInsert so concurrent calls for the same filename converge on + * a single record instead of creating duplicates (TOCTOU race fix). + */ + const newFileId = v4(); + const claimed = await claimCodeFile({ filename: name, conversationId, + file_id: newFileId, user: req.user.id, - type: `image/${appConfig.imageOutputType}`, - createdAt: formattedDate, + }); + const file_id = claimed.file_id; + const isUpdate = file_id !== newFileId; + + if (isUpdate) { + logger.debug( + `[processCodeOutput] Updating existing file "${name}" (${file_id}) instead of creating duplicate`, + ); + } + + if (isImage) { + const usage = isUpdate ? (claimed.usage ?? 0) + 1 : 1; + const _file = await convertImage(req, buffer, 'high', `${file_id}${fileExt}`); + const filepath = usage > 1 ? `${_file.filepath}?v=${Date.now()}` : _file.filepath; + const file = { + ..._file, + filepath, + file_id, + messageId, + usage, + filename: name, + conversationId, + user: req.user.id, + type: `image/${appConfig.imageOutputType}`, + createdAt: isUpdate ? claimed.createdAt : formattedDate, + updatedAt: formattedDate, + source: appConfig.fileStrategy, + context: FileContext.execute_code, + metadata: { fileIdentifier }, + }; + await createFile(file, true); + return Object.assign(file, { messageId, toolCallId }); + } + + const { saveBuffer } = getStrategyFunctions(appConfig.fileStrategy); + if (!saveBuffer) { + logger.warn( + `[processCodeOutput] saveBuffer not available for strategy ${appConfig.fileStrategy}, falling back to download URL`, + ); + return createDownloadFallback({ + id, + name, + messageId, + toolCallId, + session_id, + conversationId, + expiresAt: currentDate.getTime() + 86400000, + }); + } + + const detectedType = await determineFileType(buffer, true); + const mimeType = detectedType?.mime || inferMimeType(name, '') || 'application/octet-stream'; + + /** Check MIME type support - for code-generated files, we're lenient but log unsupported types */ + const isSupportedMimeType = fileConfig.checkType( + mimeType, + endpointFileConfig.supportedMimeTypes, + ); + if (!isSupportedMimeType) { + logger.warn( + `[processCodeOutput] File "${name}" has unsupported MIME type "${mimeType}", proceeding with storage but may not be usable as tool resource`, + ); + } + + const fileName = `${file_id}__${name}`; + const filepath = await saveBuffer({ + userId: req.user.id, + buffer, + fileName, + basePath: 'uploads', + }); + + const file = { + file_id, + filepath, + messageId, + object: 'file', + filename: name, + type: mimeType, + conversationId, + user: req.user.id, + bytes: buffer.length, updatedAt: formattedDate, + metadata: { fileIdentifier }, source: appConfig.fileStrategy, context: FileContext.execute_code, + usage: isUpdate ? (claimed.usage ?? 0) + 1 : 1, + createdAt: isUpdate ? claimed.createdAt : formattedDate, }; - createFile(file, true); - /** Note: `messageId` & `toolCallId` are not part of file DB schema; message object records associated file ID */ + + await createFile(file, true); return Object.assign(file, { messageId, toolCallId }); } catch (error) { logAxiosError({ - message: 'Error downloading code environment file', + message: 'Error downloading/processing code environment file', error, }); + + // Fallback for download errors - return download URL so user can still manually download + return createDownloadFallback({ + id, + name, + messageId, + toolCallId, + session_id, + conversationId, + expiresAt: currentDate.getTime() + 86400000, + }); } }; @@ -204,9 +355,16 @@ const primeFiles = async (options, apiKey) => { if (!toolContext) { toolContext = `- Note: The following files are available in the "${Tools.execute_code}" tool environment:`; } - toolContext += `\n\t- /mnt/data/${file.filename}${ - agentResourceIds.has(file.file_id) ? '' : ' (just attached by user)' - }`; + + let fileSuffix = ''; + if (!agentResourceIds.has(file.file_id)) { + fileSuffix = + file.context === FileContext.execute_code + ? ' (from previous code execution)' + : ' (attached by user)'; + } + + toolContext += `\n\t- /mnt/data/${file.filename}${fileSuffix}`; files.push({ id, session_id, diff --git a/api/server/services/Files/Code/process.spec.js b/api/server/services/Files/Code/process.spec.js new file mode 100644 index 0000000000..f01a623f90 --- /dev/null +++ b/api/server/services/Files/Code/process.spec.js @@ -0,0 +1,411 @@ +// Configurable file size limit for tests - use a getter so it can be changed per test +const fileSizeLimitConfig = { value: 20 * 1024 * 1024 }; // Default 20MB + +// Mock librechat-data-provider with configurable file size limit +jest.mock('librechat-data-provider', () => { + const actual = jest.requireActual('librechat-data-provider'); + return { + ...actual, + mergeFileConfig: jest.fn((config) => { + const merged = actual.mergeFileConfig(config); + // Override the serverFileSizeLimit with our test value + return { + ...merged, + get serverFileSizeLimit() { + return fileSizeLimitConfig.value; + }, + }; + }), + getEndpointFileConfig: jest.fn((options) => { + const config = actual.getEndpointFileConfig(options); + // Override fileSizeLimit with our test value + return { + ...config, + get fileSizeLimit() { + return fileSizeLimitConfig.value; + }, + }; + }), + }; +}); + +const { FileContext } = require('librechat-data-provider'); + +// Mock uuid +jest.mock('uuid', () => ({ + v4: jest.fn(() => 'mock-uuid-1234'), +})); + +// Mock axios +jest.mock('axios'); +const axios = require('axios'); + +// Mock logger +jest.mock('@librechat/data-schemas', () => ({ + logger: { + warn: jest.fn(), + debug: jest.fn(), + error: jest.fn(), + }, +})); + +// Mock getCodeBaseURL +jest.mock('@librechat/agents', () => ({ + getCodeBaseURL: jest.fn(() => 'https://code-api.example.com'), +})); + +// Mock logAxiosError and getBasePath +jest.mock('@librechat/api', () => ({ + logAxiosError: jest.fn(), + getBasePath: jest.fn(() => ''), +})); + +// Mock models +const mockClaimCodeFile = jest.fn(); +jest.mock('~/models', () => ({ + createFile: jest.fn().mockResolvedValue({}), + getFiles: jest.fn(), + updateFile: jest.fn(), + claimCodeFile: (...args) => mockClaimCodeFile(...args), +})); + +// Mock permissions (must be before process.js import) +jest.mock('~/server/services/Files/permissions', () => ({ + filterFilesByAgentAccess: jest.fn((options) => Promise.resolve(options.files)), +})); + +// Mock strategy functions +jest.mock('~/server/services/Files/strategies', () => ({ + getStrategyFunctions: jest.fn(), +})); + +// Mock convertImage +jest.mock('~/server/services/Files/images/convert', () => ({ + convertImage: jest.fn(), +})); + +// Mock determineFileType +jest.mock('~/server/utils', () => ({ + determineFileType: jest.fn(), +})); + +const { createFile, getFiles } = require('~/models'); +const { getStrategyFunctions } = require('~/server/services/Files/strategies'); +const { convertImage } = require('~/server/services/Files/images/convert'); +const { determineFileType } = require('~/server/utils'); +const { logger } = require('@librechat/data-schemas'); + +// Import after mocks +const { processCodeOutput } = require('./process'); + +describe('Code Process', () => { + const mockReq = { + user: { id: 'user-123' }, + config: { + fileConfig: {}, + fileStrategy: 'local', + imageOutputType: 'webp', + }, + }; + + const baseParams = { + req: mockReq, + id: 'file-id-123', + name: 'test-file.txt', + apiKey: 'test-api-key', + toolCallId: 'tool-call-123', + conversationId: 'conv-123', + messageId: 'msg-123', + session_id: 'session-123', + }; + + beforeEach(() => { + jest.clearAllMocks(); + // Default mock: atomic claim returns a new file record (no existing file) + mockClaimCodeFile.mockResolvedValue({ + file_id: 'mock-uuid-1234', + user: 'user-123', + }); + getFiles.mockResolvedValue(null); + createFile.mockResolvedValue({}); + getStrategyFunctions.mockReturnValue({ + saveBuffer: jest.fn().mockResolvedValue('/uploads/mock-file-path.txt'), + }); + determineFileType.mockResolvedValue({ mime: 'text/plain' }); + }); + + describe('atomic file claim (via processCodeOutput)', () => { + it('should reuse file_id from existing record via atomic claim', async () => { + mockClaimCodeFile.mockResolvedValue({ + file_id: 'existing-file-id', + filename: 'test-file.txt', + usage: 2, + createdAt: '2024-01-01T00:00:00.000Z', + }); + + const smallBuffer = Buffer.alloc(100); + axios.mockResolvedValue({ data: smallBuffer }); + + const result = await processCodeOutput(baseParams); + + expect(mockClaimCodeFile).toHaveBeenCalledWith({ + filename: 'test-file.txt', + conversationId: 'conv-123', + file_id: 'mock-uuid-1234', + user: 'user-123', + }); + + expect(result.file_id).toBe('existing-file-id'); + expect(result.usage).toBe(3); + expect(result.createdAt).toBe('2024-01-01T00:00:00.000Z'); + }); + + it('should create new file when no existing file found', async () => { + mockClaimCodeFile.mockResolvedValue({ + file_id: 'mock-uuid-1234', + user: 'user-123', + }); + + const smallBuffer = Buffer.alloc(100); + axios.mockResolvedValue({ data: smallBuffer }); + + const result = await processCodeOutput(baseParams); + + expect(result.file_id).toBe('mock-uuid-1234'); + expect(result.usage).toBe(1); + }); + }); + + describe('processCodeOutput', () => { + describe('image file processing', () => { + it('should process image files using convertImage', async () => { + const imageParams = { ...baseParams, name: 'chart.png' }; + const imageBuffer = Buffer.alloc(500); + axios.mockResolvedValue({ data: imageBuffer }); + + const convertedFile = { + filepath: '/uploads/converted-image.webp', + bytes: 400, + }; + convertImage.mockResolvedValue(convertedFile); + + const result = await processCodeOutput(imageParams); + + expect(convertImage).toHaveBeenCalledWith( + mockReq, + imageBuffer, + 'high', + 'mock-uuid-1234.png', + ); + expect(result.type).toBe('image/webp'); + expect(result.context).toBe(FileContext.execute_code); + expect(result.filename).toBe('chart.png'); + }); + + it('should update existing image file with cache-busted filepath', async () => { + const imageParams = { ...baseParams, name: 'chart.png' }; + mockClaimCodeFile.mockResolvedValue({ + file_id: 'existing-img-id', + usage: 1, + createdAt: '2024-01-01T00:00:00.000Z', + }); + + const imageBuffer = Buffer.alloc(500); + axios.mockResolvedValue({ data: imageBuffer }); + convertImage.mockResolvedValue({ filepath: '/images/user-123/existing-img-id.webp' }); + + const result = await processCodeOutput(imageParams); + + expect(convertImage).toHaveBeenCalledWith( + mockReq, + imageBuffer, + 'high', + 'existing-img-id.png', + ); + expect(result.file_id).toBe('existing-img-id'); + expect(result.usage).toBe(2); + expect(result.filepath).toMatch(/^\/images\/user-123\/existing-img-id\.webp\?v=\d+$/); + expect(logger.debug).toHaveBeenCalledWith( + expect.stringContaining('Updating existing file'), + ); + }); + }); + + describe('non-image file processing', () => { + it('should process non-image files using saveBuffer', async () => { + const smallBuffer = Buffer.alloc(100); + axios.mockResolvedValue({ data: smallBuffer }); + + const mockSaveBuffer = jest.fn().mockResolvedValue('/uploads/saved-file.txt'); + getStrategyFunctions.mockReturnValue({ saveBuffer: mockSaveBuffer }); + determineFileType.mockResolvedValue({ mime: 'text/plain' }); + + const result = await processCodeOutput(baseParams); + + expect(mockSaveBuffer).toHaveBeenCalledWith({ + userId: 'user-123', + buffer: smallBuffer, + fileName: 'mock-uuid-1234__test-file.txt', + basePath: 'uploads', + }); + expect(result.type).toBe('text/plain'); + expect(result.filepath).toBe('/uploads/saved-file.txt'); + expect(result.bytes).toBe(100); + }); + + it('should detect MIME type from buffer', async () => { + const smallBuffer = Buffer.alloc(100); + axios.mockResolvedValue({ data: smallBuffer }); + determineFileType.mockResolvedValue({ mime: 'application/pdf' }); + + const result = await processCodeOutput({ ...baseParams, name: 'document.pdf' }); + + expect(determineFileType).toHaveBeenCalledWith(smallBuffer, true); + expect(result.type).toBe('application/pdf'); + }); + + it('should fallback to application/octet-stream for unknown types', async () => { + const smallBuffer = Buffer.alloc(100); + axios.mockResolvedValue({ data: smallBuffer }); + determineFileType.mockResolvedValue(null); + + const result = await processCodeOutput({ ...baseParams, name: 'unknown.xyz' }); + + expect(result.type).toBe('application/octet-stream'); + }); + }); + + describe('file size limit enforcement', () => { + it('should fallback to download URL when file exceeds size limit', async () => { + // Set a small file size limit for this test + fileSizeLimitConfig.value = 1000; // 1KB limit + + const largeBuffer = Buffer.alloc(5000); // 5KB - exceeds 1KB limit + axios.mockResolvedValue({ data: largeBuffer }); + + const result = await processCodeOutput(baseParams); + + expect(logger.warn).toHaveBeenCalledWith(expect.stringContaining('exceeds size limit')); + expect(result.filepath).toContain('/api/files/code/download/session-123/file-id-123'); + expect(result.expiresAt).toBeDefined(); + // Should not call createFile for oversized files (fallback path) + expect(createFile).not.toHaveBeenCalled(); + + // Reset to default for other tests + fileSizeLimitConfig.value = 20 * 1024 * 1024; + }); + }); + + describe('fallback behavior', () => { + it('should fallback to download URL when saveBuffer is not available', async () => { + const smallBuffer = Buffer.alloc(100); + axios.mockResolvedValue({ data: smallBuffer }); + getStrategyFunctions.mockReturnValue({ saveBuffer: null }); + + const result = await processCodeOutput(baseParams); + + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('saveBuffer not available'), + ); + expect(result.filepath).toContain('/api/files/code/download/'); + expect(result.filename).toBe('test-file.txt'); + }); + + it('should fallback to download URL on axios error', async () => { + axios.mockRejectedValue(new Error('Network error')); + + const result = await processCodeOutput(baseParams); + + expect(result.filepath).toContain('/api/files/code/download/session-123/file-id-123'); + expect(result.conversationId).toBe('conv-123'); + expect(result.messageId).toBe('msg-123'); + expect(result.toolCallId).toBe('tool-call-123'); + }); + }); + + describe('usage counter increment', () => { + it('should set usage to 1 for new files', async () => { + const smallBuffer = Buffer.alloc(100); + axios.mockResolvedValue({ data: smallBuffer }); + + const result = await processCodeOutput(baseParams); + + expect(result.usage).toBe(1); + }); + + it('should increment usage for existing files', async () => { + mockClaimCodeFile.mockResolvedValue({ + file_id: 'existing-id', + usage: 5, + createdAt: '2024-01-01', + }); + const smallBuffer = Buffer.alloc(100); + axios.mockResolvedValue({ data: smallBuffer }); + + const result = await processCodeOutput(baseParams); + + expect(result.usage).toBe(6); + }); + + it('should handle existing file with undefined usage', async () => { + mockClaimCodeFile.mockResolvedValue({ + file_id: 'existing-id', + createdAt: '2024-01-01', + }); + const smallBuffer = Buffer.alloc(100); + axios.mockResolvedValue({ data: smallBuffer }); + + const result = await processCodeOutput(baseParams); + + expect(result.usage).toBe(1); + }); + }); + + describe('metadata and file properties', () => { + it('should include fileIdentifier in metadata', async () => { + const smallBuffer = Buffer.alloc(100); + axios.mockResolvedValue({ data: smallBuffer }); + + const result = await processCodeOutput(baseParams); + + expect(result.metadata).toEqual({ + fileIdentifier: 'session-123/file-id-123', + }); + }); + + it('should set correct context for code-generated files', async () => { + const smallBuffer = Buffer.alloc(100); + axios.mockResolvedValue({ data: smallBuffer }); + + const result = await processCodeOutput(baseParams); + + expect(result.context).toBe(FileContext.execute_code); + }); + + it('should include toolCallId and messageId in result', async () => { + const smallBuffer = Buffer.alloc(100); + axios.mockResolvedValue({ data: smallBuffer }); + + const result = await processCodeOutput(baseParams); + + expect(result.toolCallId).toBe('tool-call-123'); + expect(result.messageId).toBe('msg-123'); + }); + + it('should call createFile with upsert enabled', async () => { + const smallBuffer = Buffer.alloc(100); + axios.mockResolvedValue({ data: smallBuffer }); + + await processCodeOutput(baseParams); + + expect(createFile).toHaveBeenCalledWith( + expect.objectContaining({ + file_id: 'mock-uuid-1234', + context: FileContext.execute_code, + }), + true, // upsert flag + ); + }); + }); + }); +}); diff --git a/api/server/services/Files/Firebase/crud.js b/api/server/services/Files/Firebase/crud.js index 170df45677..d5e5a409bf 100644 --- a/api/server/services/Files/Firebase/crud.js +++ b/api/server/services/Files/Firebase/crud.js @@ -3,7 +3,7 @@ const path = require('path'); const axios = require('axios'); const fetch = require('node-fetch'); const { logger } = require('@librechat/data-schemas'); -const { getFirebaseStorage } = require('@librechat/api'); +const { getFirebaseStorage, deleteRagFile } = require('@librechat/api'); const { ref, uploadBytes, getDownloadURL, deleteObject } = require('firebase/storage'); const { getBufferMetadata } = require('~/server/utils'); @@ -167,27 +167,7 @@ function extractFirebaseFilePath(urlString) { * Throws an error if there is an issue with deletion. */ const deleteFirebaseFile = async (req, file) => { - if (file.embedded && process.env.RAG_API_URL) { - const jwtToken = req.headers.authorization.split(' ')[1]; - try { - await axios.delete(`${process.env.RAG_API_URL}/documents`, { - headers: { - Authorization: `Bearer ${jwtToken}`, - 'Content-Type': 'application/json', - accept: 'application/json', - }, - data: [file.file_id], - }); - } catch (error) { - if (error.response?.status === 404) { - logger.warn( - `[deleteFirebaseFile] Document ${file.file_id} not found in RAG API, may have been deleted already`, - ); - } else { - logger.error('[deleteFirebaseFile] Error deleting document from RAG API:', error); - } - } - } + await deleteRagFile({ userId: req.user.id, file }); const fileName = extractFirebaseFilePath(file.filepath); if (!fileName.includes(req.user.id)) { diff --git a/api/server/services/Files/Local/crud.js b/api/server/services/Files/Local/crud.js index db553f57dd..1f38a01f83 100644 --- a/api/server/services/Files/Local/crud.js +++ b/api/server/services/Files/Local/crud.js @@ -1,9 +1,9 @@ const fs = require('fs'); const path = require('path'); const axios = require('axios'); +const { deleteRagFile } = require('@librechat/api'); const { logger } = require('@librechat/data-schemas'); const { EModelEndpoint } = require('librechat-data-provider'); -const { generateShortLivedToken } = require('@librechat/api'); const { resizeImageBuffer } = require('~/server/services/Files/images/resize'); const { getBufferMetadata } = require('~/server/utils'); const paths = require('~/config/paths'); @@ -67,7 +67,12 @@ async function saveLocalBuffer({ userId, buffer, fileName, basePath = 'images' } try { const { publicPath, uploads } = paths; - const directoryPath = path.join(basePath === 'images' ? publicPath : uploads, basePath, userId); + /** + * For 'images': save to publicPath/images/userId (images are served statically) + * For 'uploads': save to uploads/userId (files downloaded via API) + * */ + const directoryPath = + basePath === 'images' ? path.join(publicPath, basePath, userId) : path.join(uploads, userId); if (!fs.existsSync(directoryPath)) { fs.mkdirSync(directoryPath, { recursive: true }); @@ -208,27 +213,7 @@ const deleteLocalFile = async (req, file) => { /** Filepath stripped of query parameters (e.g., ?manual=true) */ const cleanFilepath = file.filepath.split('?')[0]; - if (file.embedded && process.env.RAG_API_URL) { - const jwtToken = generateShortLivedToken(req.user.id); - try { - await axios.delete(`${process.env.RAG_API_URL}/documents`, { - headers: { - Authorization: `Bearer ${jwtToken}`, - 'Content-Type': 'application/json', - accept: 'application/json', - }, - data: [file.file_id], - }); - } catch (error) { - if (error.response?.status === 404) { - logger.warn( - `[deleteLocalFile] Document ${file.file_id} not found in RAG API, may have been deleted already`, - ); - } else { - logger.error('[deleteLocalFile] Error deleting document from RAG API:', error); - } - } - } + await deleteRagFile({ userId: req.user.id, file }); if (cleanFilepath.startsWith(`/uploads/${req.user.id}`)) { const userUploadDir = path.join(uploads, req.user.id); diff --git a/api/server/services/Files/S3/crud.js b/api/server/services/Files/S3/crud.js index 8dac767aa2..0721e33b29 100644 --- a/api/server/services/Files/S3/crud.js +++ b/api/server/services/Files/S3/crud.js @@ -1,9 +1,9 @@ const fs = require('fs'); const fetch = require('node-fetch'); -const { initializeS3 } = require('@librechat/api'); const { logger } = require('@librechat/data-schemas'); const { FileSources } = require('librechat-data-provider'); const { getSignedUrl } = require('@aws-sdk/s3-request-presigner'); +const { initializeS3, deleteRagFile } = require('@librechat/api'); const { PutObjectCommand, GetObjectCommand, @@ -142,6 +142,8 @@ async function saveURLToS3({ userId, URL, fileName, basePath = defaultBasePath } * @returns {Promise} */ async function deleteFileFromS3(req, file) { + await deleteRagFile({ userId: req.user.id, file }); + const key = extractKeyFromS3Url(file.filepath); const params = { Bucket: bucketName, Key: key }; if (!key.includes(req.user.id)) { diff --git a/api/server/services/MCP.js b/api/server/services/MCP.js index 81d7107de4..ad1f9f5cc3 100644 --- a/api/server/services/MCP.js +++ b/api/server/services/MCP.js @@ -1,4 +1,3 @@ -const { z } = require('zod'); const { tool } = require('@langchain/core/tools'); const { logger } = require('@librechat/data-schemas'); const { @@ -12,8 +11,9 @@ const { MCPOAuthHandler, isMCPDomainAllowed, normalizeServerName, - convertWithResolvedRefs, + normalizeJsonSchema, GenerationJobManager, + resolveJsonSchemaRefs, } = require('@librechat/api'); const { Time, @@ -29,10 +29,21 @@ const { getMCPManager, } = require('~/config'); const { findToken, createToken, updateToken } = require('~/models'); +const { getGraphApiToken } = require('./GraphTokenService'); const { reinitMCPServer } = require('./Tools/mcp'); const { getAppConfig } = require('./Config'); const { getLogStores } = require('~/cache'); +function isEmptyObjectSchema(jsonSchema) { + return ( + jsonSchema != null && + typeof jsonSchema === 'object' && + jsonSchema.type === 'object' && + (jsonSchema.properties == null || Object.keys(jsonSchema.properties).length === 0) && + !jsonSchema.additionalProperties + ); +} + /** * @param {object} params * @param {ServerResponse} params.res - The Express response object for sending events. @@ -43,9 +54,9 @@ const { getLogStores } = require('~/cache'); function createRunStepDeltaEmitter({ res, stepId, toolCall, streamId = null }) { /** * @param {string} authURL - The URL to redirect the user for OAuth authentication. - * @returns {void} + * @returns {Promise} */ - return function (authURL) { + return async function (authURL) { /** @type {{ id: string; delta: AgentToolCallDelta }} */ const data = { id: stepId, @@ -58,7 +69,7 @@ function createRunStepDeltaEmitter({ res, stepId, toolCall, streamId = null }) { }; const eventData = { event: GraphEvents.ON_RUN_STEP_DELTA, data }; if (streamId) { - GenerationJobManager.emitChunk(streamId, eventData); + await GenerationJobManager.emitChunk(streamId, eventData); } else { sendEvent(res, eventData); } @@ -73,9 +84,10 @@ function createRunStepDeltaEmitter({ res, stepId, toolCall, streamId = null }) { * @param {ToolCallChunk} params.toolCall - The tool call object containing tool information. * @param {number} [params.index] * @param {string | null} [params.streamId] - The stream ID for resumable mode. + * @returns {() => Promise} */ function createRunStepEmitter({ res, runId, stepId, toolCall, index, streamId = null }) { - return function () { + return async function () { /** @type {import('@librechat/agents').RunStep} */ const data = { runId: runId ?? Constants.USE_PRELIM_RESPONSE_MESSAGE_ID, @@ -89,7 +101,7 @@ function createRunStepEmitter({ res, runId, stepId, toolCall, index, streamId = }; const eventData = { event: GraphEvents.ON_RUN_STEP, data }; if (streamId) { - GenerationJobManager.emitChunk(streamId, eventData); + await GenerationJobManager.emitChunk(streamId, eventData); } else { sendEvent(res, eventData); } @@ -137,7 +149,7 @@ function createOAuthEnd({ res, stepId, toolCall, streamId = null }) { }; const eventData = { event: GraphEvents.ON_RUN_STEP_DELTA, data }; if (streamId) { - GenerationJobManager.emitChunk(streamId, eventData); + await GenerationJobManager.emitChunk(streamId, eventData); } else { sendEvent(res, eventData); } @@ -196,6 +208,9 @@ async function reconnectServer({ userMCPAuthMap, streamId = null, }) { + logger.debug( + `[MCP][reconnectServer] serverName: ${serverName}, user: ${user?.id}, hasUserMCPAuthMap: ${!!userMCPAuthMap}`, + ); const runId = Constants.USE_PRELIM_RESPONSE_MESSAGE_ID; const flowId = `${user.id}:${serverName}:${Date.now()}`; const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS)); @@ -428,13 +443,17 @@ function createToolInstance({ /** @type {LCTool} */ const { description, parameters } = toolDefinition; const isGoogle = _provider === Providers.VERTEXAI || _provider === Providers.GOOGLE; - let schema = convertWithResolvedRefs(parameters, { - allowEmptyObject: !isGoogle, - transformOneOfAnyOf: true, - }); - if (!schema) { - schema = z.object({ input: z.string().optional() }); + let schema = parameters ? normalizeJsonSchema(resolveJsonSchemaRefs(parameters)) : null; + + if (!schema || (isGoogle && isEmptyObjectSchema(schema))) { + schema = { + type: 'object', + properties: { + input: { type: 'string', description: 'Input for the tool' }, + }, + required: [], + }; } const normalizedToolKey = `${toolName}${Constants.mcp_delimiter}${normalizeServerName(serverName)}`; @@ -501,6 +520,7 @@ function createToolInstance({ }, oauthStart, oauthEnd, + graphTokenResolver: getGraphApiToken, }); if (isAssistantsEndpoint(provider) && Array.isArray(result)) { @@ -548,6 +568,7 @@ function createToolInstance({ }); toolInstance.mcp = true; toolInstance.mcpRawServerName = serverName; + toolInstance.mcpJsonSchema = parameters; return toolInstance; } diff --git a/api/server/services/MCP.spec.js b/api/server/services/MCP.spec.js index cb2f0081a3..b2caebc91e 100644 --- a/api/server/services/MCP.spec.js +++ b/api/server/services/MCP.spec.js @@ -9,30 +9,6 @@ jest.mock('@librechat/data-schemas', () => ({ }, })); -jest.mock('@langchain/core/tools', () => ({ - tool: jest.fn((fn, config) => { - const toolInstance = { _call: fn, ...config }; - return toolInstance; - }), -})); - -jest.mock('@librechat/agents', () => ({ - Providers: { - VERTEXAI: 'vertexai', - GOOGLE: 'google', - }, - StepTypes: { - TOOL_CALLS: 'tool_calls', - }, - GraphEvents: { - ON_RUN_STEP_DELTA: 'on_run_step_delta', - ON_RUN_STEP: 'on_run_step', - }, - Constants: { - CONTENT_AND_ARTIFACT: 'content_and_artifact', - }, -})); - // Create mock registry instance const mockRegistryInstance = { getOAuthServers: jest.fn(() => Promise.resolve(new Set())), @@ -46,26 +22,23 @@ const mockIsMCPDomainAllowed = jest.fn(() => Promise.resolve(true)); const mockGetAppConfig = jest.fn(() => Promise.resolve({})); jest.mock('@librechat/api', () => { - // Access mock via getter to avoid hoisting issues + const actual = jest.requireActual('@librechat/api'); return { - MCPOAuthHandler: { - generateFlowId: jest.fn(), - }, + ...actual, sendEvent: jest.fn(), - normalizeServerName: jest.fn((name) => name), - convertWithResolvedRefs: jest.fn((params) => params), get isMCPDomainAllowed() { return mockIsMCPDomainAllowed; }, - MCPServersRegistry: { - getInstance: () => mockRegistryInstance, + GenerationJobManager: { + emitChunk: jest.fn(), }, }; }); const { logger } = require('@librechat/data-schemas'); const { MCPOAuthHandler } = require('@librechat/api'); -const { CacheKeys } = require('librechat-data-provider'); +const { CacheKeys, Constants } = require('librechat-data-provider'); +const D = Constants.mcp_delimiter; const { createMCPTool, createMCPTools, @@ -74,24 +47,6 @@ const { getServerConnectionStatus, } = require('./MCP'); -jest.mock('librechat-data-provider', () => ({ - CacheKeys: { - FLOWS: 'flows', - }, - Constants: { - USE_PRELIM_RESPONSE_MESSAGE_ID: 'prelim_response_id', - mcp_delimiter: '::', - mcp_prefix: 'mcp_', - }, - ContentTypes: { - TEXT: 'text', - }, - isAssistantsEndpoint: jest.fn(() => false), - Time: { - TWO_MINUTES: 120000, - }, -})); - jest.mock('./Config', () => ({ loadCustomConfig: jest.fn(), get getAppConfig() { @@ -120,6 +75,10 @@ jest.mock('./Tools/mcp', () => ({ reinitMCPServer: jest.fn(), })); +jest.mock('./GraphTokenService', () => ({ + getGraphApiToken: jest.fn(), +})); + describe('tests for the new helper functions used by the MCP connection status endpoints', () => { let mockGetMCPManager; let mockGetFlowStateManager; @@ -128,6 +87,7 @@ describe('tests for the new helper functions used by the MCP connection status e beforeEach(() => { jest.clearAllMocks(); + jest.spyOn(MCPOAuthHandler, 'generateFlowId'); mockGetMCPManager = require('~/config').getMCPManager; mockGetFlowStateManager = require('~/config').getFlowStateManager; @@ -731,7 +691,7 @@ describe('User parameter passing tests', () => { mockReinitMCPServer.mockResolvedValue({ tools: [{ name: 'test-tool' }], availableTools: { - 'test-tool::test-server': { + [`test-tool${D}test-server`]: { function: { description: 'Test tool', parameters: { type: 'object', properties: {} }, @@ -791,7 +751,7 @@ describe('User parameter passing tests', () => { mockReinitMCPServer.mockResolvedValue({ availableTools: { - 'test-tool::test-server': { + [`test-tool${D}test-server`]: { function: { description: 'Test tool', parameters: { type: 'object', properties: {} }, @@ -804,7 +764,7 @@ describe('User parameter passing tests', () => { await createMCPTool({ res: mockRes, user: mockUser, - toolKey: 'test-tool::test-server', + toolKey: `test-tool${D}test-server`, provider: 'openai', signal: mockSignal, userMCPAuthMap: {}, @@ -826,7 +786,7 @@ describe('User parameter passing tests', () => { const mockRes = { write: jest.fn(), flush: jest.fn() }; const availableTools = { - 'test-tool::test-server': { + [`test-tool${D}test-server`]: { function: { description: 'Cached tool', parameters: { type: 'object', properties: {} }, @@ -837,7 +797,7 @@ describe('User parameter passing tests', () => { await createMCPTool({ res: mockRes, user: mockUser, - toolKey: 'test-tool::test-server', + toolKey: `test-tool${D}test-server`, provider: 'openai', userMCPAuthMap: {}, availableTools: availableTools, @@ -860,8 +820,8 @@ describe('User parameter passing tests', () => { return Promise.resolve({ tools: [{ name: 'tool1' }, { name: 'tool2' }], availableTools: { - 'tool1::server1': { function: { description: 'Tool 1', parameters: {} } }, - 'tool2::server1': { function: { description: 'Tool 2', parameters: {} } }, + [`tool1${D}server1`]: { function: { description: 'Tool 1', parameters: {} } }, + [`tool2${D}server1`]: { function: { description: 'Tool 2', parameters: {} } }, }, }); }); @@ -892,7 +852,7 @@ describe('User parameter passing tests', () => { reinitCalls.push(params); return Promise.resolve({ availableTools: { - 'my-tool::my-server': { + [`my-tool${D}my-server`]: { function: { description: 'My Tool', parameters: {} }, }, }, @@ -902,7 +862,7 @@ describe('User parameter passing tests', () => { await createMCPTool({ res: mockRes, user: mockUser, - toolKey: 'my-tool::my-server', + toolKey: `my-tool${D}my-server`, provider: 'google', userMCPAuthMap: {}, availableTools: undefined, // Force reinit @@ -936,11 +896,11 @@ describe('User parameter passing tests', () => { const result = await createMCPTool({ res: mockRes, user: mockUser, - toolKey: 'test-tool::test-server', + toolKey: `test-tool${D}test-server`, provider: 'openai', userMCPAuthMap: {}, availableTools: { - 'test-tool::test-server': { + [`test-tool${D}test-server`]: { function: { description: 'Test tool', parameters: { type: 'object', properties: {} }, @@ -983,7 +943,7 @@ describe('User parameter passing tests', () => { mockIsMCPDomainAllowed.mockResolvedValueOnce(true); const availableTools = { - 'test-tool::test-server': { + [`test-tool${D}test-server`]: { function: { description: 'Test tool', parameters: { type: 'object', properties: {} }, @@ -994,7 +954,7 @@ describe('User parameter passing tests', () => { const result = await createMCPTool({ res: mockRes, user: mockUser, - toolKey: 'test-tool::test-server', + toolKey: `test-tool${D}test-server`, provider: 'openai', userMCPAuthMap: {}, availableTools, @@ -1023,7 +983,7 @@ describe('User parameter passing tests', () => { }); const availableTools = { - 'test-tool::test-server': { + [`test-tool${D}test-server`]: { function: { description: 'Test tool', parameters: { type: 'object', properties: {} }, @@ -1034,7 +994,7 @@ describe('User parameter passing tests', () => { const result = await createMCPTool({ res: mockRes, user: mockUser, - toolKey: 'test-tool::test-server', + toolKey: `test-tool${D}test-server`, provider: 'openai', userMCPAuthMap: {}, availableTools, @@ -1100,7 +1060,7 @@ describe('User parameter passing tests', () => { mockIsMCPDomainAllowed.mockResolvedValue(true); const availableTools = { - 'test-tool::test-server': { + [`test-tool${D}test-server`]: { function: { description: 'Test tool', parameters: { type: 'object', properties: {} }, @@ -1112,7 +1072,7 @@ describe('User parameter passing tests', () => { await createMCPTool({ res: mockRes, user: adminUser, - toolKey: 'test-tool::test-server', + toolKey: `test-tool${D}test-server`, provider: 'openai', userMCPAuthMap: {}, availableTools, @@ -1126,7 +1086,7 @@ describe('User parameter passing tests', () => { await createMCPTool({ res: mockRes, user: regularUser, - toolKey: 'test-tool::test-server', + toolKey: `test-tool${D}test-server`, provider: 'openai', userMCPAuthMap: {}, availableTools, @@ -1154,7 +1114,7 @@ describe('User parameter passing tests', () => { return Promise.resolve({ tools: [{ name: 'test' }], availableTools: { - 'test::server': { function: { description: 'Test', parameters: {} } }, + [`test${D}server`]: { function: { description: 'Test', parameters: {} } }, }, }); }); diff --git a/api/server/services/PermissionService.js b/api/server/services/PermissionService.js index c35faf7c8d..a843f48f6f 100644 --- a/api/server/services/PermissionService.js +++ b/api/server/services/PermissionService.js @@ -141,7 +141,6 @@ const checkPermission = async ({ userId, role, resourceType, resourceId, require validateResourceType(resourceType); - // Get all principals for the user (user + groups + public) const principals = await getUserPrincipals({ userId, role }); if (principals.length === 0) { @@ -151,7 +150,6 @@ const checkPermission = async ({ userId, role, resourceType, resourceId, require return await hasPermission(principals, resourceType, resourceId, requiredPermission); } catch (error) { logger.error(`[PermissionService.checkPermission] Error: ${error.message}`); - // Re-throw validation errors if (error.message.includes('requiredPermission must be')) { throw error; } @@ -172,12 +170,12 @@ const getEffectivePermissions = async ({ userId, role, resourceType, resourceId try { validateResourceType(resourceType); - // Get all principals for the user (user + groups + public) const principals = await getUserPrincipals({ userId, role }); if (principals.length === 0) { return 0; } + return await getEffectivePermissionsACL(principals, resourceType, resourceId); } catch (error) { logger.error(`[PermissionService.getEffectivePermissions] Error: ${error.message}`); diff --git a/api/server/services/ToolService.js b/api/server/services/ToolService.js index 62d25b23eb..eedb95bd4d 100644 --- a/api/server/services/ToolService.js +++ b/api/server/services/ToolService.js @@ -1,24 +1,43 @@ -const { sleep } = require('@librechat/agents'); const { logger } = require('@librechat/data-schemas'); const { tool: toolFn, DynamicStructuredTool } = require('@langchain/core/tools'); const { + sleep, + EnvVar, + StepTypes, + GraphEvents, + createToolSearch, + Constants: AgentConstants, + createProgrammaticToolCallingTool, +} = require('@librechat/agents'); +const { + sendEvent, getToolkitKey, hasCustomUserVars, getUserMCPAuthMap, + loadToolDefinitions, + GenerationJobManager, isActionDomainAllowed, + buildWebSearchContext, + buildImageToolContext, + buildToolClassification, } = require('@librechat/api'); const { + Time, Tools, + Constants, + CacheKeys, ErrorTypes, ContentTypes, imageGenTools, EModelEndpoint, + EToolResources, actionDelimiter, ImageVisionTool, openapiToFunction, AgentCapabilities, isEphemeralAgentId, validateActionDomain, + actionDomainSeparator, defaultAgentCapabilities, validateAndParseOpenAPISpec, } = require('librechat-data-provider'); @@ -28,14 +47,24 @@ const { loadActionSets, domainParser, } = require('./ActionService'); +const { + getEndpointsConfig, + getMCPServerTools, + getCachedTools, +} = require('~/server/services/Config'); const { processFileURL, uploadImageBuffer } = require('~/server/services/Files/process'); -const { getEndpointsConfig, getCachedTools } = require('~/server/services/Config'); +const { primeFiles: primeSearchFiles } = require('~/app/clients/tools/util/fileSearch'); +const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process'); const { manifestToolMap, toolkits } = require('~/app/clients/tools/manifest'); const { createOnSearchResults } = require('~/server/services/Tools/search'); +const { loadAuthValues } = require('~/server/services/Tools/credentials'); +const { reinitMCPServer } = require('~/server/services/Tools/mcp'); const { recordUsage } = require('~/server/services/Threads'); const { loadTools } = require('~/app/clients/tools/util'); const { redactMessage } = require('~/config/parsers'); const { findPluginAuthsByKeys } = require('~/models'); +const { getFlowStateManager } = require('~/config'); +const { getLogStores } = require('~/cache'); /** * Processes the required actions by calling the appropriate tools and returning the outputs. * @param {OpenAIClient} client - OpenAI or StreamRunManager Client. @@ -309,6 +338,7 @@ async function processRequiredActions(client, requiredActions) { } // We've already decrypted the metadata, so we can pass it directly + const _allowedDomains = appConfig?.actions?.allowedDomains; tool = await createActionTool({ userId: client.req.user.id, res: client.res, @@ -316,6 +346,7 @@ async function processRequiredActions(client, requiredActions) { requestBuilder, // Note: intentionally not passing zodSchema, name, and description for assistants API encrypted, // Pass the encrypted values for OAuth flow + useSSRFProtection: !Array.isArray(_allowedDomains) || _allowedDomains.length === 0, }); if (!tool) { logger.warn( @@ -367,7 +398,390 @@ async function processRequiredActions(client, requiredActions) { * @param {AbortSignal} params.signal * @param {Pick> }>} The agent tools. + * @returns {Promise<{ + * tools?: StructuredTool[]; + * toolContextMap?: Record; + * userMCPAuthMap?: Record>; + * toolRegistry?: Map; + * hasDeferredTools?: boolean; + * }>} The agent tools and registry. + */ +/** Native LibreChat tools that are not in the manifest */ +const nativeTools = new Set([Tools.execute_code, Tools.file_search, Tools.web_search]); + +/** Checks if a tool name is a known built-in tool */ +const isBuiltInTool = (toolName) => + Boolean( + manifestToolMap[toolName] || + toolkits.some((t) => t.pluginKey === toolName) || + nativeTools.has(toolName), + ); + +/** + * Loads only tool definitions without creating tool instances. + * This is the efficient path for event-driven mode where tools are loaded on-demand. + * + * @param {Object} params + * @param {ServerRequest} params.req - The request object + * @param {ServerResponse} [params.res] - The response object for SSE events + * @param {Object} params.agent - The agent configuration + * @param {string|null} [params.streamId] - Stream ID for resumable mode + * @returns {Promise<{ + * toolDefinitions?: import('@librechat/api').LCTool[]; + * toolRegistry?: Map; + * userMCPAuthMap?: Record>; + * hasDeferredTools?: boolean; + * }>} + */ +async function loadToolDefinitionsWrapper({ req, res, agent, streamId = null, tool_resources }) { + if (!agent.tools || agent.tools.length === 0) { + return { toolDefinitions: [] }; + } + + if ( + agent.tools.length === 1 && + (agent.tools[0] === AgentCapabilities.context || agent.tools[0] === AgentCapabilities.ocr) + ) { + return { toolDefinitions: [] }; + } + + const appConfig = req.config; + const endpointsConfig = await getEndpointsConfig(req); + let enabledCapabilities = new Set(endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? []); + + if (enabledCapabilities.size === 0 && isEphemeralAgentId(agent.id)) { + enabledCapabilities = new Set( + appConfig.endpoints?.[EModelEndpoint.agents]?.capabilities ?? defaultAgentCapabilities, + ); + } + + const checkCapability = (capability) => enabledCapabilities.has(capability); + const areToolsEnabled = checkCapability(AgentCapabilities.tools); + const deferredToolsEnabled = checkCapability(AgentCapabilities.deferred_tools); + + const filteredTools = agent.tools?.filter((tool) => { + if (tool === Tools.file_search) { + return checkCapability(AgentCapabilities.file_search); + } + if (tool === Tools.execute_code) { + return checkCapability(AgentCapabilities.execute_code); + } + if (tool === Tools.web_search) { + return checkCapability(AgentCapabilities.web_search); + } + if (!areToolsEnabled && !tool.includes(actionDelimiter)) { + return false; + } + return true; + }); + + if (!filteredTools || filteredTools.length === 0) { + return { toolDefinitions: [] }; + } + + /** @type {Record>} */ + let userMCPAuthMap; + if (hasCustomUserVars(req.config)) { + userMCPAuthMap = await getUserMCPAuthMap({ + tools: agent.tools, + userId: req.user.id, + findPluginAuthsByKeys, + }); + } + + const flowsCache = getLogStores(CacheKeys.FLOWS); + const flowManager = getFlowStateManager(flowsCache); + const pendingOAuthServers = new Set(); + + const createOAuthEmitter = (serverName) => { + return async (authURL) => { + const flowId = `${req.user.id}:${serverName}:${Date.now()}`; + const stepId = 'step_oauth_login_' + serverName; + const toolCall = { + id: flowId, + name: serverName, + type: 'tool_call_chunk', + }; + + const runStepData = { + runId: Constants.USE_PRELIM_RESPONSE_MESSAGE_ID, + id: stepId, + type: StepTypes.TOOL_CALLS, + index: 0, + stepDetails: { + type: StepTypes.TOOL_CALLS, + tool_calls: [toolCall], + }, + }; + + const runStepDeltaData = { + id: stepId, + delta: { + type: StepTypes.TOOL_CALLS, + tool_calls: [{ ...toolCall, args: '' }], + auth: authURL, + expires_at: Date.now() + Time.TWO_MINUTES, + }, + }; + + const runStepEvent = { event: GraphEvents.ON_RUN_STEP, data: runStepData }; + const runStepDeltaEvent = { event: GraphEvents.ON_RUN_STEP_DELTA, data: runStepDeltaData }; + + if (streamId) { + await GenerationJobManager.emitChunk(streamId, runStepEvent); + await GenerationJobManager.emitChunk(streamId, runStepDeltaEvent); + } else if (res && !res.writableEnded) { + sendEvent(res, runStepEvent); + sendEvent(res, runStepDeltaEvent); + } else { + logger.warn( + `[Tool Definitions] Cannot emit OAuth event for ${serverName}: no streamId and res not available`, + ); + } + }; + }; + + const getOrFetchMCPServerTools = async (userId, serverName) => { + const cached = await getMCPServerTools(userId, serverName); + if (cached) { + return cached; + } + + const oauthStart = async () => { + pendingOAuthServers.add(serverName); + }; + + const result = await reinitMCPServer({ + user: req.user, + oauthStart, + flowManager, + serverName, + userMCPAuthMap, + }); + + return result?.availableTools || null; + }; + + const getActionToolDefinitions = async (agentId, actionToolNames) => { + const actionSets = (await loadActionSets({ agent_id: agentId })) ?? []; + if (actionSets.length === 0) { + return []; + } + + const definitions = []; + const allowedDomains = appConfig?.actions?.allowedDomains; + const domainSeparatorRegex = new RegExp(actionDomainSeparator, 'g'); + + for (const action of actionSets) { + const domain = await domainParser(action.metadata.domain, true); + const normalizedDomain = domain.replace(domainSeparatorRegex, '_'); + + const isDomainAllowed = await isActionDomainAllowed(action.metadata.domain, allowedDomains); + if (!isDomainAllowed) { + logger.warn( + `[Actions] Domain "${action.metadata.domain}" not in allowedDomains. ` + + `Add it to librechat.yaml actions.allowedDomains to enable this action.`, + ); + continue; + } + + const validationResult = validateAndParseOpenAPISpec(action.metadata.raw_spec); + if (!validationResult.spec || !validationResult.serverUrl) { + logger.warn(`[Actions] Invalid OpenAPI spec for domain: ${domain}`); + continue; + } + + const { functionSignatures } = openapiToFunction(validationResult.spec, true); + + for (const sig of functionSignatures) { + const toolName = `${sig.name}${actionDelimiter}${normalizedDomain}`; + if (!actionToolNames.some((name) => name.replace(domainSeparatorRegex, '_') === toolName)) { + continue; + } + + definitions.push({ + name: toolName, + description: sig.description, + parameters: sig.parameters, + }); + } + } + + return definitions; + }; + + let { toolDefinitions, toolRegistry, hasDeferredTools } = await loadToolDefinitions( + { + userId: req.user.id, + agentId: agent.id, + tools: filteredTools, + toolOptions: agent.tool_options, + deferredToolsEnabled, + }, + { + isBuiltInTool, + loadAuthValues, + getOrFetchMCPServerTools, + getActionToolDefinitions, + }, + ); + + if (pendingOAuthServers.size > 0 && (res || streamId)) { + const serverNames = Array.from(pendingOAuthServers); + logger.info( + `[Tool Definitions] OAuth required for ${serverNames.length} server(s): ${serverNames.join(', ')}. Emitting events and waiting.`, + ); + + const oauthWaitPromises = serverNames.map(async (serverName) => { + try { + const result = await reinitMCPServer({ + user: req.user, + serverName, + userMCPAuthMap, + flowManager, + returnOnOAuth: false, + oauthStart: createOAuthEmitter(serverName), + connectionTimeout: Time.TWO_MINUTES, + }); + + if (result?.availableTools) { + logger.info(`[Tool Definitions] OAuth completed for ${serverName}, tools available`); + return { serverName, success: true }; + } + return { serverName, success: false }; + } catch (error) { + logger.debug(`[Tool Definitions] OAuth wait failed for ${serverName}:`, error?.message); + return { serverName, success: false }; + } + }); + + const results = await Promise.allSettled(oauthWaitPromises); + const successfulServers = results + .filter((r) => r.status === 'fulfilled' && r.value.success) + .map((r) => r.value.serverName); + + if (successfulServers.length > 0) { + logger.info( + `[Tool Definitions] Reloading tools after OAuth for: ${successfulServers.join(', ')}`, + ); + const reloadResult = await loadToolDefinitions( + { + userId: req.user.id, + agentId: agent.id, + tools: filteredTools, + toolOptions: agent.tool_options, + deferredToolsEnabled, + }, + { + isBuiltInTool, + loadAuthValues, + getOrFetchMCPServerTools, + getActionToolDefinitions, + }, + ); + toolDefinitions = reloadResult.toolDefinitions; + toolRegistry = reloadResult.toolRegistry; + hasDeferredTools = reloadResult.hasDeferredTools; + } + } + + /** @type {Record} */ + const toolContextMap = {}; + const hasWebSearch = filteredTools.includes(Tools.web_search); + const hasFileSearch = filteredTools.includes(Tools.file_search); + const hasExecuteCode = filteredTools.includes(Tools.execute_code); + + if (hasWebSearch) { + toolContextMap[Tools.web_search] = buildWebSearchContext(); + } + + if (hasExecuteCode && tool_resources) { + try { + const authValues = await loadAuthValues({ + userId: req.user.id, + authFields: [EnvVar.CODE_API_KEY], + }); + const codeApiKey = authValues[EnvVar.CODE_API_KEY]; + + if (codeApiKey) { + const { toolContext } = await primeCodeFiles( + { req, tool_resources, agentId: agent.id }, + codeApiKey, + ); + if (toolContext) { + toolContextMap[Tools.execute_code] = toolContext; + } + } + } catch (error) { + logger.error('[loadToolDefinitionsWrapper] Error priming code files:', error); + } + } + + if (hasFileSearch && tool_resources) { + try { + const { toolContext } = await primeSearchFiles({ + req, + tool_resources, + agentId: agent.id, + }); + if (toolContext) { + toolContextMap[Tools.file_search] = toolContext; + } + } catch (error) { + logger.error('[loadToolDefinitionsWrapper] Error priming search files:', error); + } + } + + const imageFiles = tool_resources?.[EToolResources.image_edit]?.files ?? []; + if (imageFiles.length > 0) { + const hasOaiImageGen = filteredTools.includes('image_gen_oai'); + const hasGeminiImageGen = filteredTools.includes('gemini_image_gen'); + + if (hasOaiImageGen) { + const toolContext = buildImageToolContext({ + imageFiles, + toolName: `${EToolResources.image_edit}_oai`, + contextDescription: 'image editing', + }); + if (toolContext) { + toolContextMap.image_edit_oai = toolContext; + } + } + + if (hasGeminiImageGen) { + const toolContext = buildImageToolContext({ + imageFiles, + toolName: 'gemini_image_gen', + contextDescription: 'image context', + }); + if (toolContext) { + toolContextMap.gemini_image_gen = toolContext; + } + } + } + + return { + toolRegistry, + userMCPAuthMap, + toolContextMap, + toolDefinitions, + hasDeferredTools, + }; +} + +/** + * Loads agent tools for initialization or execution. + * @param {Object} params + * @param {ServerRequest} params.req - The request object + * @param {ServerResponse} params.res - The response object + * @param {Object} params.agent - The agent configuration + * @param {AbortSignal} [params.signal] - Abort signal + * @param {Object} [params.tool_resources] - Tool resources + * @param {string} [params.openAIApiKey] - OpenAI API key + * @param {string|null} [params.streamId] - Stream ID for resumable mode + * @param {boolean} [params.definitionsOnly=true] - When true, returns only serializable + * tool definitions without creating full tool instances. Use for event-driven mode + * where tools are loaded on-demand during execution. */ async function loadAgentTools({ req, @@ -377,16 +791,21 @@ async function loadAgentTools({ tool_resources, openAIApiKey, streamId = null, + definitionsOnly = true, }) { + if (definitionsOnly) { + return loadToolDefinitionsWrapper({ req, res, agent, streamId, tool_resources }); + } + if (!agent.tools || agent.tools.length === 0) { - return {}; + return { toolDefinitions: [] }; } else if ( agent.tools && agent.tools.length === 1 && /** Legacy handling for `ocr` as may still exist in existing Agents */ (agent.tools[0] === AgentCapabilities.context || agent.tools[0] === AgentCapabilities.ocr) ) { - return {}; + return { toolDefinitions: [] }; } const appConfig = req.config; @@ -401,8 +820,14 @@ async function loadAgentTools({ const checkCapability = (capability) => { const enabled = enabledCapabilities.has(capability); if (!enabled) { + const isToolCapability = [ + AgentCapabilities.file_search, + AgentCapabilities.execute_code, + AgentCapabilities.web_search, + ].includes(capability); + const suffix = isToolCapability ? ' despite configured tool.' : '.'; logger.warn( - `Capability "${capability}" disabled${capability === AgentCapabilities.tools ? '.' : ' despite configured tool.'} User: ${req.user.id} | Agent: ${agent.id}`, + `Capability "${capability}" disabled${suffix} User: ${req.user.id} | Agent: ${agent.id}`, ); } return enabled; @@ -466,6 +891,18 @@ async function loadAgentTools({ imageOutputType: appConfig.imageOutputType, }); + /** Build tool registry from MCP tools and create PTC/tool search tools if configured */ + const deferredToolsEnabled = checkCapability(AgentCapabilities.deferred_tools); + const { toolRegistry, toolDefinitions, additionalTools, hasDeferredTools } = + await buildToolClassification({ + loadedTools, + userId: req.user.id, + agentId: agent.id, + agentToolOptions: agent.tool_options, + deferredToolsEnabled, + loadAuthValues, + }); + const agentTools = []; for (let i = 0; i < loadedTools.length; i++) { const tool = loadedTools[i]; @@ -510,11 +947,16 @@ async function loadAgentTools({ return map; }, {}); + agentTools.push(...additionalTools); + if (!checkCapability(AgentCapabilities.actions)) { return { - tools: agentTools, + toolRegistry, userMCPAuthMap, toolContextMap, + toolDefinitions, + hasDeferredTools, + tools: agentTools, }; } @@ -524,9 +966,12 @@ async function loadAgentTools({ logger.warn(`No tools found for the specified tool calls: ${_agentTools.join(', ')}`); } return { - tools: agentTools, + toolRegistry, userMCPAuthMap, toolContextMap, + toolDefinitions, + hasDeferredTools, + tools: agentTools, }; } @@ -621,6 +1066,7 @@ async function loadAgentTools({ const zodSchema = zodSchemas[functionName]; if (requestBuilder) { + const _allowedDomains = appConfig?.actions?.allowedDomains; const tool = await createActionTool({ userId: req.user.id, res, @@ -631,6 +1077,7 @@ async function loadAgentTools({ name: toolName, description: functionSig.description, streamId, + useSSRFProtection: !Array.isArray(_allowedDomains) || _allowedDomains.length === 0, }); if (!tool) { @@ -651,14 +1098,303 @@ async function loadAgentTools({ } return { - tools: agentTools, + toolRegistry, toolContextMap, userMCPAuthMap, + toolDefinitions, + hasDeferredTools, + tools: agentTools, }; } +/** + * Loads tools for event-driven execution (ON_TOOL_EXECUTE handler). + * This function encapsulates all dependencies needed for tool loading, + * so callers don't need to import processFileURL, uploadImageBuffer, etc. + * + * Handles both regular tools (MCP, built-in) and action tools. + * + * @param {Object} params + * @param {ServerRequest} params.req - The request object + * @param {ServerResponse} params.res - The response object + * @param {AbortSignal} [params.signal] - Abort signal + * @param {Object} params.agent - The agent object + * @param {string[]} params.toolNames - Names of tools to load + * @param {Record>} [params.userMCPAuthMap] - User MCP auth map + * @param {Object} [params.tool_resources] - Tool resources + * @param {string|null} [params.streamId] - Stream ID for web search callbacks + * @returns {Promise<{ loadedTools: Array, configurable: Object }>} + */ +async function loadToolsForExecution({ + req, + res, + signal, + agent, + toolNames, + toolRegistry, + userMCPAuthMap, + tool_resources, + streamId = null, +}) { + const appConfig = req.config; + const allLoadedTools = []; + const configurable = { userMCPAuthMap }; + + const isToolSearch = toolNames.includes(AgentConstants.TOOL_SEARCH); + const isPTC = toolNames.includes(AgentConstants.PROGRAMMATIC_TOOL_CALLING); + + logger.debug( + `[loadToolsForExecution] isToolSearch: ${isToolSearch}, toolRegistry: ${toolRegistry?.size ?? 'undefined'}`, + ); + + if (isToolSearch && toolRegistry) { + const toolSearchTool = createToolSearch({ + mode: 'local', + toolRegistry, + }); + allLoadedTools.push(toolSearchTool); + configurable.toolRegistry = toolRegistry; + } + + if (isPTC && toolRegistry) { + configurable.toolRegistry = toolRegistry; + try { + const authValues = await loadAuthValues({ + userId: req.user.id, + authFields: [EnvVar.CODE_API_KEY], + }); + const codeApiKey = authValues[EnvVar.CODE_API_KEY]; + + if (codeApiKey) { + const ptcTool = createProgrammaticToolCallingTool({ apiKey: codeApiKey }); + allLoadedTools.push(ptcTool); + } else { + logger.warn('[loadToolsForExecution] PTC requested but CODE_API_KEY not available'); + } + } catch (error) { + logger.error('[loadToolsForExecution] Error creating PTC tool:', error); + } + } + + const specialToolNames = new Set([ + AgentConstants.TOOL_SEARCH, + AgentConstants.PROGRAMMATIC_TOOL_CALLING, + ]); + + let ptcOrchestratedToolNames = []; + if (isPTC && toolRegistry) { + ptcOrchestratedToolNames = Array.from(toolRegistry.keys()).filter( + (name) => !specialToolNames.has(name), + ); + } + + const requestedNonSpecialToolNames = toolNames.filter((name) => !specialToolNames.has(name)); + const allToolNamesToLoad = isPTC + ? [...new Set([...requestedNonSpecialToolNames, ...ptcOrchestratedToolNames])] + : requestedNonSpecialToolNames; + + const actionToolNames = allToolNamesToLoad.filter((name) => name.includes(actionDelimiter)); + const regularToolNames = allToolNamesToLoad.filter((name) => !name.includes(actionDelimiter)); + + /** @type {Record} */ + if (regularToolNames.length > 0) { + const includesWebSearch = regularToolNames.includes(Tools.web_search); + const webSearchCallbacks = includesWebSearch ? createOnSearchResults(res, streamId) : undefined; + + const { loadedTools } = await loadTools({ + agent, + signal, + userMCPAuthMap, + functions: true, + tools: regularToolNames, + user: req.user.id, + options: { + req, + res, + tool_resources, + processFileURL, + uploadImageBuffer, + returnMetadata: true, + [Tools.web_search]: webSearchCallbacks, + }, + webSearch: appConfig?.webSearch, + fileStrategy: appConfig?.fileStrategy, + imageOutputType: appConfig?.imageOutputType, + }); + + if (loadedTools) { + allLoadedTools.push(...loadedTools); + } + } + + if (actionToolNames.length > 0 && agent) { + const actionTools = await loadActionToolsForExecution({ + req, + res, + agent, + appConfig, + streamId, + actionToolNames, + }); + allLoadedTools.push(...actionTools); + } + + if (isPTC && allLoadedTools.length > 0) { + const ptcToolMap = new Map(); + for (const tool of allLoadedTools) { + if (tool.name && tool.name !== AgentConstants.PROGRAMMATIC_TOOL_CALLING) { + ptcToolMap.set(tool.name, tool); + } + } + configurable.ptcToolMap = ptcToolMap; + } + + return { + configurable, + loadedTools: allLoadedTools, + }; +} + +/** + * Loads action tools for event-driven execution. + * @param {Object} params + * @param {ServerRequest} params.req - The request object + * @param {ServerResponse} params.res - The response object + * @param {Object} params.agent - The agent object + * @param {Object} params.appConfig - App configuration + * @param {string|null} params.streamId - Stream ID + * @param {string[]} params.actionToolNames - Action tool names to load + * @returns {Promise} Loaded action tools + */ +async function loadActionToolsForExecution({ + req, + res, + agent, + appConfig, + streamId, + actionToolNames, +}) { + const loadedActionTools = []; + + const actionSets = (await loadActionSets({ agent_id: agent.id })) ?? []; + if (actionSets.length === 0) { + return loadedActionTools; + } + + const processedActionSets = new Map(); + const domainMap = new Map(); + const allowedDomains = appConfig?.actions?.allowedDomains; + + for (const action of actionSets) { + const domain = await domainParser(action.metadata.domain, true); + domainMap.set(domain, action); + + const isDomainAllowed = await isActionDomainAllowed(action.metadata.domain, allowedDomains); + if (!isDomainAllowed) { + logger.warn( + `[Actions] Domain "${action.metadata.domain}" not in allowedDomains. ` + + `Add it to librechat.yaml actions.allowedDomains to enable this action.`, + ); + continue; + } + + const validationResult = validateAndParseOpenAPISpec(action.metadata.raw_spec); + if (!validationResult.spec || !validationResult.serverUrl) { + logger.warn(`[Actions] Invalid OpenAPI spec for domain: ${domain}`); + continue; + } + + const domainValidation = validateActionDomain( + action.metadata.domain, + validationResult.serverUrl, + ); + if (!domainValidation.isValid) { + logger.error(`Domain mismatch in stored action: ${domainValidation.message}`, { + userId: req.user.id, + agent_id: agent.id, + action_id: action.action_id, + }); + continue; + } + + const encrypted = { + oauth_client_id: action.metadata.oauth_client_id, + oauth_client_secret: action.metadata.oauth_client_secret, + }; + + const decryptedAction = { ...action }; + decryptedAction.metadata = await decryptMetadata(action.metadata); + + const { requestBuilders, functionSignatures, zodSchemas } = openapiToFunction( + validationResult.spec, + true, + ); + + processedActionSets.set(domain, { + action: decryptedAction, + requestBuilders, + functionSignatures, + zodSchemas, + encrypted, + }); + } + + const domainSeparatorRegex = new RegExp(actionDomainSeparator, 'g'); + for (const toolName of actionToolNames) { + let currentDomain = ''; + for (const domain of domainMap.keys()) { + const normalizedDomain = domain.replace(domainSeparatorRegex, '_'); + if (toolName.includes(normalizedDomain)) { + currentDomain = domain; + break; + } + } + + if (!currentDomain || !processedActionSets.has(currentDomain)) { + continue; + } + + const { action, encrypted, zodSchemas, requestBuilders, functionSignatures } = + processedActionSets.get(currentDomain); + const normalizedDomain = currentDomain.replace(domainSeparatorRegex, '_'); + const functionName = toolName.replace(`${actionDelimiter}${normalizedDomain}`, ''); + const functionSig = functionSignatures.find((sig) => sig.name === functionName); + const requestBuilder = requestBuilders[functionName]; + const zodSchema = zodSchemas[functionName]; + + if (!requestBuilder) { + continue; + } + + const tool = await createActionTool({ + userId: req.user.id, + res, + action, + streamId, + zodSchema, + encrypted, + requestBuilder, + name: toolName, + description: functionSig?.description ?? '', + useSSRFProtection: !Array.isArray(allowedDomains) || allowedDomains.length === 0, + }); + + if (!tool) { + logger.warn(`[Actions] Failed to create action tool: ${toolName}`); + continue; + } + + loadedActionTools.push(tool); + } + + return loadedActionTools; +} + module.exports = { + loadTools, + isBuiltInTool, getToolkitKey, loadAgentTools, + loadToolsForExecution, processRequiredActions, }; diff --git a/api/server/services/Tools/mcp.js b/api/server/services/Tools/mcp.js index 33e67c8238..10f2d71a18 100644 --- a/api/server/services/Tools/mcp.js +++ b/api/server/services/Tools/mcp.js @@ -1,11 +1,14 @@ const { logger } = require('@librechat/data-schemas'); const { CacheKeys, Constants } = require('librechat-data-provider'); const { findToken, createToken, updateToken, deleteTokens } = require('~/models'); -const { getMCPManager, getFlowStateManager } = require('~/config'); const { updateMCPServerTools } = require('~/server/services/Config'); +const { getMCPManager, getFlowStateManager } = require('~/config'); const { getLogStores } = require('~/cache'); /** + * Reinitializes an MCP server connection and discovers available tools. + * When OAuth is required, uses discovery mode to list tools without full authentication + * (per MCP spec, tool listing should be possible without auth). * @param {Object} params * @param {IUser} params.user - The user from the request object. * @param {string} params.serverName - The name of the MCP server @@ -14,7 +17,7 @@ const { getLogStores } = require('~/cache'); * @param {boolean} [params.forceNew] * @param {number} [params.connectionTimeout] * @param {FlowStateManager} [params.flowManager] - * @param {(authURL: string) => Promise} [params.oauthStart] + * @param {(authURL: string) => Promise} [params.oauthStart] * @param {Record>} [params.userMCPAuthMap] */ async function reinitMCPServer({ @@ -36,10 +39,12 @@ async function reinitMCPServer({ let tools = null; let oauthRequired = false; let oauthUrl = null; + try { const customUserVars = userMCPAuthMap?.[`${Constants.mcp_prefix}${serverName}`]; const flowManager = _flowManager ?? getFlowStateManager(getLogStores(CacheKeys.FLOWS)); const mcpManager = getMCPManager(); + const tokenMethods = { findToken, updateToken, createToken, deleteTokens }; const oauthStart = _oauthStart ?? @@ -57,15 +62,10 @@ async function reinitMCPServer({ oauthStart, serverName, flowManager, + tokenMethods, returnOnOAuth, customUserVars, connectionTimeout, - tokenMethods: { - findToken, - updateToken, - createToken, - deleteTokens, - }, }); logger.info(`[MCP Reinitialize] Successfully established connection for ${serverName}`); @@ -84,9 +84,33 @@ async function reinitMCPServer({ if (isOAuthError || oauthRequired || isOAuthFlowInitiated) { logger.info( - `[MCP Reinitialize] OAuth required for ${serverName} (isOAuthError: ${isOAuthError}, oauthRequired: ${oauthRequired}, isOAuthFlowInitiated: ${isOAuthFlowInitiated})`, + `[MCP Reinitialize] OAuth required for ${serverName}, attempting tool discovery without auth`, ); oauthRequired = true; + + try { + const discoveryResult = await mcpManager.discoverServerTools({ + user, + signal, + serverName, + flowManager, + tokenMethods, + oauthStart, + customUserVars, + connectionTimeout, + }); + + if (discoveryResult.tools && discoveryResult.tools.length > 0) { + tools = discoveryResult.tools; + logger.info( + `[MCP Reinitialize] Discovered ${tools.length} tools for ${serverName} without full auth`, + ); + } + } catch (discoveryErr) { + logger.debug( + `[MCP Reinitialize] Tool discovery failed for ${serverName}: ${discoveryErr?.message ?? String(discoveryErr)}`, + ); + } } else { logger.error( `[MCP Reinitialize] Error initializing MCP server ${serverName} for user:`, @@ -97,6 +121,9 @@ async function reinitMCPServer({ if (connection && !oauthRequired) { tools = await connection.fetchTools(); + } + + if (tools && tools.length > 0) { availableTools = await updateMCPServerTools({ userId: user.id, serverName, @@ -109,6 +136,9 @@ async function reinitMCPServer({ ); const getResponseMessage = () => { + if (oauthRequired && tools && tools.length > 0) { + return `MCP server '${serverName}' tools discovered, OAuth required for execution`; + } if (oauthRequired) { return `MCP server '${serverName}' ready for OAuth authentication`; } @@ -120,19 +150,25 @@ async function reinitMCPServer({ const result = { availableTools, - success: Boolean((connection && !oauthRequired) || (oauthRequired && oauthUrl)), + success: Boolean( + (connection && !oauthRequired) || + (oauthRequired && oauthUrl) || + (tools && tools.length > 0), + ), message: getResponseMessage(), oauthRequired, serverName, oauthUrl, tools, }; + logger.debug(`[MCP Reinitialize] Response for ${serverName}:`, { success: result.success, oauthRequired: result.oauthRequired, oauthUrl: result.oauthUrl ? 'present' : null, toolsCount: tools?.length ?? 0, }); + return result; } catch (error) { logger.error( diff --git a/api/server/services/__tests__/ToolService.spec.js b/api/server/services/__tests__/ToolService.spec.js new file mode 100644 index 0000000000..2f00bbc3d6 --- /dev/null +++ b/api/server/services/__tests__/ToolService.spec.js @@ -0,0 +1,149 @@ +const { AgentCapabilities, defaultAgentCapabilities } = require('librechat-data-provider'); + +/** + * Tests for ToolService capability checking logic. + * The actual loadAgentTools function has many dependencies, so we test + * the capability checking logic in isolation. + */ +describe('ToolService - Capability Checking', () => { + describe('checkCapability logic', () => { + /** + * Simulates the checkCapability function from loadAgentTools + */ + const createCheckCapability = (enabledCapabilities, logger = { warn: jest.fn() }) => { + return (capability) => { + const enabled = enabledCapabilities.has(capability); + if (!enabled) { + const isToolCapability = [ + AgentCapabilities.file_search, + AgentCapabilities.execute_code, + AgentCapabilities.web_search, + ].includes(capability); + const suffix = isToolCapability ? ' despite configured tool.' : '.'; + logger.warn(`Capability "${capability}" disabled${suffix}`); + } + return enabled; + }; + }; + + it('should return true when capability is enabled', () => { + const enabledCapabilities = new Set([AgentCapabilities.deferred_tools]); + const checkCapability = createCheckCapability(enabledCapabilities); + + expect(checkCapability(AgentCapabilities.deferred_tools)).toBe(true); + }); + + it('should return false when capability is not enabled', () => { + const enabledCapabilities = new Set([]); + const checkCapability = createCheckCapability(enabledCapabilities); + + expect(checkCapability(AgentCapabilities.deferred_tools)).toBe(false); + }); + + it('should log warning with "despite configured tool" for tool capabilities', () => { + const logger = { warn: jest.fn() }; + const enabledCapabilities = new Set([]); + const checkCapability = createCheckCapability(enabledCapabilities, logger); + + checkCapability(AgentCapabilities.file_search); + expect(logger.warn).toHaveBeenCalledWith(expect.stringContaining('despite configured tool')); + + logger.warn.mockClear(); + checkCapability(AgentCapabilities.execute_code); + expect(logger.warn).toHaveBeenCalledWith(expect.stringContaining('despite configured tool')); + + logger.warn.mockClear(); + checkCapability(AgentCapabilities.web_search); + expect(logger.warn).toHaveBeenCalledWith(expect.stringContaining('despite configured tool')); + }); + + it('should log warning without "despite configured tool" for non-tool capabilities', () => { + const logger = { warn: jest.fn() }; + const enabledCapabilities = new Set([]); + const checkCapability = createCheckCapability(enabledCapabilities, logger); + + checkCapability(AgentCapabilities.deferred_tools); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('Capability "deferred_tools" disabled.'), + ); + expect(logger.warn).not.toHaveBeenCalledWith( + expect.stringContaining('despite configured tool'), + ); + + logger.warn.mockClear(); + checkCapability(AgentCapabilities.tools); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('Capability "tools" disabled.'), + ); + expect(logger.warn).not.toHaveBeenCalledWith( + expect.stringContaining('despite configured tool'), + ); + + logger.warn.mockClear(); + checkCapability(AgentCapabilities.actions); + expect(logger.warn).toHaveBeenCalledWith( + expect.stringContaining('Capability "actions" disabled.'), + ); + }); + + it('should not log warning when capability is enabled', () => { + const logger = { warn: jest.fn() }; + const enabledCapabilities = new Set([ + AgentCapabilities.deferred_tools, + AgentCapabilities.file_search, + ]); + const checkCapability = createCheckCapability(enabledCapabilities, logger); + + checkCapability(AgentCapabilities.deferred_tools); + checkCapability(AgentCapabilities.file_search); + + expect(logger.warn).not.toHaveBeenCalled(); + }); + }); + + describe('defaultAgentCapabilities', () => { + it('should include deferred_tools capability by default', () => { + expect(defaultAgentCapabilities).toContain(AgentCapabilities.deferred_tools); + }); + + it('should include all expected default capabilities', () => { + expect(defaultAgentCapabilities).toContain(AgentCapabilities.execute_code); + expect(defaultAgentCapabilities).toContain(AgentCapabilities.file_search); + expect(defaultAgentCapabilities).toContain(AgentCapabilities.web_search); + expect(defaultAgentCapabilities).toContain(AgentCapabilities.artifacts); + expect(defaultAgentCapabilities).toContain(AgentCapabilities.actions); + expect(defaultAgentCapabilities).toContain(AgentCapabilities.context); + expect(defaultAgentCapabilities).toContain(AgentCapabilities.tools); + expect(defaultAgentCapabilities).toContain(AgentCapabilities.chain); + expect(defaultAgentCapabilities).toContain(AgentCapabilities.ocr); + }); + }); + + describe('deferredToolsEnabled integration', () => { + it('should correctly determine deferredToolsEnabled from capabilities set', () => { + const createCheckCapability = (enabledCapabilities) => { + return (capability) => enabledCapabilities.has(capability); + }; + + // When deferred_tools is in capabilities + const withDeferred = new Set([AgentCapabilities.deferred_tools, AgentCapabilities.tools]); + const checkWithDeferred = createCheckCapability(withDeferred); + expect(checkWithDeferred(AgentCapabilities.deferred_tools)).toBe(true); + + // When deferred_tools is NOT in capabilities + const withoutDeferred = new Set([AgentCapabilities.tools, AgentCapabilities.actions]); + const checkWithoutDeferred = createCheckCapability(withoutDeferred); + expect(checkWithoutDeferred(AgentCapabilities.deferred_tools)).toBe(false); + }); + + it('should use defaultAgentCapabilities when no capabilities configured', () => { + // Simulates the fallback behavior in loadAgentTools + const endpointsConfig = {}; // No capabilities configured + const enabledCapabilities = new Set( + endpointsConfig?.capabilities ?? defaultAgentCapabilities, + ); + + expect(enabledCapabilities.has(AgentCapabilities.deferred_tools)).toBe(true); + }); + }); +}); diff --git a/api/server/services/start/tools.js b/api/server/services/start/tools.js index dd2d69b274..8dc8475f7f 100644 --- a/api/server/services/start/tools.js +++ b/api/server/services/start/tools.js @@ -107,22 +107,33 @@ function loadAndFormatTools({ directory, adminFilter = [], adminIncluded = [] }) }, {}); } +/** + * Checks if a schema is a Zod schema by looking for the _def property + * @param {unknown} schema - The schema to check + * @returns {boolean} True if it's a Zod schema + */ +function isZodSchema(schema) { + return schema && typeof schema === 'object' && '_def' in schema; +} + /** * Formats a `StructuredTool` instance into a format that is compatible * with OpenAI's ChatCompletionFunctions. It uses the `zodToJsonSchema` * function to convert the schema of the `StructuredTool` into a JSON * schema, which is then used as the parameters for the OpenAI function. + * If the schema is already a JSON schema, it is used directly. * * @param {StructuredTool} tool - The StructuredTool to format. * @returns {FunctionTool} The OpenAI Assistant Tool. */ function formatToOpenAIAssistantTool(tool) { + const parameters = isZodSchema(tool.schema) ? zodToJsonSchema(tool.schema) : tool.schema; return { type: Tools.function, [Tools.function]: { name: tool.name, description: tool.description, - parameters: zodToJsonSchema(tool.schema), + parameters, }, }; } diff --git a/api/server/socialLogins.js b/api/server/socialLogins.js index 0a89313ba9..a84c33bd52 100644 --- a/api/server/socialLogins.js +++ b/api/server/socialLogins.js @@ -1,8 +1,8 @@ const passport = require('passport'); const session = require('express-session'); -const { isEnabled } = require('@librechat/api'); -const { logger } = require('@librechat/data-schemas'); const { CacheKeys } = require('librechat-data-provider'); +const { isEnabled, shouldUseSecureCookie } = require('@librechat/api'); +const { logger, DEFAULT_SESSION_EXPIRY } = require('@librechat/data-schemas'); const { openIdJwtLogin, facebookLogin, @@ -22,11 +22,16 @@ const { getLogStores } = require('~/cache'); */ async function configureOpenId(app) { logger.info('Configuring OpenID Connect...'); + const sessionExpiry = Number(process.env.SESSION_EXPIRY) || DEFAULT_SESSION_EXPIRY; const sessionOptions = { secret: process.env.OPENID_SESSION_SECRET, resave: false, saveUninitialized: false, store: getLogStores(CacheKeys.OPENID_SESSION), + cookie: { + maxAge: sessionExpiry, + secure: shouldUseSecureCookie(), + }, }; app.use(session(sessionOptions)); app.use(passport.session()); @@ -82,11 +87,16 @@ const configureSocialLogins = async (app) => { process.env.SAML_SESSION_SECRET ) { logger.info('Configuring SAML Connect...'); + const sessionExpiry = Number(process.env.SESSION_EXPIRY) || DEFAULT_SESSION_EXPIRY; const sessionOptions = { secret: process.env.SAML_SESSION_SECRET, resave: false, saveUninitialized: false, store: getLogStores(CacheKeys.SAML_SESSION), + cookie: { + maxAge: sessionExpiry, + secure: shouldUseSecureCookie(), + }, }; app.use(session(sessionOptions)); app.use(passport.session()); diff --git a/api/strategies/index.js b/api/strategies/index.js index 725e04224a..b4f7bd3cac 100644 --- a/api/strategies/index.js +++ b/api/strategies/index.js @@ -1,14 +1,14 @@ -const appleLogin = require('./appleStrategy'); +const { setupOpenId, getOpenIdConfig } = require('./openidStrategy'); +const openIdJwtLogin = require('./openIdJwtStrategy'); +const facebookLogin = require('./facebookStrategy'); +const discordLogin = require('./discordStrategy'); const passportLogin = require('./localStrategy'); const googleLogin = require('./googleStrategy'); const githubLogin = require('./githubStrategy'); -const discordLogin = require('./discordStrategy'); -const facebookLogin = require('./facebookStrategy'); -const { setupOpenId, getOpenIdConfig } = require('./openidStrategy'); -const jwtLogin = require('./jwtStrategy'); -const ldapLogin = require('./ldapStrategy'); const { setupSaml } = require('./samlStrategy'); -const openIdJwtLogin = require('./openIdJwtStrategy'); +const appleLogin = require('./appleStrategy'); +const ldapLogin = require('./ldapStrategy'); +const jwtLogin = require('./jwtStrategy'); module.exports = { appleLogin, diff --git a/api/strategies/openIdJwtStrategy.js b/api/strategies/openIdJwtStrategy.js index df318ca30e..997dcec397 100644 --- a/api/strategies/openIdJwtStrategy.js +++ b/api/strategies/openIdJwtStrategy.js @@ -84,19 +84,21 @@ const openIdJwtLogin = (openIdConfig) => { /** Read tokens from session (server-side) to avoid large cookie issues */ const sessionTokens = req.session?.openidTokens; let accessToken = sessionTokens?.accessToken; + let idToken = sessionTokens?.idToken; let refreshToken = sessionTokens?.refreshToken; /** Fallback to cookies for backward compatibility */ - if (!accessToken || !refreshToken) { + if (!accessToken || !refreshToken || !idToken) { const cookieHeader = req.headers.cookie; const parsedCookies = cookieHeader ? cookies.parse(cookieHeader) : {}; accessToken = accessToken || parsedCookies.openid_access_token; + idToken = idToken || parsedCookies.openid_id_token; refreshToken = refreshToken || parsedCookies.refreshToken; } user.federatedTokens = { access_token: accessToken || rawToken, - id_token: rawToken, + id_token: idToken, refresh_token: refreshToken, expires_at: payload.exp, }; diff --git a/api/strategies/openIdJwtStrategy.spec.js b/api/strategies/openIdJwtStrategy.spec.js new file mode 100644 index 0000000000..566afe5a90 --- /dev/null +++ b/api/strategies/openIdJwtStrategy.spec.js @@ -0,0 +1,183 @@ +const { SystemRoles } = require('librechat-data-provider'); + +// --- Capture the verify callback from JwtStrategy --- +let capturedVerifyCallback; +jest.mock('passport-jwt', () => ({ + Strategy: jest.fn((_opts, verifyCallback) => { + capturedVerifyCallback = verifyCallback; + return { name: 'jwt' }; + }), + ExtractJwt: { + fromAuthHeaderAsBearerToken: jest.fn(() => 'mock-extractor'), + }, +})); +jest.mock('jwks-rsa', () => ({ + passportJwtSecret: jest.fn(() => 'mock-secret-provider'), +})); +jest.mock('https-proxy-agent', () => ({ + HttpsProxyAgent: jest.fn(), +})); +jest.mock('@librechat/data-schemas', () => ({ + logger: { info: jest.fn(), warn: jest.fn(), debug: jest.fn(), error: jest.fn() }, +})); +jest.mock('@librechat/api', () => ({ + isEnabled: jest.fn(() => false), + findOpenIDUser: jest.fn(), + math: jest.fn((val, fallback) => fallback), +})); +jest.mock('~/models', () => ({ + findUser: jest.fn(), + updateUser: jest.fn(), +})); + +const { findOpenIDUser } = require('@librechat/api'); +const { updateUser } = require('~/models'); +const openIdJwtLogin = require('./openIdJwtStrategy'); + +// Helper: build a mock openIdConfig +const mockOpenIdConfig = { + serverMetadata: () => ({ jwks_uri: 'https://example.com/.well-known/jwks.json' }), +}; + +// Helper: invoke the captured verify callback +async function invokeVerify(req, payload) { + return new Promise((resolve, reject) => { + capturedVerifyCallback(req, payload, (err, user, info) => { + if (err) { + return reject(err); + } + resolve({ user, info }); + }); + }); +} + +describe('openIdJwtStrategy – token source handling', () => { + const baseUser = { + _id: { toString: () => 'user-abc' }, + role: SystemRoles.USER, + provider: 'openid', + }; + + const payload = { sub: 'oidc-123', email: 'test@example.com', exp: 9999999999 }; + + beforeEach(() => { + jest.clearAllMocks(); + findOpenIDUser.mockResolvedValue({ user: { ...baseUser }, error: null, migration: false }); + updateUser.mockResolvedValue({}); + + // Initialize the strategy so capturedVerifyCallback is set + openIdJwtLogin(mockOpenIdConfig); + }); + + it('should read all tokens from session when available', async () => { + const req = { + headers: { authorization: 'Bearer raw-bearer-token' }, + session: { + openidTokens: { + accessToken: 'session-access', + idToken: 'session-id', + refreshToken: 'session-refresh', + }, + }, + }; + + const { user } = await invokeVerify(req, payload); + + expect(user.federatedTokens).toEqual({ + access_token: 'session-access', + id_token: 'session-id', + refresh_token: 'session-refresh', + expires_at: payload.exp, + }); + }); + + it('should fall back to cookies when session is absent', async () => { + const req = { + headers: { + authorization: 'Bearer raw-bearer-token', + cookie: + 'openid_access_token=cookie-access; openid_id_token=cookie-id; refreshToken=cookie-refresh', + }, + }; + + const { user } = await invokeVerify(req, payload); + + expect(user.federatedTokens).toEqual({ + access_token: 'cookie-access', + id_token: 'cookie-id', + refresh_token: 'cookie-refresh', + expires_at: payload.exp, + }); + }); + + it('should fall back to cookie for idToken only when session lacks it', async () => { + const req = { + headers: { + authorization: 'Bearer raw-bearer-token', + cookie: 'openid_id_token=cookie-id', + }, + session: { + openidTokens: { + accessToken: 'session-access', + // idToken intentionally missing + refreshToken: 'session-refresh', + }, + }, + }; + + const { user } = await invokeVerify(req, payload); + + expect(user.federatedTokens).toEqual({ + access_token: 'session-access', + id_token: 'cookie-id', + refresh_token: 'session-refresh', + expires_at: payload.exp, + }); + }); + + it('should use raw Bearer token as access_token fallback when neither session nor cookie has one', async () => { + const req = { + headers: { + authorization: 'Bearer raw-bearer-token', + cookie: 'openid_id_token=cookie-id; refreshToken=cookie-refresh', + }, + }; + + const { user } = await invokeVerify(req, payload); + + expect(user.federatedTokens.access_token).toBe('raw-bearer-token'); + expect(user.federatedTokens.id_token).toBe('cookie-id'); + expect(user.federatedTokens.refresh_token).toBe('cookie-refresh'); + }); + + it('should set id_token to undefined when not available in session or cookies', async () => { + const req = { + headers: { + authorization: 'Bearer raw-bearer-token', + cookie: 'openid_access_token=cookie-access; refreshToken=cookie-refresh', + }, + }; + + const { user } = await invokeVerify(req, payload); + + expect(user.federatedTokens.access_token).toBe('cookie-access'); + expect(user.federatedTokens.id_token).toBeUndefined(); + expect(user.federatedTokens.refresh_token).toBe('cookie-refresh'); + }); + + it('should keep id_token and access_token as distinct values from cookies', async () => { + const req = { + headers: { + authorization: 'Bearer raw-bearer-token', + cookie: + 'openid_access_token=the-access-token; openid_id_token=the-id-token; refreshToken=the-refresh', + }, + }; + + const { user } = await invokeVerify(req, payload); + + expect(user.federatedTokens.access_token).toBe('the-access-token'); + expect(user.federatedTokens.id_token).toBe('the-id-token'); + expect(user.federatedTokens.access_token).not.toBe(user.federatedTokens.id_token); + }); +}); diff --git a/api/strategies/openidStrategy.js b/api/strategies/openidStrategy.js index a4369e601b..198c8735ae 100644 --- a/api/strategies/openidStrategy.js +++ b/api/strategies/openidStrategy.js @@ -6,8 +6,8 @@ const client = require('openid-client'); const jwtDecode = require('jsonwebtoken/decode'); const { HttpsProxyAgent } = require('https-proxy-agent'); const { hashToken, logger } = require('@librechat/data-schemas'); -const { CacheKeys, ErrorTypes } = require('librechat-data-provider'); const { Strategy: OpenIDStrategy } = require('openid-client/passport'); +const { CacheKeys, ErrorTypes, SystemRoles } = require('librechat-data-provider'); const { isEnabled, logHeaders, @@ -287,6 +287,367 @@ function convertToUsername(input, defaultValue = '') { return defaultValue; } +/** + * Resolve Azure AD groups when group overage is in effect (groups moved to _claim_names/_claim_sources). + * + * NOTE: Microsoft recommends treating _claim_names/_claim_sources as a signal only and using Microsoft Graph + * to resolve group membership instead of calling the endpoint in _claim_sources directly. + * + * @param {string} accessToken - Access token with Microsoft Graph permissions + * @returns {Promise} Resolved group IDs or null on failure + * @see https://learn.microsoft.com/en-us/entra/identity-platform/access-token-claims-reference#groups-overage-claim + * @see https://learn.microsoft.com/en-us/graph/api/directoryobject-getmemberobjects + */ +async function resolveGroupsFromOverage(accessToken) { + try { + if (!accessToken) { + logger.error('[openidStrategy] Access token missing; cannot resolve group overage'); + return null; + } + + // Use /me/getMemberObjects so least-privileged delegated permission User.Read is sufficient + // when resolving the signed-in user's group membership. + const url = 'https://graph.microsoft.com/v1.0/me/getMemberObjects'; + + logger.debug( + `[openidStrategy] Detected group overage, resolving groups via Microsoft Graph getMemberObjects: ${url}`, + ); + + const fetchOptions = { + method: 'POST', + headers: { + Authorization: `Bearer ${accessToken}`, + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ securityEnabledOnly: false }), + }; + + if (process.env.PROXY) { + const { ProxyAgent } = undici; + fetchOptions.dispatcher = new ProxyAgent(process.env.PROXY); + } + + const response = await undici.fetch(url, fetchOptions); + if (!response.ok) { + logger.error( + `[openidStrategy] Failed to resolve groups via Microsoft Graph getMemberObjects: HTTP ${response.status} ${response.statusText}`, + ); + return null; + } + + const data = await response.json(); + const values = Array.isArray(data?.value) ? data.value : null; + if (!values) { + logger.error( + '[openidStrategy] Unexpected response format when resolving groups via Microsoft Graph getMemberObjects', + ); + return null; + } + const groupIds = values.filter((id) => typeof id === 'string'); + + logger.debug( + `[openidStrategy] Successfully resolved ${groupIds.length} groups via Microsoft Graph getMemberObjects`, + ); + return groupIds; + } catch (err) { + logger.error( + '[openidStrategy] Error resolving groups via Microsoft Graph getMemberObjects:', + err, + ); + return null; + } +} + +/** + * Process OpenID authentication tokenset and userinfo + * This is the core logic extracted from the passport strategy callback + * Can be reused by both the passport strategy and proxy authentication + * + * @param {Object} tokenset - The OpenID tokenset containing access_token, id_token, etc. + * @param {boolean} existingUsersOnly - If true, only existing users will be processed + * @returns {Promise} The authenticated user object with tokenset + */ +async function processOpenIDAuth(tokenset, existingUsersOnly = false) { + const claims = tokenset.claims ? tokenset.claims() : tokenset; + const userinfo = { + ...claims, + }; + + if (tokenset.access_token) { + const providerUserinfo = await getUserInfo(openidConfig, tokenset.access_token, claims.sub); + Object.assign(userinfo, providerUserinfo); + } + + const appConfig = await getAppConfig(); + /** Azure AD sometimes doesn't return email, use preferred_username as fallback */ + const email = userinfo.email || userinfo.preferred_username || userinfo.upn; + if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) { + logger.error( + `[OpenID Strategy] Authentication blocked - email domain not allowed [Email: ${userinfo.email}]`, + ); + throw new Error('Email domain not allowed'); + } + + const result = await findOpenIDUser({ + findUser, + email: email, + openidId: claims.sub || userinfo.sub, + idOnTheSource: claims.oid || userinfo.oid, + strategyName: 'openidStrategy', + }); + let user = result.user; + const error = result.error; + + if (error) { + throw new Error(ErrorTypes.AUTH_FAILED); + } + + const fullName = getFullName(userinfo); + + const requiredRole = process.env.OPENID_REQUIRED_ROLE; + if (requiredRole) { + const requiredRoles = requiredRole + .split(',') + .map((role) => role.trim()) + .filter(Boolean); + const requiredRoleParameterPath = process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH; + const requiredRoleTokenKind = process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND; + + let decodedToken = ''; + if (requiredRoleTokenKind === 'access' && tokenset.access_token) { + decodedToken = jwtDecode(tokenset.access_token); + } else if (requiredRoleTokenKind === 'id' && tokenset.id_token) { + decodedToken = jwtDecode(tokenset.id_token); + } + + let roles = get(decodedToken, requiredRoleParameterPath); + + // Handle Azure AD group overage for ID token groups: when hasgroups or _claim_* indicate overage, + // resolve groups via Microsoft Graph instead of relying on token group values. + if ( + !Array.isArray(roles) && + typeof roles !== 'string' && + requiredRoleTokenKind === 'id' && + requiredRoleParameterPath === 'groups' && + decodedToken && + (decodedToken.hasgroups || + (decodedToken._claim_names?.groups && + decodedToken._claim_sources?.[decodedToken._claim_names.groups])) + ) { + const overageGroups = await resolveGroupsFromOverage(tokenset.access_token); + if (overageGroups) { + roles = overageGroups; + } + } + + if (!roles || (!Array.isArray(roles) && typeof roles !== 'string')) { + logger.error( + `[openidStrategy] Key '${requiredRoleParameterPath}' not found in ${requiredRoleTokenKind} token!`, + ); + const rolesList = + requiredRoles.length === 1 + ? `"${requiredRoles[0]}"` + : `one of: ${requiredRoles.map((r) => `"${r}"`).join(', ')}`; + throw new Error(`You must have ${rolesList} role to log in.`); + } + + const roleValues = Array.isArray(roles) ? roles : [roles]; + + if (!requiredRoles.some((role) => roleValues.includes(role))) { + const rolesList = + requiredRoles.length === 1 + ? `"${requiredRoles[0]}"` + : `one of: ${requiredRoles.map((r) => `"${r}"`).join(', ')}`; + throw new Error(`You must have ${rolesList} role to log in.`); + } + } + + let username = ''; + if (process.env.OPENID_USERNAME_CLAIM) { + username = userinfo[process.env.OPENID_USERNAME_CLAIM]; + } else { + username = convertToUsername( + userinfo.preferred_username || userinfo.username || userinfo.email, + ); + } + + if (existingUsersOnly && !user) { + throw new Error('User does not exist'); + } + + if (!user) { + user = { + provider: 'openid', + openidId: userinfo.sub, + username, + email: email || '', + emailVerified: userinfo.email_verified || false, + name: fullName, + idOnTheSource: userinfo.oid, + }; + + const balanceConfig = getBalanceConfig(appConfig); + user = await createUser(user, balanceConfig, true, true); + } else { + user.provider = 'openid'; + user.openidId = userinfo.sub; + user.username = username; + user.name = fullName; + user.idOnTheSource = userinfo.oid; + if (email && email !== user.email) { + user.email = email; + user.emailVerified = userinfo.email_verified || false; + } + } + + const adminRole = process.env.OPENID_ADMIN_ROLE; + const adminRoleParameterPath = process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH; + const adminRoleTokenKind = process.env.OPENID_ADMIN_ROLE_TOKEN_KIND; + + if (adminRole && adminRoleParameterPath && adminRoleTokenKind) { + let adminRoleObject; + switch (adminRoleTokenKind) { + case 'access': + adminRoleObject = jwtDecode(tokenset.access_token); + break; + case 'id': + adminRoleObject = jwtDecode(tokenset.id_token); + break; + case 'userinfo': + adminRoleObject = userinfo; + break; + default: + logger.error( + `[openidStrategy] Invalid admin role token kind: ${adminRoleTokenKind}. Must be one of 'access', 'id', or 'userinfo'.`, + ); + throw new Error('Invalid admin role token kind'); + } + + const adminRoles = get(adminRoleObject, adminRoleParameterPath); + + if ( + adminRoles && + (adminRoles === true || + adminRoles === adminRole || + (Array.isArray(adminRoles) && adminRoles.includes(adminRole))) + ) { + user.role = SystemRoles.ADMIN; + logger.info(`[openidStrategy] User ${username} is an admin based on role: ${adminRole}`); + } else if (user.role === SystemRoles.ADMIN) { + user.role = SystemRoles.USER; + logger.info( + `[openidStrategy] User ${username} demoted from admin - role no longer present in token`, + ); + } + } + + if (!!userinfo && userinfo.picture && !user.avatar?.includes('manual=true')) { + /** @type {string | undefined} */ + const imageUrl = userinfo.picture; + + let fileName; + if (crypto) { + fileName = (await hashToken(userinfo.sub)) + '.png'; + } else { + fileName = userinfo.sub + '.png'; + } + + const imageBuffer = await downloadImage( + imageUrl, + openidConfig, + tokenset.access_token, + userinfo.sub, + ); + if (imageBuffer) { + const { saveBuffer } = getStrategyFunctions( + appConfig?.fileStrategy ?? process.env.CDN_PROVIDER, + ); + const imagePath = await saveBuffer({ + fileName, + userId: user._id.toString(), + buffer: imageBuffer, + }); + user.avatar = imagePath ?? ''; + } + } + + user = await updateUser(user._id, user); + + logger.info( + `[openidStrategy] login success openidId: ${user.openidId} | email: ${user.email} | username: ${user.username} `, + { + user: { + openidId: user.openidId, + username: user.username, + email: user.email, + name: user.name, + }, + }, + ); + + return { + ...user, + tokenset, + federatedTokens: { + access_token: tokenset.access_token, + id_token: tokenset.id_token, + refresh_token: tokenset.refresh_token, + expires_at: tokenset.expires_at, + }, + }; +} + +/** + * @param {boolean | undefined} [existingUsersOnly] + */ +function createOpenIDCallback(existingUsersOnly) { + return async (tokenset, done) => { + try { + const user = await processOpenIDAuth(tokenset, existingUsersOnly); + done(null, user); + } catch (err) { + if (err.message === 'Email domain not allowed') { + return done(null, false, { message: err.message }); + } + if (err.message === ErrorTypes.AUTH_FAILED) { + return done(null, false, { message: err.message }); + } + if (err.message && err.message.includes('role to log in')) { + return done(null, false, { message: err.message }); + } + logger.error('[openidStrategy] login failed', err); + done(err); + } + }; +} + +/** + * Sets up the OpenID strategy specifically for admin authentication. + * @param {Configuration} openidConfig + */ +const setupOpenIdAdmin = (openidConfig) => { + try { + if (!openidConfig) { + throw new Error('OpenID configuration not initialized'); + } + + const openidAdminLogin = new CustomOpenIDStrategy( + { + config: openidConfig, + scope: process.env.OPENID_SCOPE, + usePKCE: isEnabled(process.env.OPENID_USE_PKCE), + clockTolerance: process.env.OPENID_CLOCK_TOLERANCE || 300, + callbackURL: process.env.DOMAIN_SERVER + '/api/admin/oauth/openid/callback', + }, + createOpenIDCallback(true), + ); + + passport.use('openidAdmin', openidAdminLogin); + } catch (err) { + logger.error('[openidStrategy] setupOpenIdAdmin', err); + } +}; + /** * Sets up the OpenID strategy for authentication. * This function configures the OpenID client, handles proxy settings, @@ -324,10 +685,6 @@ async function setupOpenId() { }, ); - const requiredRole = process.env.OPENID_REQUIRED_ROLE; - const requiredRoleParameterPath = process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH; - const requiredRoleTokenKind = process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND; - const usePKCE = isEnabled(process.env.OPENID_USE_PKCE); logger.info(`[openidStrategy] OpenID authentication configuration`, { generateNonce: shouldGenerateNonce, reason: shouldGenerateNonce @@ -335,241 +692,25 @@ async function setupOpenId() { : 'OPENID_GENERATE_NONCE=false - Standard flow without explicit nonce or metadata', }); - // Set of env variables that specify how to set if a user is an admin - // If not set, all users will be treated as regular users - const adminRole = process.env.OPENID_ADMIN_ROLE; - const adminRoleParameterPath = process.env.OPENID_ADMIN_ROLE_PARAMETER_PATH; - const adminRoleTokenKind = process.env.OPENID_ADMIN_ROLE_TOKEN_KIND; - const openidLogin = new CustomOpenIDStrategy( { config: openidConfig, scope: process.env.OPENID_SCOPE, callbackURL: process.env.DOMAIN_SERVER + process.env.OPENID_CALLBACK_URL, clockTolerance: process.env.OPENID_CLOCK_TOLERANCE || 300, - usePKCE, - }, - /** - * @param {import('openid-client').TokenEndpointResponseHelpers} tokenset - * @param {import('passport-jwt').VerifyCallback} done - */ - async (tokenset, done) => { - try { - const claims = tokenset.claims(); - const userinfo = { - ...claims, - ...(await getUserInfo(openidConfig, tokenset.access_token, claims.sub)), - }; - - const appConfig = await getAppConfig(); - /** Azure AD sometimes doesn't return email, use preferred_username as fallback */ - const email = userinfo.email || userinfo.preferred_username || userinfo.upn; - if (!isEmailDomainAllowed(email, appConfig?.registration?.allowedDomains)) { - logger.error( - `[OpenID Strategy] Authentication blocked - email domain not allowed [Email: ${email}]`, - ); - return done(null, false, { message: 'Email domain not allowed' }); - } - - const result = await findOpenIDUser({ - findUser, - email: email, - openidId: claims.sub, - idOnTheSource: claims.oid, - strategyName: 'openidStrategy', - }); - let user = result.user; - const error = result.error; - - if (error) { - return done(null, false, { - message: ErrorTypes.AUTH_FAILED, - }); - } - - const fullName = getFullName(userinfo); - - if (requiredRole) { - const requiredRoles = requiredRole - .split(',') - .map((role) => role.trim()) - .filter(Boolean); - let decodedToken = ''; - if (requiredRoleTokenKind === 'access') { - decodedToken = jwtDecode(tokenset.access_token); - } else if (requiredRoleTokenKind === 'id') { - decodedToken = jwtDecode(tokenset.id_token); - } - - let roles = get(decodedToken, requiredRoleParameterPath); - if (!roles || (!Array.isArray(roles) && typeof roles !== 'string')) { - logger.error( - `[openidStrategy] Key '${requiredRoleParameterPath}' not found or invalid type in ${requiredRoleTokenKind} token!`, - ); - const rolesList = - requiredRoles.length === 1 - ? `"${requiredRoles[0]}"` - : `one of: ${requiredRoles.map((r) => `"${r}"`).join(', ')}`; - return done(null, false, { - message: `You must have ${rolesList} role to log in.`, - }); - } - - if (!requiredRoles.some((role) => roles.includes(role))) { - const rolesList = - requiredRoles.length === 1 - ? `"${requiredRoles[0]}"` - : `one of: ${requiredRoles.map((r) => `"${r}"`).join(', ')}`; - return done(null, false, { - message: `You must have ${rolesList} role to log in.`, - }); - } - } - - let username = ''; - if (process.env.OPENID_USERNAME_CLAIM) { - username = userinfo[process.env.OPENID_USERNAME_CLAIM]; - } else { - username = convertToUsername( - userinfo.preferred_username || userinfo.username || userinfo.email, - ); - } - - if (!user) { - user = { - provider: 'openid', - openidId: userinfo.sub, - username, - email: email || '', - emailVerified: userinfo.email_verified || false, - name: fullName, - idOnTheSource: userinfo.oid, - }; - - const balanceConfig = getBalanceConfig(appConfig); - user = await createUser(user, balanceConfig, true, true); - } else { - user.provider = 'openid'; - user.openidId = userinfo.sub; - user.username = username; - user.name = fullName; - user.idOnTheSource = userinfo.oid; - if (email && email !== user.email) { - user.email = email; - user.emailVerified = userinfo.email_verified || false; - } - } - - if (adminRole && adminRoleParameterPath && adminRoleTokenKind) { - let adminRoleObject; - switch (adminRoleTokenKind) { - case 'access': - adminRoleObject = jwtDecode(tokenset.access_token); - break; - case 'id': - adminRoleObject = jwtDecode(tokenset.id_token); - break; - case 'userinfo': - adminRoleObject = userinfo; - break; - default: - logger.error( - `[openidStrategy] Invalid admin role token kind: ${adminRoleTokenKind}. Must be one of 'access', 'id', or 'userinfo'.`, - ); - return done(new Error('Invalid admin role token kind')); - } - - const adminRoles = get(adminRoleObject, adminRoleParameterPath); - - // Accept 3 types of values for the object extracted from adminRoleParameterPath: - // 1. A boolean value indicating if the user is an admin - // 2. A string with a single role name - // 3. An array of role names - - if ( - adminRoles && - (adminRoles === true || - adminRoles === adminRole || - (Array.isArray(adminRoles) && adminRoles.includes(adminRole))) - ) { - user.role = 'ADMIN'; - logger.info( - `[openidStrategy] User ${username} is an admin based on role: ${adminRole}`, - ); - } else if (user.role === 'ADMIN') { - user.role = 'USER'; - logger.info( - `[openidStrategy] User ${username} demoted from admin - role no longer present in token`, - ); - } - } - - if (!!userinfo && userinfo.picture && !user.avatar?.includes('manual=true')) { - /** @type {string | undefined} */ - const imageUrl = userinfo.picture; - - let fileName; - if (crypto) { - fileName = (await hashToken(userinfo.sub)) + '.png'; - } else { - fileName = userinfo.sub + '.png'; - } - - const imageBuffer = await downloadImage( - imageUrl, - openidConfig, - tokenset.access_token, - userinfo.sub, - ); - if (imageBuffer) { - const { saveBuffer } = getStrategyFunctions( - appConfig?.fileStrategy ?? process.env.CDN_PROVIDER, - ); - const imagePath = await saveBuffer({ - fileName, - userId: user._id.toString(), - buffer: imageBuffer, - }); - user.avatar = imagePath ?? ''; - } - } - - user = await updateUser(user._id, user); - - logger.info( - `[openidStrategy] login success openidId: ${user.openidId} | email: ${user.email} | username: ${user.username} `, - { - user: { - openidId: user.openidId, - username: user.username, - email: user.email, - name: user.name, - }, - }, - ); - - done(null, { - ...user, - tokenset, - federatedTokens: { - access_token: tokenset.access_token, - refresh_token: tokenset.refresh_token, - expires_at: tokenset.expires_at, - }, - }); - } catch (err) { - logger.error('[openidStrategy] login failed', err); - done(err); - } + usePKCE: isEnabled(process.env.OPENID_USE_PKCE), }, + createOpenIDCallback(), ); passport.use('openid', openidLogin); + setupOpenIdAdmin(openidConfig); return openidConfig; } catch (err) { logger.error('[openidStrategy]', err); return null; } } + /** * @function getOpenIdConfig * @description Returns the OpenID client instance. diff --git a/api/strategies/openidStrategy.spec.js b/api/strategies/openidStrategy.spec.js index 9ac22ff42f..b1dc54d77b 100644 --- a/api/strategies/openidStrategy.spec.js +++ b/api/strategies/openidStrategy.spec.js @@ -1,5 +1,6 @@ const fetch = require('node-fetch'); const jwtDecode = require('jsonwebtoken/decode'); +const undici = require('undici'); const { ErrorTypes } = require('librechat-data-provider'); const { findUser, createUser, updateUser } = require('~/models'); const { setupOpenId } = require('./openidStrategy'); @@ -7,6 +8,10 @@ const { setupOpenId } = require('./openidStrategy'); // --- Mocks --- jest.mock('node-fetch'); jest.mock('jsonwebtoken/decode'); +jest.mock('undici', () => ({ + fetch: jest.fn(), + ProxyAgent: jest.fn(), +})); jest.mock('~/server/services/Files/strategies', () => ({ getStrategyFunctions: jest.fn(() => ({ saveBuffer: jest.fn().mockResolvedValue('/fake/path/to/avatar.png'), @@ -64,21 +69,36 @@ jest.mock('openid-client', () => { }); jest.mock('openid-client/passport', () => { - let verifyCallback; + /** Store callbacks by strategy name - 'openid' and 'openidAdmin' */ + const verifyCallbacks = {}; + let lastVerifyCallback; + const mockStrategy = jest.fn((options, verify) => { - verifyCallback = verify; + lastVerifyCallback = verify; return { name: 'openid', options, verify }; }); return { Strategy: mockStrategy, - __getVerifyCallback: () => verifyCallback, + /** Get the last registered callback (for backward compatibility) */ + __getVerifyCallback: () => lastVerifyCallback, + /** Store callback by name when passport.use is called */ + __setVerifyCallback: (name, callback) => { + verifyCallbacks[name] = callback; + }, + /** Get callback by strategy name */ + __getVerifyCallbackByName: (name) => verifyCallbacks[name], }; }); -// Mock passport +// Mock passport - capture strategy name and callback jest.mock('passport', () => ({ - use: jest.fn(), + use: jest.fn((name, strategy) => { + const passportMock = require('openid-client/passport'); + if (strategy && strategy.verify) { + passportMock.__setVerifyCallback(name, strategy.verify); + } + }), })); describe('setupOpenId', () => { @@ -159,9 +179,10 @@ describe('setupOpenId', () => { }; fetch.mockResolvedValue(fakeResponse); - // Call the setup function and capture the verify callback + // Call the setup function and capture the verify callback for the regular 'openid' strategy + // (not 'openidAdmin' which requires existing users) await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); }); it('should create a new user with correct username when preferred_username claim exists', async () => { @@ -344,6 +365,25 @@ describe('setupOpenId', () => { expect(details.message).toBe('You must have "requiredRole" role to log in.'); }); + it('should not treat substring matches in string roles as satisfying required role', async () => { + // Arrange – override required role to "read" then re-setup + process.env.OPENID_REQUIRED_ROLE = 'read'; + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + // Token contains "bread" which *contains* "read" as a substring + jwtDecode.mockReturnValue({ + roles: 'bread', + }); + + // Act + const { user, details } = await validate(tokenset); + + // Assert – verify that substring match does not grant access + expect(user).toBe(false); + expect(details.message).toBe('You must have "read" role to log in.'); + }); + it('should allow login when single required role is present (backward compatibility)', async () => { // Arrange – ensure single role configuration (as set in beforeEach) // OPENID_REQUIRED_ROLE = 'requiredRole' @@ -362,6 +402,292 @@ describe('setupOpenId', () => { expect(createUser).toHaveBeenCalled(); }); + describe('group overage and groups handling', () => { + it.each([ + ['groups array contains required group', ['group-required', 'other-group'], true, undefined], + [ + 'groups array missing required group', + ['other-group'], + false, + 'You must have "group-required" role to log in.', + ], + ['groups string equals required group', 'group-required', true, undefined], + [ + 'groups string is other group', + 'other-group', + false, + 'You must have "group-required" role to log in.', + ], + ])( + 'uses groups claim directly when %s (no overage)', + async (_label, groupsClaim, expectedAllowed, expectedMessage) => { + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + + jwtDecode.mockReturnValue({ + groups: groupsClaim, + permissions: ['admin'], + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user, details } = await validate(tokenset); + + expect(undici.fetch).not.toHaveBeenCalled(); + expect(Boolean(user)).toBe(expectedAllowed); + expect(details?.message).toBe(expectedMessage); + }, + ); + + it.each([ + ['token kind is not id', { kind: 'access', path: 'groups', decoded: { hasgroups: true } }], + ['parameter path is not groups', { kind: 'id', path: 'roles', decoded: { hasgroups: true } }], + ['decoded token is falsy', { kind: 'id', path: 'groups', decoded: null }], + [ + 'no overage indicators in decoded token', + { + kind: 'id', + path: 'groups', + decoded: { + permissions: ['admin'], + }, + }, + ], + [ + 'only _claim_names present (no _claim_sources)', + { + kind: 'id', + path: 'groups', + decoded: { + _claim_names: { groups: 'src1' }, + permissions: ['admin'], + }, + }, + ], + [ + 'only _claim_sources present (no _claim_names)', + { + kind: 'id', + path: 'groups', + decoded: { + _claim_sources: { src1: { endpoint: 'https://graph.windows.net/ignored' } }, + permissions: ['admin'], + }, + }, + ], + ])('does not attempt overage resolution when %s', async (_label, cfg) => { + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = cfg.path; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = cfg.kind; + + jwtDecode.mockReturnValue(cfg.decoded); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const { user, details } = await validate(tokenset); + + expect(undici.fetch).not.toHaveBeenCalled(); + expect(user).toBe(false); + expect(details.message).toBe('You must have "group-required" role to log in.'); + const { logger } = require('@librechat/data-schemas'); + const expectedTokenKind = cfg.kind === 'access' ? 'access token' : 'id token'; + expect(logger.error).toHaveBeenCalledWith( + expect.stringContaining(`Key '${cfg.path}' not found in ${expectedTokenKind}!`), + ); + }); + }); + + describe('resolving groups via Microsoft Graph', () => { + it('denies login and does not call Graph when access token is missing', async () => { + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + + const { logger } = require('@librechat/data-schemas'); + + jwtDecode.mockReturnValue({ + hasgroups: true, + permissions: ['admin'], + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + const tokensetWithoutAccess = { + ...tokenset, + access_token: undefined, + }; + + const { user, details } = await validate(tokensetWithoutAccess); + + expect(user).toBe(false); + expect(details.message).toBe('You must have "group-required" role to log in.'); + + expect(undici.fetch).not.toHaveBeenCalled(); + expect(logger.error).toHaveBeenCalledWith( + expect.stringContaining('Access token missing; cannot resolve group overage'), + ); + }); + + it.each([ + [ + 'Graph returns HTTP error', + async () => ({ + ok: false, + status: 403, + statusText: 'Forbidden', + json: async () => ({}), + }), + [ + '[openidStrategy] Failed to resolve groups via Microsoft Graph getMemberObjects: HTTP 403 Forbidden', + ], + ], + [ + 'Graph network error', + async () => { + throw new Error('network error'); + }, + [ + '[openidStrategy] Error resolving groups via Microsoft Graph getMemberObjects:', + expect.any(Error), + ], + ], + [ + 'Graph returns unexpected shape (no value)', + async () => ({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({}), + }), + [ + '[openidStrategy] Unexpected response format when resolving groups via Microsoft Graph getMemberObjects', + ], + ], + [ + 'Graph returns invalid value type', + async () => ({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ value: 'not-an-array' }), + }), + [ + '[openidStrategy] Unexpected response format when resolving groups via Microsoft Graph getMemberObjects', + ], + ], + ])( + 'denies login when overage resolution fails because %s', + async (_label, setupFetch, expectedErrorArgs) => { + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + + const { logger } = require('@librechat/data-schemas'); + + jwtDecode.mockReturnValue({ + hasgroups: true, + permissions: ['admin'], + }); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockImplementation(setupFetch); + + const { user, details } = await validate(tokenset); + + expect(undici.fetch).toHaveBeenCalled(); + expect(user).toBe(false); + expect(details.message).toBe('You must have "group-required" role to log in.'); + + expect(logger.error).toHaveBeenCalledWith(...expectedErrorArgs); + }, + ); + + it.each([ + [ + 'hasgroups overage and Graph contains required group', + { + hasgroups: true, + }, + ['group-required', 'some-other-group'], + true, + ], + [ + '_claim_* overage and Graph contains required group', + { + _claim_names: { groups: 'src1' }, + _claim_sources: { src1: { endpoint: 'https://graph.windows.net/ignored' } }, + }, + ['group-required', 'some-other-group'], + true, + ], + [ + 'hasgroups overage and Graph does NOT contain required group', + { + hasgroups: true, + }, + ['some-other-group'], + false, + ], + [ + '_claim_* overage and Graph does NOT contain required group', + { + _claim_names: { groups: 'src1' }, + _claim_sources: { src1: { endpoint: 'https://graph.windows.net/ignored' } }, + }, + ['some-other-group'], + false, + ], + ])( + 'resolves groups via Microsoft Graph when %s', + async (_label, decodedTokenValue, graphGroups, expectedAllowed) => { + process.env.OPENID_REQUIRED_ROLE = 'group-required'; + process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH = 'groups'; + process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND = 'id'; + + const { logger } = require('@librechat/data-schemas'); + + jwtDecode.mockReturnValue(decodedTokenValue); + + await setupOpenId(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); + + undici.fetch.mockResolvedValue({ + ok: true, + status: 200, + statusText: 'OK', + json: async () => ({ + value: graphGroups, + }), + }); + + const { user } = await validate(tokenset); + + expect(undici.fetch).toHaveBeenCalledWith( + 'https://graph.microsoft.com/v1.0/me/getMemberObjects', + expect.objectContaining({ + method: 'POST', + headers: expect.objectContaining({ + Authorization: `Bearer ${tokenset.access_token}`, + }), + }), + ); + expect(Boolean(user)).toBe(expectedAllowed); + + expect(logger.debug).toHaveBeenCalledWith( + expect.stringContaining( + `Successfully resolved ${graphGroups.length} groups via Microsoft Graph getMemberObjects`, + ), + ); + }, + ); + }); + it('should attempt to download and save the avatar if picture is provided', async () => { // Act const { user } = await validate(tokenset); @@ -389,7 +715,7 @@ describe('setupOpenId', () => { // Arrange process.env.OPENID_REQUIRED_ROLE = 'someRole,anotherRole,admin'; await setupOpenId(); // Re-initialize the strategy - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); jwtDecode.mockReturnValue({ roles: ['anotherRole', 'aThirdRole'], }); @@ -406,7 +732,7 @@ describe('setupOpenId', () => { // Arrange process.env.OPENID_REQUIRED_ROLE = 'someRole,anotherRole,admin'; await setupOpenId(); // Re-initialize the strategy - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); jwtDecode.mockReturnValue({ roles: ['aThirdRole', 'aFourthRole'], }); @@ -425,7 +751,7 @@ describe('setupOpenId', () => { // Arrange process.env.OPENID_REQUIRED_ROLE = ' someRole , anotherRole , admin '; await setupOpenId(); // Re-initialize the strategy - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); jwtDecode.mockReturnValue({ roles: ['someRole'], }); @@ -449,10 +775,11 @@ describe('setupOpenId', () => { }); it('should attach federatedTokens to user object for token propagation', async () => { - // Arrange - setup tokenset with access token, refresh token, and expiration + // Arrange - setup tokenset with access token, id token, refresh token, and expiration const tokensetWithTokens = { ...tokenset, access_token: 'mock_access_token_abc123', + id_token: 'mock_id_token_def456', refresh_token: 'mock_refresh_token_xyz789', expires_at: 1234567890, }; @@ -464,16 +791,37 @@ describe('setupOpenId', () => { expect(user.federatedTokens).toBeDefined(); expect(user.federatedTokens).toEqual({ access_token: 'mock_access_token_abc123', + id_token: 'mock_id_token_def456', refresh_token: 'mock_refresh_token_xyz789', expires_at: 1234567890, }); }); + it('should include id_token in federatedTokens distinct from access_token', async () => { + // Arrange - use different values for access_token and id_token + const tokensetWithTokens = { + ...tokenset, + access_token: 'the_access_token', + id_token: 'the_id_token', + refresh_token: 'the_refresh_token', + expires_at: 9999999999, + }; + + // Act + const { user } = await validate(tokensetWithTokens); + + // Assert - id_token and access_token must be different values + expect(user.federatedTokens.access_token).toBe('the_access_token'); + expect(user.federatedTokens.id_token).toBe('the_id_token'); + expect(user.federatedTokens.id_token).not.toBe(user.federatedTokens.access_token); + }); + it('should include tokenset along with federatedTokens', async () => { // Arrange const tokensetWithTokens = { ...tokenset, access_token: 'test_access_token', + id_token: 'test_id_token', refresh_token: 'test_refresh_token', expires_at: 9999999999, }; @@ -485,7 +833,9 @@ describe('setupOpenId', () => { expect(user.tokenset).toBeDefined(); expect(user.federatedTokens).toBeDefined(); expect(user.tokenset.access_token).toBe('test_access_token'); + expect(user.tokenset.id_token).toBe('test_id_token'); expect(user.federatedTokens.access_token).toBe('test_access_token'); + expect(user.federatedTokens.id_token).toBe('test_id_token'); }); it('should set role to "ADMIN" if OPENID_ADMIN_ROLE is set and user has that role', async () => { @@ -560,7 +910,7 @@ describe('setupOpenId', () => { delete process.env.OPENID_ADMIN_ROLE_TOKEN_KIND; await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); // Simulate an existing admin user const existingAdminUser = { @@ -611,7 +961,7 @@ describe('setupOpenId', () => { }); await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); const { user } = await validate(tokenset); @@ -634,7 +984,7 @@ describe('setupOpenId', () => { }); await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); const { user } = await validate(tokenset); @@ -655,14 +1005,12 @@ describe('setupOpenId', () => { }); await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); const { user, details } = await validate(tokenset); expect(logger.error).toHaveBeenCalledWith( - expect.stringContaining( - "Key 'resource_access.nonexistent.roles' not found or invalid type in id token!", - ), + expect.stringContaining("Key 'resource_access.nonexistent.roles' not found in id token!"), ); expect(user).toBe(false); expect(details.message).toContain('role to log in'); @@ -680,12 +1028,12 @@ describe('setupOpenId', () => { }); await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); const { user } = await validate(tokenset); expect(logger.error).toHaveBeenCalledWith( - expect.stringContaining("Key 'org.team.roles' not found or invalid type in id token!"), + expect.stringContaining("Key 'org.team.roles' not found in id token!"), ); expect(user).toBe(false); }); @@ -709,7 +1057,7 @@ describe('setupOpenId', () => { }); await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); const { user } = await validate(tokenset); @@ -739,7 +1087,7 @@ describe('setupOpenId', () => { }); await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); const { user } = await validate({ ...tokenset, @@ -759,7 +1107,7 @@ describe('setupOpenId', () => { }); await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); const { user } = await validate(tokenset); @@ -776,7 +1124,7 @@ describe('setupOpenId', () => { }); await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); const { user } = await validate(tokenset); @@ -793,7 +1141,7 @@ describe('setupOpenId', () => { }); await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); const { user } = await validate(tokenset); @@ -810,7 +1158,7 @@ describe('setupOpenId', () => { }); await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); const { user } = await validate(tokenset); @@ -827,7 +1175,7 @@ describe('setupOpenId', () => { }); await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); const { user } = await validate(tokenset); @@ -847,7 +1195,7 @@ describe('setupOpenId', () => { }); await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); const { user } = await validate(tokenset); @@ -864,12 +1212,12 @@ describe('setupOpenId', () => { }); await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); const { user } = await validate(tokenset); expect(logger.error).toHaveBeenCalledWith( - expect.stringContaining("Key 'access.roles' not found or invalid type in id token!"), + expect.stringContaining("Key 'access.roles' not found in id token!"), ); expect(user).toBe(false); }); @@ -884,12 +1232,12 @@ describe('setupOpenId', () => { }); await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); const { user } = await validate(tokenset); expect(logger.error).toHaveBeenCalledWith( - expect.stringContaining("Key 'data.roles' not found or invalid type in id token!"), + expect.stringContaining("Key 'data.roles' not found in id token!"), ); expect(user).toBe(false); }); @@ -906,7 +1254,7 @@ describe('setupOpenId', () => { }); await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); await expect(validate(tokenset)).rejects.toThrow('Invalid admin role token kind'); @@ -927,12 +1275,12 @@ describe('setupOpenId', () => { }); await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); const { user, details } = await validate(tokenset); expect(logger.error).toHaveBeenCalledWith( - expect.stringContaining("Key 'roles' not found or invalid type in id token!"), + expect.stringContaining("Key 'roles' not found in id token!"), ); expect(user).toBe(false); expect(details.message).toContain('role to log in'); @@ -948,12 +1296,12 @@ describe('setupOpenId', () => { }); await setupOpenId(); - verifyCallback = require('openid-client/passport').__getVerifyCallback(); + verifyCallback = require('openid-client/passport').__getVerifyCallbackByName('openid'); const { user } = await validate(tokenset); expect(logger.error).toHaveBeenCalledWith( - expect.stringContaining("Key 'roleCount' not found or invalid type in id token!"), + expect.stringContaining("Key 'roleCount' not found in id token!"), ); expect(user).toBe(false); }); diff --git a/api/utils/tokens.spec.js b/api/utils/tokens.spec.js index 3336a0f82d..18905d6d18 100644 --- a/api/utils/tokens.spec.js +++ b/api/utils/tokens.spec.js @@ -1,3 +1,4 @@ +/** Note: No hard-coded values should be used in this file. */ const { EModelEndpoint } = require('librechat-data-provider'); const { maxTokensMap, @@ -626,41 +627,45 @@ describe('matchModelName', () => { describe('Meta Models Tests', () => { describe('getModelMaxTokens', () => { test('should return correct tokens for LLaMa 2 models', () => { - expect(getModelMaxTokens('llama2')).toBe(4000); - expect(getModelMaxTokens('llama2.70b')).toBe(4000); - expect(getModelMaxTokens('llama2-13b')).toBe(4000); - expect(getModelMaxTokens('llama2-70b')).toBe(4000); + const llama2Tokens = maxTokensMap[EModelEndpoint.openAI]['llama2']; + expect(getModelMaxTokens('llama2')).toBe(llama2Tokens); + expect(getModelMaxTokens('llama2.70b')).toBe(llama2Tokens); + expect(getModelMaxTokens('llama2-13b')).toBe(llama2Tokens); + expect(getModelMaxTokens('llama2-70b')).toBe(llama2Tokens); }); test('should return correct tokens for LLaMa 3 models', () => { - expect(getModelMaxTokens('llama3')).toBe(8000); - expect(getModelMaxTokens('llama3.8b')).toBe(8000); - expect(getModelMaxTokens('llama3.70b')).toBe(8000); - expect(getModelMaxTokens('llama3-8b')).toBe(8000); - expect(getModelMaxTokens('llama3-70b')).toBe(8000); + const llama3Tokens = maxTokensMap[EModelEndpoint.openAI]['llama3']; + expect(getModelMaxTokens('llama3')).toBe(llama3Tokens); + expect(getModelMaxTokens('llama3.8b')).toBe(llama3Tokens); + expect(getModelMaxTokens('llama3.70b')).toBe(llama3Tokens); + expect(getModelMaxTokens('llama3-8b')).toBe(llama3Tokens); + expect(getModelMaxTokens('llama3-70b')).toBe(llama3Tokens); }); test('should return correct tokens for LLaMa 3.1 models', () => { - expect(getModelMaxTokens('llama3.1:8b')).toBe(127500); - expect(getModelMaxTokens('llama3.1:70b')).toBe(127500); - expect(getModelMaxTokens('llama3.1:405b')).toBe(127500); - expect(getModelMaxTokens('llama3-1-8b')).toBe(127500); - expect(getModelMaxTokens('llama3-1-70b')).toBe(127500); - expect(getModelMaxTokens('llama3-1-405b')).toBe(127500); + const llama31Tokens = maxTokensMap[EModelEndpoint.openAI]['llama3.1:8b']; + expect(getModelMaxTokens('llama3.1:8b')).toBe(llama31Tokens); + expect(getModelMaxTokens('llama3.1:70b')).toBe(llama31Tokens); + expect(getModelMaxTokens('llama3.1:405b')).toBe(llama31Tokens); + expect(getModelMaxTokens('llama3-1-8b')).toBe(llama31Tokens); + expect(getModelMaxTokens('llama3-1-70b')).toBe(llama31Tokens); + expect(getModelMaxTokens('llama3-1-405b')).toBe(llama31Tokens); }); test('should handle partial matches for Meta models', () => { - // Test with full model names - expect(getModelMaxTokens('meta/llama3.1:405b')).toBe(127500); - expect(getModelMaxTokens('meta/llama3.1:70b')).toBe(127500); - expect(getModelMaxTokens('meta/llama3.1:8b')).toBe(127500); - expect(getModelMaxTokens('meta/llama3-1-8b')).toBe(127500); + const llama31Tokens = maxTokensMap[EModelEndpoint.openAI]['llama3.1:8b']; + const llama3Tokens = maxTokensMap[EModelEndpoint.openAI]['llama3']; + const llama2Tokens = maxTokensMap[EModelEndpoint.openAI]['llama2']; + expect(getModelMaxTokens('meta/llama3.1:405b')).toBe(llama31Tokens); + expect(getModelMaxTokens('meta/llama3.1:70b')).toBe(llama31Tokens); + expect(getModelMaxTokens('meta/llama3.1:8b')).toBe(llama31Tokens); + expect(getModelMaxTokens('meta/llama3-1-8b')).toBe(llama31Tokens); - // Test base versions - expect(getModelMaxTokens('meta/llama3.1')).toBe(127500); - expect(getModelMaxTokens('meta/llama3-1')).toBe(127500); - expect(getModelMaxTokens('meta/llama3')).toBe(8000); - expect(getModelMaxTokens('meta/llama2')).toBe(4000); + expect(getModelMaxTokens('meta/llama3.1')).toBe(llama31Tokens); + expect(getModelMaxTokens('meta/llama3-1')).toBe(llama31Tokens); + expect(getModelMaxTokens('meta/llama3')).toBe(llama3Tokens); + expect(getModelMaxTokens('meta/llama2')).toBe(llama2Tokens); }); test('should match Deepseek model variations', () => { @@ -678,18 +683,33 @@ describe('Meta Models Tests', () => { ); }); - test('should return 128000 context tokens for all DeepSeek models', () => { - expect(getModelMaxTokens('deepseek-chat')).toBe(128000); - expect(getModelMaxTokens('deepseek-reasoner')).toBe(128000); - expect(getModelMaxTokens('deepseek-r1')).toBe(128000); - expect(getModelMaxTokens('deepseek-v3')).toBe(128000); - expect(getModelMaxTokens('deepseek.r1')).toBe(128000); + test('should return correct context tokens for all DeepSeek models', () => { + const deepseekChatTokens = maxTokensMap[EModelEndpoint.openAI]['deepseek-chat']; + expect(getModelMaxTokens('deepseek-chat')).toBe(deepseekChatTokens); + expect(getModelMaxTokens('deepseek-reasoner')).toBe( + maxTokensMap[EModelEndpoint.openAI]['deepseek-reasoner'], + ); + expect(getModelMaxTokens('deepseek-r1')).toBe( + maxTokensMap[EModelEndpoint.openAI]['deepseek-r1'], + ); + expect(getModelMaxTokens('deepseek-v3')).toBe( + maxTokensMap[EModelEndpoint.openAI]['deepseek'], + ); + expect(getModelMaxTokens('deepseek.r1')).toBe( + maxTokensMap[EModelEndpoint.openAI]['deepseek.r1'], + ); }); test('should handle DeepSeek models with provider prefixes', () => { - expect(getModelMaxTokens('deepseek/deepseek-chat')).toBe(128000); - expect(getModelMaxTokens('openrouter/deepseek-reasoner')).toBe(128000); - expect(getModelMaxTokens('openai/deepseek-v3')).toBe(128000); + expect(getModelMaxTokens('deepseek/deepseek-chat')).toBe( + maxTokensMap[EModelEndpoint.openAI]['deepseek-chat'], + ); + expect(getModelMaxTokens('openrouter/deepseek-reasoner')).toBe( + maxTokensMap[EModelEndpoint.openAI]['deepseek-reasoner'], + ); + expect(getModelMaxTokens('openai/deepseek-v3')).toBe( + maxTokensMap[EModelEndpoint.openAI]['deepseek'], + ); }); }); @@ -728,30 +748,38 @@ describe('Meta Models Tests', () => { const { getModelMaxOutputTokens } = require('@librechat/api'); test('should return correct max output tokens for deepseek-chat', () => { - expect(getModelMaxOutputTokens('deepseek-chat')).toBe(8000); - expect(getModelMaxOutputTokens('deepseek-chat', EModelEndpoint.openAI)).toBe(8000); - expect(getModelMaxOutputTokens('deepseek-chat', EModelEndpoint.custom)).toBe(8000); + const expected = maxOutputTokensMap[EModelEndpoint.openAI]['deepseek-chat']; + expect(getModelMaxOutputTokens('deepseek-chat')).toBe(expected); + expect(getModelMaxOutputTokens('deepseek-chat', EModelEndpoint.openAI)).toBe(expected); + expect(getModelMaxOutputTokens('deepseek-chat', EModelEndpoint.custom)).toBe(expected); }); test('should return correct max output tokens for deepseek-reasoner', () => { - expect(getModelMaxOutputTokens('deepseek-reasoner')).toBe(64000); - expect(getModelMaxOutputTokens('deepseek-reasoner', EModelEndpoint.openAI)).toBe(64000); - expect(getModelMaxOutputTokens('deepseek-reasoner', EModelEndpoint.custom)).toBe(64000); + const expected = maxOutputTokensMap[EModelEndpoint.openAI]['deepseek-reasoner']; + expect(getModelMaxOutputTokens('deepseek-reasoner')).toBe(expected); + expect(getModelMaxOutputTokens('deepseek-reasoner', EModelEndpoint.openAI)).toBe(expected); + expect(getModelMaxOutputTokens('deepseek-reasoner', EModelEndpoint.custom)).toBe(expected); }); test('should return correct max output tokens for deepseek-r1', () => { - expect(getModelMaxOutputTokens('deepseek-r1')).toBe(64000); - expect(getModelMaxOutputTokens('deepseek-r1', EModelEndpoint.openAI)).toBe(64000); + const expected = maxOutputTokensMap[EModelEndpoint.openAI]['deepseek-r1']; + expect(getModelMaxOutputTokens('deepseek-r1')).toBe(expected); + expect(getModelMaxOutputTokens('deepseek-r1', EModelEndpoint.openAI)).toBe(expected); }); test('should return correct max output tokens for deepseek base pattern', () => { - expect(getModelMaxOutputTokens('deepseek')).toBe(8000); - expect(getModelMaxOutputTokens('deepseek-v3')).toBe(8000); + const expected = maxOutputTokensMap[EModelEndpoint.openAI]['deepseek']; + expect(getModelMaxOutputTokens('deepseek')).toBe(expected); + expect(getModelMaxOutputTokens('deepseek-v3')).toBe(expected); }); test('should handle DeepSeek models with provider prefixes for max output tokens', () => { - expect(getModelMaxOutputTokens('deepseek/deepseek-chat')).toBe(8000); - expect(getModelMaxOutputTokens('openrouter/deepseek-reasoner')).toBe(64000); + expect(getModelMaxOutputTokens('deepseek/deepseek-chat')).toBe( + maxOutputTokensMap[EModelEndpoint.openAI]['deepseek-chat'], + ); + expect(getModelMaxOutputTokens('openrouter/deepseek-reasoner')).toBe( + maxOutputTokensMap[EModelEndpoint.openAI]['deepseek-reasoner'], + ); }); }); @@ -796,68 +824,90 @@ describe('Meta Models Tests', () => { describe('Grok Model Tests - Tokens', () => { describe('getModelMaxTokens', () => { test('should return correct tokens for Grok vision models', () => { - expect(getModelMaxTokens('grok-2-vision-1212')).toBe(32768); - expect(getModelMaxTokens('grok-2-vision')).toBe(32768); - expect(getModelMaxTokens('grok-2-vision-latest')).toBe(32768); + const grok2VisionTokens = maxTokensMap[EModelEndpoint.openAI]['grok-2-vision']; + expect(getModelMaxTokens('grok-2-vision-1212')).toBe(grok2VisionTokens); + expect(getModelMaxTokens('grok-2-vision')).toBe(grok2VisionTokens); + expect(getModelMaxTokens('grok-2-vision-latest')).toBe(grok2VisionTokens); }); test('should return correct tokens for Grok beta models', () => { - expect(getModelMaxTokens('grok-vision-beta')).toBe(8192); - expect(getModelMaxTokens('grok-beta')).toBe(131072); + expect(getModelMaxTokens('grok-vision-beta')).toBe( + maxTokensMap[EModelEndpoint.openAI]['grok-vision-beta'], + ); + expect(getModelMaxTokens('grok-beta')).toBe(maxTokensMap[EModelEndpoint.openAI]['grok-beta']); }); test('should return correct tokens for Grok text models', () => { - expect(getModelMaxTokens('grok-2-1212')).toBe(131072); - expect(getModelMaxTokens('grok-2')).toBe(131072); - expect(getModelMaxTokens('grok-2-latest')).toBe(131072); + const grok2Tokens = maxTokensMap[EModelEndpoint.openAI]['grok-2']; + expect(getModelMaxTokens('grok-2-1212')).toBe(grok2Tokens); + expect(getModelMaxTokens('grok-2')).toBe(grok2Tokens); + expect(getModelMaxTokens('grok-2-latest')).toBe(grok2Tokens); }); test('should return correct tokens for Grok 3 series models', () => { - expect(getModelMaxTokens('grok-3')).toBe(131072); - expect(getModelMaxTokens('grok-3-fast')).toBe(131072); - expect(getModelMaxTokens('grok-3-mini')).toBe(131072); - expect(getModelMaxTokens('grok-3-mini-fast')).toBe(131072); + expect(getModelMaxTokens('grok-3')).toBe(maxTokensMap[EModelEndpoint.openAI]['grok-3']); + expect(getModelMaxTokens('grok-3-fast')).toBe( + maxTokensMap[EModelEndpoint.openAI]['grok-3-fast'], + ); + expect(getModelMaxTokens('grok-3-mini')).toBe( + maxTokensMap[EModelEndpoint.openAI]['grok-3-mini'], + ); + expect(getModelMaxTokens('grok-3-mini-fast')).toBe( + maxTokensMap[EModelEndpoint.openAI]['grok-3-mini-fast'], + ); }); test('should return correct tokens for Grok 4 model', () => { - expect(getModelMaxTokens('grok-4-0709')).toBe(256000); + expect(getModelMaxTokens('grok-4-0709')).toBe(maxTokensMap[EModelEndpoint.openAI]['grok-4']); }); test('should return correct tokens for Grok 4 Fast and Grok 4.1 Fast models', () => { - expect(getModelMaxTokens('grok-4-fast')).toBe(2000000); - expect(getModelMaxTokens('grok-4-1-fast-reasoning')).toBe(2000000); - expect(getModelMaxTokens('grok-4-1-fast-non-reasoning')).toBe(2000000); + const grok4FastTokens = maxTokensMap[EModelEndpoint.openAI]['grok-4-fast']; + const grok41FastTokens = maxTokensMap[EModelEndpoint.openAI]['grok-4-1-fast']; + expect(getModelMaxTokens('grok-4-fast')).toBe(grok4FastTokens); + expect(getModelMaxTokens('grok-4-1-fast-reasoning')).toBe(grok41FastTokens); + expect(getModelMaxTokens('grok-4-1-fast-non-reasoning')).toBe(grok41FastTokens); }); test('should return correct tokens for Grok Code Fast model', () => { - expect(getModelMaxTokens('grok-code-fast-1')).toBe(256000); + expect(getModelMaxTokens('grok-code-fast-1')).toBe( + maxTokensMap[EModelEndpoint.openAI]['grok-code-fast'], + ); }); test('should handle partial matches for Grok models with prefixes', () => { - // Vision models should match before general models - expect(getModelMaxTokens('xai/grok-2-vision-1212')).toBe(32768); - expect(getModelMaxTokens('xai/grok-2-vision')).toBe(32768); - expect(getModelMaxTokens('xai/grok-2-vision-latest')).toBe(32768); - // Beta models - expect(getModelMaxTokens('xai/grok-vision-beta')).toBe(8192); - expect(getModelMaxTokens('xai/grok-beta')).toBe(131072); - // Text models - expect(getModelMaxTokens('xai/grok-2-1212')).toBe(131072); - expect(getModelMaxTokens('xai/grok-2')).toBe(131072); - expect(getModelMaxTokens('xai/grok-2-latest')).toBe(131072); - // Grok 3 models - expect(getModelMaxTokens('xai/grok-3')).toBe(131072); - expect(getModelMaxTokens('xai/grok-3-fast')).toBe(131072); - expect(getModelMaxTokens('xai/grok-3-mini')).toBe(131072); - expect(getModelMaxTokens('xai/grok-3-mini-fast')).toBe(131072); - // Grok 4 model - expect(getModelMaxTokens('xai/grok-4-0709')).toBe(256000); - // Grok 4 Fast and 4.1 Fast models - expect(getModelMaxTokens('xai/grok-4-fast')).toBe(2000000); - expect(getModelMaxTokens('xai/grok-4-1-fast-reasoning')).toBe(2000000); - expect(getModelMaxTokens('xai/grok-4-1-fast-non-reasoning')).toBe(2000000); - // Grok Code Fast model - expect(getModelMaxTokens('xai/grok-code-fast-1')).toBe(256000); + const grok2VisionTokens = maxTokensMap[EModelEndpoint.openAI]['grok-2-vision']; + const grokVisionBetaTokens = maxTokensMap[EModelEndpoint.openAI]['grok-vision-beta']; + const grokBetaTokens = maxTokensMap[EModelEndpoint.openAI]['grok-beta']; + const grok2Tokens = maxTokensMap[EModelEndpoint.openAI]['grok-2']; + const grok3Tokens = maxTokensMap[EModelEndpoint.openAI]['grok-3']; + const grok4Tokens = maxTokensMap[EModelEndpoint.openAI]['grok-4']; + const grok4FastTokens = maxTokensMap[EModelEndpoint.openAI]['grok-4-fast']; + const grok41FastTokens = maxTokensMap[EModelEndpoint.openAI]['grok-4-1-fast']; + const grokCodeFastTokens = maxTokensMap[EModelEndpoint.openAI]['grok-code-fast']; + expect(getModelMaxTokens('xai/grok-2-vision-1212')).toBe(grok2VisionTokens); + expect(getModelMaxTokens('xai/grok-2-vision')).toBe(grok2VisionTokens); + expect(getModelMaxTokens('xai/grok-2-vision-latest')).toBe(grok2VisionTokens); + expect(getModelMaxTokens('xai/grok-vision-beta')).toBe(grokVisionBetaTokens); + expect(getModelMaxTokens('xai/grok-beta')).toBe(grokBetaTokens); + expect(getModelMaxTokens('xai/grok-2-1212')).toBe(grok2Tokens); + expect(getModelMaxTokens('xai/grok-2')).toBe(grok2Tokens); + expect(getModelMaxTokens('xai/grok-2-latest')).toBe(grok2Tokens); + expect(getModelMaxTokens('xai/grok-3')).toBe(grok3Tokens); + expect(getModelMaxTokens('xai/grok-3-fast')).toBe( + maxTokensMap[EModelEndpoint.openAI]['grok-3-fast'], + ); + expect(getModelMaxTokens('xai/grok-3-mini')).toBe( + maxTokensMap[EModelEndpoint.openAI]['grok-3-mini'], + ); + expect(getModelMaxTokens('xai/grok-3-mini-fast')).toBe( + maxTokensMap[EModelEndpoint.openAI]['grok-3-mini-fast'], + ); + expect(getModelMaxTokens('xai/grok-4-0709')).toBe(grok4Tokens); + expect(getModelMaxTokens('xai/grok-4-fast')).toBe(grok4FastTokens); + expect(getModelMaxTokens('xai/grok-4-1-fast-reasoning')).toBe(grok41FastTokens); + expect(getModelMaxTokens('xai/grok-4-1-fast-non-reasoning')).toBe(grok41FastTokens); + expect(getModelMaxTokens('xai/grok-code-fast-1')).toBe(grokCodeFastTokens); }); }); @@ -1062,46 +1112,251 @@ describe('Claude Model Tests', () => { expect(matchModelName(model, EModelEndpoint.anthropic)).toBe(expectedModel); }); }); + + it('should return correct context length for Claude Opus 4.6 (1M)', () => { + expect(getModelMaxTokens('claude-opus-4-6', EModelEndpoint.anthropic)).toBe( + maxTokensMap[EModelEndpoint.anthropic]['claude-opus-4-6'], + ); + expect(getModelMaxTokens('claude-opus-4-6')).toBe( + maxTokensMap[EModelEndpoint.anthropic]['claude-opus-4-6'], + ); + }); + + it('should return correct max output tokens for Claude Opus 4.6 (128K)', () => { + const { getModelMaxOutputTokens } = require('@librechat/api'); + expect(getModelMaxOutputTokens('claude-opus-4-6', EModelEndpoint.anthropic)).toBe( + maxOutputTokensMap[EModelEndpoint.anthropic]['claude-opus-4-6'], + ); + }); + + it('should handle Claude Opus 4.6 model name variations', () => { + const modelVariations = [ + 'claude-opus-4-6', + 'claude-opus-4-6-20250801', + 'claude-opus-4-6-latest', + 'anthropic/claude-opus-4-6', + 'claude-opus-4-6/anthropic', + 'claude-opus-4-6-preview', + ]; + + modelVariations.forEach((model) => { + const modelKey = findMatchingPattern(model, maxTokensMap[EModelEndpoint.anthropic]); + expect(modelKey).toBe('claude-opus-4-6'); + expect(getModelMaxTokens(model, EModelEndpoint.anthropic)).toBe( + maxTokensMap[EModelEndpoint.anthropic]['claude-opus-4-6'], + ); + }); + }); + + it('should match model names correctly for Claude Opus 4.6', () => { + const modelVariations = [ + 'claude-opus-4-6', + 'claude-opus-4-6-20250801', + 'claude-opus-4-6-latest', + 'anthropic/claude-opus-4-6', + 'claude-opus-4-6/anthropic', + 'claude-opus-4-6-preview', + ]; + + modelVariations.forEach((model) => { + expect(matchModelName(model, EModelEndpoint.anthropic)).toBe('claude-opus-4-6'); + }); + }); + + it('should return correct context length for Claude Sonnet 4.6 (1M)', () => { + expect(getModelMaxTokens('claude-sonnet-4-6', EModelEndpoint.anthropic)).toBe( + maxTokensMap[EModelEndpoint.anthropic]['claude-sonnet-4-6'], + ); + expect(getModelMaxTokens('claude-sonnet-4-6')).toBe( + maxTokensMap[EModelEndpoint.anthropic]['claude-sonnet-4-6'], + ); + }); + + it('should return correct max output tokens for Claude Sonnet 4.6 (64K)', () => { + const { getModelMaxOutputTokens } = require('@librechat/api'); + expect(getModelMaxOutputTokens('claude-sonnet-4-6', EModelEndpoint.anthropic)).toBe( + maxOutputTokensMap[EModelEndpoint.anthropic]['claude-sonnet-4-6'], + ); + }); + + it('should handle Claude Sonnet 4.6 model name variations', () => { + const modelVariations = [ + 'claude-sonnet-4-6', + 'claude-sonnet-4-6-20260101', + 'claude-sonnet-4-6-latest', + 'anthropic/claude-sonnet-4-6', + 'claude-sonnet-4-6/anthropic', + 'claude-sonnet-4-6-preview', + ]; + + modelVariations.forEach((model) => { + const modelKey = findMatchingPattern(model, maxTokensMap[EModelEndpoint.anthropic]); + expect(modelKey).toBe('claude-sonnet-4-6'); + expect(getModelMaxTokens(model, EModelEndpoint.anthropic)).toBe( + maxTokensMap[EModelEndpoint.anthropic]['claude-sonnet-4-6'], + ); + }); + }); + + it('should match model names correctly for Claude Sonnet 4.6', () => { + const modelVariations = [ + 'claude-sonnet-4-6', + 'claude-sonnet-4-6-20260101', + 'claude-sonnet-4-6-latest', + 'anthropic/claude-sonnet-4-6', + 'claude-sonnet-4-6/anthropic', + 'claude-sonnet-4-6-preview', + ]; + + modelVariations.forEach((model) => { + expect(matchModelName(model, EModelEndpoint.anthropic)).toBe('claude-sonnet-4-6'); + }); + }); }); -describe('Kimi Model Tests', () => { +describe('Moonshot/Kimi Model Tests', () => { describe('getModelMaxTokens', () => { - test('should return correct tokens for Kimi models', () => { - expect(getModelMaxTokens('kimi')).toBe(131000); - expect(getModelMaxTokens('kimi-k2')).toBe(131000); - expect(getModelMaxTokens('kimi-vl')).toBe(131000); + test('should return correct tokens for kimi-k2.5 (multi-modal)', () => { + expect(getModelMaxTokens('kimi-k2.5')).toBe(maxTokensMap[EModelEndpoint.openAI]['kimi-k2.5']); + expect(getModelMaxTokens('kimi-k2.5-latest')).toBe( + maxTokensMap[EModelEndpoint.openAI]['kimi-k2.5'], + ); }); - test('should return correct tokens for Kimi models with provider prefix', () => { - expect(getModelMaxTokens('moonshotai/kimi-k2')).toBe(131000); - expect(getModelMaxTokens('moonshotai/kimi')).toBe(131000); - expect(getModelMaxTokens('moonshotai/kimi-vl')).toBe(131000); + test('should return correct tokens for kimi-k2 series models', () => { + expect(getModelMaxTokens('kimi')).toBe(maxTokensMap[EModelEndpoint.openAI]['kimi']); + expect(getModelMaxTokens('kimi-k2')).toBe(maxTokensMap[EModelEndpoint.openAI]['kimi-k2']); + expect(getModelMaxTokens('kimi-k2-turbo')).toBe( + maxTokensMap[EModelEndpoint.openAI]['kimi-k2-turbo'], + ); + expect(getModelMaxTokens('kimi-k2-turbo-preview')).toBe( + maxTokensMap[EModelEndpoint.openAI]['kimi-k2-turbo-preview'], + ); + expect(getModelMaxTokens('kimi-k2-0905')).toBe( + maxTokensMap[EModelEndpoint.openAI]['kimi-k2-0905'], + ); + expect(getModelMaxTokens('kimi-k2-0905-preview')).toBe( + maxTokensMap[EModelEndpoint.openAI]['kimi-k2-0905-preview'], + ); + expect(getModelMaxTokens('kimi-k2-thinking')).toBe( + maxTokensMap[EModelEndpoint.openAI]['kimi-k2-thinking'], + ); + expect(getModelMaxTokens('kimi-k2-thinking-turbo')).toBe( + maxTokensMap[EModelEndpoint.openAI]['kimi-k2-thinking-turbo'], + ); }); - test('should handle partial matches for Kimi models', () => { - expect(getModelMaxTokens('kimi-k2-latest')).toBe(131000); - expect(getModelMaxTokens('kimi-vl-preview')).toBe(131000); - expect(getModelMaxTokens('kimi-2024')).toBe(131000); + test('should return correct tokens for kimi-k2-0711 (smaller context)', () => { + expect(getModelMaxTokens('kimi-k2-0711')).toBe( + maxTokensMap[EModelEndpoint.openAI]['kimi-k2-0711'], + ); + expect(getModelMaxTokens('kimi-k2-0711-preview')).toBe( + maxTokensMap[EModelEndpoint.openAI]['kimi-k2-0711-preview'], + ); + }); + + test('should return correct tokens for kimi-latest', () => { + expect(getModelMaxTokens('kimi-latest')).toBe( + maxTokensMap[EModelEndpoint.openAI]['kimi-latest'], + ); + }); + + test('should return correct tokens for moonshot-v1 series models', () => { + expect(getModelMaxTokens('moonshot')).toBe(maxTokensMap[EModelEndpoint.openAI]['moonshot']); + expect(getModelMaxTokens('moonshot-v1')).toBe( + maxTokensMap[EModelEndpoint.openAI]['moonshot-v1'], + ); + expect(getModelMaxTokens('moonshot-v1-auto')).toBe( + maxTokensMap[EModelEndpoint.openAI]['moonshot-v1-auto'], + ); + expect(getModelMaxTokens('moonshot-v1-8k')).toBe( + maxTokensMap[EModelEndpoint.openAI]['moonshot-v1-8k'], + ); + expect(getModelMaxTokens('moonshot-v1-8k-vision')).toBe( + maxTokensMap[EModelEndpoint.openAI]['moonshot-v1-8k-vision'], + ); + expect(getModelMaxTokens('moonshot-v1-8k-vision-preview')).toBe( + maxTokensMap[EModelEndpoint.openAI]['moonshot-v1-8k-vision-preview'], + ); + expect(getModelMaxTokens('moonshot-v1-32k')).toBe( + maxTokensMap[EModelEndpoint.openAI]['moonshot-v1-32k'], + ); + expect(getModelMaxTokens('moonshot-v1-32k-vision')).toBe( + maxTokensMap[EModelEndpoint.openAI]['moonshot-v1-32k-vision'], + ); + expect(getModelMaxTokens('moonshot-v1-32k-vision-preview')).toBe( + maxTokensMap[EModelEndpoint.openAI]['moonshot-v1-32k-vision-preview'], + ); + expect(getModelMaxTokens('moonshot-v1-128k')).toBe( + maxTokensMap[EModelEndpoint.openAI]['moonshot-v1-128k'], + ); + expect(getModelMaxTokens('moonshot-v1-128k-vision')).toBe( + maxTokensMap[EModelEndpoint.openAI]['moonshot-v1-128k-vision'], + ); + expect(getModelMaxTokens('moonshot-v1-128k-vision-preview')).toBe( + maxTokensMap[EModelEndpoint.openAI]['moonshot-v1-128k-vision-preview'], + ); + }); + + test('should return correct tokens for Bedrock moonshot models', () => { + expect(getModelMaxTokens('moonshot.kimi', EModelEndpoint.bedrock)).toBe( + maxTokensMap[EModelEndpoint.bedrock]['moonshot.kimi'], + ); + expect(getModelMaxTokens('moonshot.kimi-k2', EModelEndpoint.bedrock)).toBe( + maxTokensMap[EModelEndpoint.bedrock]['moonshot.kimi-k2'], + ); + expect(getModelMaxTokens('moonshot.kimi-k2.5', EModelEndpoint.bedrock)).toBe( + maxTokensMap[EModelEndpoint.bedrock]['moonshot.kimi-k2.5'], + ); + expect(getModelMaxTokens('moonshot.kimi-k2-thinking', EModelEndpoint.bedrock)).toBe( + maxTokensMap[EModelEndpoint.bedrock]['moonshot.kimi-k2-thinking'], + ); + expect(getModelMaxTokens('moonshot.kimi-k2-0711', EModelEndpoint.bedrock)).toBe( + maxTokensMap[EModelEndpoint.bedrock]['moonshot.kimi-k2-0711'], + ); + }); + + test('should handle Moonshot/Kimi models with provider prefixes', () => { + expect(getModelMaxTokens('openrouter/kimi-k2')).toBe( + maxTokensMap[EModelEndpoint.openAI]['kimi-k2'], + ); + expect(getModelMaxTokens('openrouter/kimi-k2.5')).toBe( + maxTokensMap[EModelEndpoint.openAI]['kimi-k2.5'], + ); + expect(getModelMaxTokens('openrouter/kimi-k2-turbo')).toBe( + maxTokensMap[EModelEndpoint.openAI]['kimi-k2-turbo'], + ); + expect(getModelMaxTokens('openrouter/moonshot-v1-128k')).toBe( + maxTokensMap[EModelEndpoint.openAI]['moonshot-v1-128k'], + ); }); }); describe('matchModelName', () => { test('should match exact Kimi model names', () => { expect(matchModelName('kimi')).toBe('kimi'); - expect(matchModelName('kimi-k2')).toBe('kimi'); - expect(matchModelName('kimi-vl')).toBe('kimi'); + expect(matchModelName('kimi-k2')).toBe('kimi-k2'); + expect(matchModelName('kimi-k2.5')).toBe('kimi-k2.5'); + expect(matchModelName('kimi-k2-turbo')).toBe('kimi-k2-turbo'); + expect(matchModelName('kimi-k2-0711')).toBe('kimi-k2-0711'); + }); + + test('should match moonshot model names', () => { + expect(matchModelName('moonshot')).toBe('moonshot'); + expect(matchModelName('moonshot-v1-8k')).toBe('moonshot-v1-8k'); + expect(matchModelName('moonshot-v1-32k')).toBe('moonshot-v1-32k'); + expect(matchModelName('moonshot-v1-128k')).toBe('moonshot-v1-128k'); }); test('should match Kimi model variations with provider prefix', () => { - expect(matchModelName('moonshotai/kimi')).toBe('kimi'); - expect(matchModelName('moonshotai/kimi-k2')).toBe('kimi'); - expect(matchModelName('moonshotai/kimi-vl')).toBe('kimi'); + expect(matchModelName('openrouter/kimi')).toBe('kimi'); + expect(matchModelName('openrouter/kimi-k2')).toBe('kimi-k2'); + expect(matchModelName('openrouter/kimi-k2.5')).toBe('kimi-k2.5'); }); test('should match Kimi model variations with suffixes', () => { - expect(matchModelName('kimi-k2-latest')).toBe('kimi'); - expect(matchModelName('kimi-vl-preview')).toBe('kimi'); - expect(matchModelName('kimi-2024')).toBe('kimi'); + expect(matchModelName('kimi-k2-latest')).toBe('kimi-k2'); + expect(matchModelName('kimi-k2.5-preview')).toBe('kimi-k2.5'); }); }); }); @@ -1224,44 +1479,80 @@ describe('Qwen3 Model Tests', () => { describe('GLM Model Tests (Zhipu AI)', () => { describe('getModelMaxTokens', () => { test('should return correct tokens for GLM models', () => { - expect(getModelMaxTokens('glm-4.6')).toBe(200000); - expect(getModelMaxTokens('glm-4.5v')).toBe(66000); - expect(getModelMaxTokens('glm-4.5-air')).toBe(131000); - expect(getModelMaxTokens('glm-4.5')).toBe(131000); - expect(getModelMaxTokens('glm-4-32b')).toBe(128000); - expect(getModelMaxTokens('glm-4')).toBe(128000); - expect(getModelMaxTokens('glm4')).toBe(128000); + expect(getModelMaxTokens('glm-4.6')).toBe(maxTokensMap[EModelEndpoint.openAI]['glm-4.6']); + expect(getModelMaxTokens('glm-4.5v')).toBe(maxTokensMap[EModelEndpoint.openAI]['glm-4.5v']); + expect(getModelMaxTokens('glm-4.5-air')).toBe( + maxTokensMap[EModelEndpoint.openAI]['glm-4.5-air'], + ); + expect(getModelMaxTokens('glm-4.5')).toBe(maxTokensMap[EModelEndpoint.openAI]['glm-4.5']); + expect(getModelMaxTokens('glm-4-32b')).toBe(maxTokensMap[EModelEndpoint.openAI]['glm-4-32b']); + expect(getModelMaxTokens('glm-4')).toBe(maxTokensMap[EModelEndpoint.openAI]['glm-4']); + expect(getModelMaxTokens('glm4')).toBe(maxTokensMap[EModelEndpoint.openAI]['glm4']); }); test('should handle partial matches for GLM models with provider prefixes', () => { - expect(getModelMaxTokens('z-ai/glm-4.6')).toBe(200000); - expect(getModelMaxTokens('z-ai/glm-4.5')).toBe(131000); - expect(getModelMaxTokens('z-ai/glm-4.5-air')).toBe(131000); - expect(getModelMaxTokens('z-ai/glm-4.5v')).toBe(66000); - expect(getModelMaxTokens('z-ai/glm-4-32b')).toBe(128000); + expect(getModelMaxTokens('z-ai/glm-4.6')).toBe( + maxTokensMap[EModelEndpoint.openAI]['glm-4.6'], + ); + expect(getModelMaxTokens('z-ai/glm-4.5')).toBe( + maxTokensMap[EModelEndpoint.openAI]['glm-4.5'], + ); + expect(getModelMaxTokens('z-ai/glm-4.5-air')).toBe( + maxTokensMap[EModelEndpoint.openAI]['glm-4.5-air'], + ); + expect(getModelMaxTokens('z-ai/glm-4.5v')).toBe( + maxTokensMap[EModelEndpoint.openAI]['glm-4.5v'], + ); + expect(getModelMaxTokens('z-ai/glm-4-32b')).toBe( + maxTokensMap[EModelEndpoint.openAI]['glm-4-32b'], + ); - expect(getModelMaxTokens('zai/glm-4.6')).toBe(200000); - expect(getModelMaxTokens('zai/glm-4.5')).toBe(131000); - expect(getModelMaxTokens('zai/glm-4.5-air')).toBe(131000); - expect(getModelMaxTokens('zai/glm-4.5v')).toBe(66000); + expect(getModelMaxTokens('zai/glm-4.6')).toBe(maxTokensMap[EModelEndpoint.openAI]['glm-4.6']); + expect(getModelMaxTokens('zai/glm-4.5')).toBe(maxTokensMap[EModelEndpoint.openAI]['glm-4.5']); + expect(getModelMaxTokens('zai/glm-4.5-air')).toBe( + maxTokensMap[EModelEndpoint.openAI]['glm-4.5-air'], + ); + expect(getModelMaxTokens('zai/glm-4.5v')).toBe( + maxTokensMap[EModelEndpoint.openAI]['glm-4.5v'], + ); - expect(getModelMaxTokens('zai-org/GLM-4.6')).toBe(200000); - expect(getModelMaxTokens('zai-org/GLM-4.5')).toBe(131000); - expect(getModelMaxTokens('zai-org/GLM-4.5-Air')).toBe(131000); - expect(getModelMaxTokens('zai-org/GLM-4.5V')).toBe(66000); - expect(getModelMaxTokens('zai-org/GLM-4-32B-0414')).toBe(128000); + expect(getModelMaxTokens('zai-org/GLM-4.6')).toBe( + maxTokensMap[EModelEndpoint.openAI]['glm-4.6'], + ); + expect(getModelMaxTokens('zai-org/GLM-4.5')).toBe( + maxTokensMap[EModelEndpoint.openAI]['glm-4.5'], + ); + expect(getModelMaxTokens('zai-org/GLM-4.5-Air')).toBe( + maxTokensMap[EModelEndpoint.openAI]['glm-4.5-air'], + ); + expect(getModelMaxTokens('zai-org/GLM-4.5V')).toBe( + maxTokensMap[EModelEndpoint.openAI]['glm-4.5v'], + ); + expect(getModelMaxTokens('zai-org/GLM-4-32B-0414')).toBe( + maxTokensMap[EModelEndpoint.openAI]['glm-4-32b'], + ); }); test('should handle GLM model variations with suffixes', () => { - expect(getModelMaxTokens('glm-4.6-fp8')).toBe(200000); - expect(getModelMaxTokens('zai-org/GLM-4.6-FP8')).toBe(200000); - expect(getModelMaxTokens('zai-org/GLM-4.5-Air-FP8')).toBe(131000); + expect(getModelMaxTokens('glm-4.6-fp8')).toBe(maxTokensMap[EModelEndpoint.openAI]['glm-4.6']); + expect(getModelMaxTokens('zai-org/GLM-4.6-FP8')).toBe( + maxTokensMap[EModelEndpoint.openAI]['glm-4.6'], + ); + expect(getModelMaxTokens('zai-org/GLM-4.5-Air-FP8')).toBe( + maxTokensMap[EModelEndpoint.openAI]['glm-4.5-air'], + ); }); test('should prioritize more specific GLM patterns', () => { - expect(getModelMaxTokens('glm-4.5-air-custom')).toBe(131000); - expect(getModelMaxTokens('glm-4.5-custom')).toBe(131000); - expect(getModelMaxTokens('glm-4.5v-custom')).toBe(66000); + expect(getModelMaxTokens('glm-4.5-air-custom')).toBe( + maxTokensMap[EModelEndpoint.openAI]['glm-4.5-air'], + ); + expect(getModelMaxTokens('glm-4.5-custom')).toBe( + maxTokensMap[EModelEndpoint.openAI]['glm-4.5'], + ); + expect(getModelMaxTokens('glm-4.5v-custom')).toBe( + maxTokensMap[EModelEndpoint.openAI]['glm-4.5v'], + ); }); }); diff --git a/bun.lock b/bun.lock index 783bfc762e..600a640c87 100644 --- a/bun.lock +++ b/bun.lock @@ -254,7 +254,7 @@ }, "packages/api": { "name": "@librechat/api", - "version": "1.7.20", + "version": "1.7.22", "devDependencies": { "@babel/preset-env": "^7.21.5", "@babel/preset-react": "^7.18.6", @@ -321,7 +321,7 @@ }, "packages/client": { "name": "@librechat/client", - "version": "0.4.4", + "version": "0.4.51", "devDependencies": { "@babel/core": "^7.28.5", "@babel/preset-env": "^7.28.5", @@ -409,7 +409,7 @@ }, "packages/data-provider": { "name": "librechat-data-provider", - "version": "0.8.220", + "version": "0.8.231", "dependencies": { "axios": "^1.12.1", "dayjs": "^1.11.13", @@ -447,7 +447,7 @@ }, "packages/data-schemas": { "name": "@librechat/data-schemas", - "version": "0.0.33", + "version": "0.0.35", "devDependencies": { "@rollup/plugin-alias": "^5.1.0", "@rollup/plugin-commonjs": "^29.0.0", diff --git a/client/jest.config.cjs b/client/jest.config.cjs index 9a9f9f5451..53d4063a0a 100644 --- a/client/jest.config.cjs +++ b/client/jest.config.cjs @@ -1,4 +1,4 @@ -/** v0.8.2-rc2 */ +/** v0.8.2 */ module.exports = { roots: ['/src'], testEnvironment: 'jsdom', diff --git a/client/package.json b/client/package.json index 81b2fdf255..f6838f5091 100644 --- a/client/package.json +++ b/client/package.json @@ -1,6 +1,6 @@ { "name": "@librechat/frontend", - "version": "v0.8.2-rc2", + "version": "v0.8.2", "description": "", "type": "module", "scripts": { @@ -77,10 +77,10 @@ "jotai": "^2.12.5", "js-cookie": "^3.0.5", "librechat-data-provider": "*", - "lodash": "^4.17.21", + "lodash": "^4.17.23", "lucide-react": "^0.394.0", "match-sorter": "^8.1.0", - "mermaid": "^11.12.2", + "mermaid": "^11.12.3", "micromark-extension-llm-math": "^3.1.0", "qrcode.react": "^4.2.0", "rc-input-number": "^7.4.2", @@ -148,7 +148,6 @@ "jest-file-loader": "^1.0.3", "jest-junit": "^16.0.0", "postcss": "^8.4.31", - "postcss-loader": "^7.1.0", "postcss-preset-env": "^8.2.0", "tailwindcss": "^3.4.1", "typescript": "^5.3.3", diff --git a/client/public/assets/maskable-icon.png b/client/public/assets/maskable-icon.png index 90e48f870b..b48524b867 100644 Binary files a/client/public/assets/maskable-icon.png and b/client/public/assets/maskable-icon.png differ diff --git a/client/public/assets/web-browser.svg b/client/public/assets/web-browser.svg deleted file mode 100644 index 3f9c85d14b..0000000000 --- a/client/public/assets/web-browser.svg +++ /dev/null @@ -1,86 +0,0 @@ - - - - diff --git a/client/src/Providers/BadgeRowContext.tsx b/client/src/Providers/BadgeRowContext.tsx index 40df795aba..dce1c38a78 100644 --- a/client/src/Providers/BadgeRowContext.tsx +++ b/client/src/Providers/BadgeRowContext.tsx @@ -1,4 +1,4 @@ -import React, { createContext, useContext, useEffect, useRef } from 'react'; +import React, { createContext, useContext, useEffect, useMemo, useRef } from 'react'; import { useSetRecoilState } from 'recoil'; import { Tools, Constants, LocalStorageKeys, AgentCapabilities } from 'librechat-data-provider'; import type { TAgentsEndpoint } from 'librechat-data-provider'; @@ -9,11 +9,13 @@ import { useCodeApiKeyForm, useToolToggle, } from '~/hooks'; -import { getTimestampedValue, setTimestamp } from '~/utils/timestamps'; +import { getTimestampedValue } from '~/utils/timestamps'; +import { useGetStartupConfig } from '~/data-provider'; import { ephemeralAgentByConvoId } from '~/store'; interface BadgeRowContextType { conversationId?: string | null; + storageContextKey?: string; agentsConfig?: TAgentsEndpoint | null; webSearch: ReturnType; artifacts: ReturnType; @@ -38,34 +40,70 @@ interface BadgeRowProviderProps { children: React.ReactNode; isSubmitting?: boolean; conversationId?: string | null; + specName?: string | null; } export default function BadgeRowProvider({ children, isSubmitting, conversationId, + specName, }: BadgeRowProviderProps) { - const lastKeyRef = useRef(''); + const lastContextKeyRef = useRef(''); const hasInitializedRef = useRef(false); const { agentsConfig } = useGetAgentsConfig(); + const { data: startupConfig } = useGetStartupConfig(); const key = conversationId ?? Constants.NEW_CONVO; + const hasModelSpecs = (startupConfig?.modelSpecs?.list?.length ?? 0) > 0; + + /** + * Compute the storage context key for non-spec persistence: + * - `__defaults__`: specs configured but none active → shared defaults key + * - undefined: spec active (no persistence) or no specs configured (original behavior) + * + * When a spec is active, tool/MCP state is NOT persisted — the admin's spec + * configuration is always applied fresh. Only non-spec user preferences persist. + */ + const storageContextKey = useMemo(() => { + if (!specName && hasModelSpecs) { + return Constants.spec_defaults_key as string; + } + return undefined; + }, [specName, hasModelSpecs]); + + /** + * Compute the storage suffix for reading localStorage defaults: + * - New conversations read from environment key (spec or non-spec defaults) + * - Existing conversations read from conversation key (per-conversation state) + */ + const isNewConvo = key === Constants.NEW_CONVO; + const storageSuffix = isNewConvo && storageContextKey ? storageContextKey : key; const setEphemeralAgent = useSetRecoilState(ephemeralAgentByConvoId(key)); - /** Initialize ephemeralAgent from localStorage on mount and when conversation changes */ + /** Initialize ephemeralAgent from localStorage on mount and when conversation/spec changes. + * Skipped when a spec is active — applyModelSpecEphemeralAgent handles both new conversations + * (pure spec values) and existing conversations (spec values + localStorage overrides). */ useEffect(() => { if (isSubmitting) { return; } - // Check if this is a new conversation or the first load - if (!hasInitializedRef.current || lastKeyRef.current !== key) { + if (specName) { + // Spec active: applyModelSpecEphemeralAgent handles all state (spec base + localStorage + // overrides for existing conversations). Reset init flag so switching back to non-spec + // triggers a fresh re-init. + hasInitializedRef.current = false; + return; + } + // Check if this is a new conversation/spec or the first load + if (!hasInitializedRef.current || lastContextKeyRef.current !== storageSuffix) { hasInitializedRef.current = true; - lastKeyRef.current = key; + lastContextKeyRef.current = storageSuffix; - const codeToggleKey = `${LocalStorageKeys.LAST_CODE_TOGGLE_}${key}`; - const webSearchToggleKey = `${LocalStorageKeys.LAST_WEB_SEARCH_TOGGLE_}${key}`; - const fileSearchToggleKey = `${LocalStorageKeys.LAST_FILE_SEARCH_TOGGLE_}${key}`; - const artifactsToggleKey = `${LocalStorageKeys.LAST_ARTIFACTS_TOGGLE_}${key}`; + const codeToggleKey = `${LocalStorageKeys.LAST_CODE_TOGGLE_}${storageSuffix}`; + const webSearchToggleKey = `${LocalStorageKeys.LAST_WEB_SEARCH_TOGGLE_}${storageSuffix}`; + const fileSearchToggleKey = `${LocalStorageKeys.LAST_FILE_SEARCH_TOGGLE_}${storageSuffix}`; + const artifactsToggleKey = `${LocalStorageKeys.LAST_ARTIFACTS_TOGGLE_}${storageSuffix}`; const codeToggleValue = getTimestampedValue(codeToggleKey); const webSearchToggleValue = getTimestampedValue(webSearchToggleKey); @@ -106,39 +144,53 @@ export default function BadgeRowProvider({ } } - /** - * Always set values for all tools (use defaults if not in `localStorage`) - * If `ephemeralAgent` is `null`, create a new object with just our tool values - */ - const finalValues = { - [Tools.execute_code]: initialValues[Tools.execute_code] ?? false, - [Tools.web_search]: initialValues[Tools.web_search] ?? false, - [Tools.file_search]: initialValues[Tools.file_search] ?? false, - [AgentCapabilities.artifacts]: initialValues[AgentCapabilities.artifacts] ?? false, - }; + const hasOverrides = Object.keys(initialValues).length > 0; - setEphemeralAgent((prev) => ({ - ...(prev || {}), - ...finalValues, - })); - - Object.entries(finalValues).forEach(([toolKey, value]) => { - if (value !== false) { - let storageKey = artifactsToggleKey; - if (toolKey === Tools.execute_code) { - storageKey = codeToggleKey; - } else if (toolKey === Tools.web_search) { - storageKey = webSearchToggleKey; - } else if (toolKey === Tools.file_search) { - storageKey = fileSearchToggleKey; + /** Read persisted MCP values from localStorage */ + let mcpOverrides: string[] | null = null; + const mcpStorageKey = `${LocalStorageKeys.LAST_MCP_}${storageSuffix}`; + const mcpRaw = localStorage.getItem(mcpStorageKey); + if (mcpRaw !== null) { + try { + const parsed = JSON.parse(mcpRaw); + if (Array.isArray(parsed) && parsed.length > 0) { + mcpOverrides = parsed; } - // Store the value and set timestamp for existing values - localStorage.setItem(storageKey, JSON.stringify(value)); - setTimestamp(storageKey); + } catch (e) { + console.error('Failed to parse MCP values:', e); } + } + + setEphemeralAgent((prev) => { + if (prev == null) { + /** ephemeralAgent is null — use localStorage defaults */ + if (hasOverrides || mcpOverrides) { + const result = { ...initialValues }; + if (mcpOverrides) { + result.mcp = mcpOverrides; + } + return result; + } + return prev; + } + /** ephemeralAgent already has values (from prior state). + * Only fill in undefined keys from localStorage. */ + let changed = false; + const result = { ...prev }; + for (const [toolKey, value] of Object.entries(initialValues)) { + if (result[toolKey] === undefined) { + result[toolKey] = value; + changed = true; + } + } + if (mcpOverrides && result.mcp === undefined) { + result.mcp = mcpOverrides; + changed = true; + } + return changed ? result : prev; }); } - }, [key, isSubmitting, setEphemeralAgent]); + }, [storageSuffix, specName, isSubmitting, setEphemeralAgent]); /** CodeInterpreter hooks */ const codeApiKeyForm = useCodeApiKeyForm({}); @@ -146,6 +198,7 @@ export default function BadgeRowProvider({ const codeInterpreter = useToolToggle({ conversationId, + storageContextKey, setIsDialogOpen: setCodeDialogOpen, toolKey: Tools.execute_code, localStorageKey: LocalStorageKeys.LAST_CODE_TOGGLE_, @@ -161,6 +214,7 @@ export default function BadgeRowProvider({ const webSearch = useToolToggle({ conversationId, + storageContextKey, toolKey: Tools.web_search, localStorageKey: LocalStorageKeys.LAST_WEB_SEARCH_TOGGLE_, setIsDialogOpen: setWebSearchDialogOpen, @@ -173,6 +227,7 @@ export default function BadgeRowProvider({ /** FileSearch hook */ const fileSearch = useToolToggle({ conversationId, + storageContextKey, toolKey: Tools.file_search, localStorageKey: LocalStorageKeys.LAST_FILE_SEARCH_TOGGLE_, isAuthenticated: true, @@ -181,12 +236,13 @@ export default function BadgeRowProvider({ /** Artifacts hook - using a custom key since it's not a Tool but a capability */ const artifacts = useToolToggle({ conversationId, + storageContextKey, toolKey: AgentCapabilities.artifacts, localStorageKey: LocalStorageKeys.LAST_ARTIFACTS_TOGGLE_, isAuthenticated: true, }); - const mcpServerManager = useMCPServerManager({ conversationId }); + const mcpServerManager = useMCPServerManager({ conversationId, storageContextKey }); const value: BadgeRowContextType = { webSearch, @@ -194,6 +250,7 @@ export default function BadgeRowProvider({ fileSearch, agentsConfig, conversationId, + storageContextKey, codeApiKeyForm, codeInterpreter, searchApiKeyForm, diff --git a/client/src/common/agents-types.ts b/client/src/common/agents-types.ts index 9ac6b440a3..c3ea06f890 100644 --- a/client/src/common/agents-types.ts +++ b/client/src/common/agents-types.ts @@ -1,6 +1,7 @@ import { AgentCapabilities, ArtifactModes } from 'librechat-data-provider'; import type { AgentModelParameters, + AgentToolOptions, SupportContact, AgentProvider, GraphEdge, @@ -8,6 +9,8 @@ import type { } from 'librechat-data-provider'; import type { OptionWithIcon, ExtendedFile } from './types'; +export type AgentQueryResult = { found: true; agent: Agent } | { found: false }; + export type TAgentOption = OptionWithIcon & Agent & { knowledge_files?: Array<[string, ExtendedFile]>; @@ -33,6 +36,8 @@ export type AgentForm = { model: string | null; model_parameters: AgentModelParameters; tools?: string[]; + /** Per-tool configuration options (deferred loading, allowed callers, etc.) */ + tool_options?: AgentToolOptions; provider?: AgentProvider | OptionWithIcon; /** @deprecated Use edges instead */ agent_ids?: string[]; diff --git a/client/src/common/menus.ts b/client/src/common/menus.ts index 97c2d1b11b..ee7d7292c9 100644 --- a/client/src/common/menus.ts +++ b/client/src/common/menus.ts @@ -15,6 +15,8 @@ export interface MenuItemProps { separate?: boolean; hideOnClick?: boolean; dialog?: React.ReactElement; + ariaLabel?: string; + ariaChecked?: boolean; ref?: React.Ref; className?: string; render?: diff --git a/client/src/components/Auth/LoginForm.tsx b/client/src/components/Auth/LoginForm.tsx index 0a6e1e8614..c51c2002e3 100644 --- a/client/src/components/Auth/LoginForm.tsx +++ b/client/src/components/Auth/LoginForm.tsx @@ -5,6 +5,7 @@ import { ThemeContext, Spinner, Button, isDark } from '@librechat/client'; import type { TLoginUser, TStartupConfig } from 'librechat-data-provider'; import type { TAuthContext } from '~/common'; import { useResendVerificationEmail, useGetStartupConfig } from '~/data-provider'; +import { validateEmail } from '~/utils'; import { useLocalize } from '~/hooks'; type TLoginFormProps = { @@ -96,10 +97,9 @@ const LoginForm: React.FC = ({ onSubmit, startupConfig, error, {...register('email', { required: localize('com_auth_email_required'), maxLength: { value: 120, message: localize('com_auth_email_max_length') }, - pattern: { - value: useUsernameLogin ? /\S+/ : /\S+@\S+\.\S+/, - message: localize('com_auth_email_pattern'), - }, + validate: useUsernameLogin + ? undefined + : (value) => validateEmail(value, localize('com_auth_email_pattern')), })} aria-invalid={!!errors.email} className="webkit-dark-styles transition-color peer w-full rounded-2xl border border-border-light bg-surface-primary px-3.5 pb-2.5 pt-3 text-text-primary duration-200 focus:border-green-500 focus:outline-none" diff --git a/client/src/components/Chat/Footer.tsx b/client/src/components/Chat/Footer.tsx index 72aa04be57..75dd853c4f 100644 --- a/client/src/components/Chat/Footer.tsx +++ b/client/src/components/Chat/Footer.tsx @@ -13,30 +13,14 @@ export default function Footer({ className }: { className?: string }) { const termsOfService = config?.interface?.termsOfService; const privacyPolicyRender = privacyPolicy?.externalUrl != null && ( - + {localize('com_ui_privacy_policy')} - {privacyPolicy.openNewTab === true && ( - {' ' + localize('com_ui_opens_new_tab')} - )} ); const termsOfServiceRender = termsOfService?.externalUrl != null && ( - + {localize('com_ui_terms_of_service')} - {termsOfService.openNewTab === true && ( - {' ' + localize('com_ui_opens_new_tab')} - )} ); @@ -67,12 +51,10 @@ export default function Footer({ className }: { className?: string }) { {children} - {' ' + localize('com_ui_opens_new_tab')} ); }, diff --git a/client/src/components/Chat/Input/BadgeRow.tsx b/client/src/components/Chat/Input/BadgeRow.tsx index 5036dcd5e4..6fea6b0d58 100644 --- a/client/src/components/Chat/Input/BadgeRow.tsx +++ b/client/src/components/Chat/Input/BadgeRow.tsx @@ -28,6 +28,7 @@ interface BadgeRowProps { onChange: (badges: Pick[]) => void; onToggle?: (badgeId: string, currentActive: boolean) => void; conversationId?: string | null; + specName?: string | null; isSubmitting?: boolean; isInChat: boolean; } @@ -142,6 +143,7 @@ const dragReducer = (state: DragState, action: DragAction): DragState => { function BadgeRow({ showEphemeralBadges, conversationId, + specName, isSubmitting, onChange, onToggle, @@ -320,7 +322,11 @@ function BadgeRow({ }, [dragState.draggedBadge, handleMouseMove, handleMouseUp]); return ( - +
{showEphemeralBadges === true && } {tempBadges.map((badge, index) => ( diff --git a/client/src/components/Chat/Input/ChatForm.tsx b/client/src/components/Chat/Input/ChatForm.tsx index cb1e30a09d..45277e5b9c 100644 --- a/client/src/components/Chat/Input/ChatForm.tsx +++ b/client/src/components/Chat/Input/ChatForm.tsx @@ -258,7 +258,17 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => { {endpoint && (
-
+
{ @@ -290,16 +300,6 @@ const ChatForm = memo(({ index = 0 }: { index?: number }) => { 'scrollbar-hover transition-[max-height] duration-200 disabled:cursor-not-allowed', )} /> - {isCollapsed && ( -
- )}
{ } isSubmitting={isSubmitting} conversationId={conversationId} + specName={conversation?.spec} onChange={setBadges} isInChat={ Array.isArray(conversation?.messages) && conversation.messages.length >= 1 diff --git a/client/src/components/Chat/Input/Files/ImagePreview.tsx b/client/src/components/Chat/Input/Files/ImagePreview.tsx index c675c9326c..2714c3677f 100644 --- a/client/src/components/Chat/Input/Files/ImagePreview.tsx +++ b/client/src/components/Chat/Input/Files/ImagePreview.tsx @@ -158,11 +158,11 @@ const ImagePreview = ({ { e.preventDefault(); closeButtonRef.current?.focus(); diff --git a/client/src/components/Chat/Input/MCPSelect.tsx b/client/src/components/Chat/Input/MCPSelect.tsx index 278e603db0..a5356f5094 100644 --- a/client/src/components/Chat/Input/MCPSelect.tsx +++ b/client/src/components/Chat/Input/MCPSelect.tsx @@ -11,7 +11,7 @@ import { useHasAccess } from '~/hooks'; import { cn } from '~/utils'; function MCPSelectContent() { - const { conversationId, mcpServerManager } = useBadgeRowContext(); + const { conversationId, storageContextKey, mcpServerManager } = useBadgeRowContext(); const { localize, isPinned, @@ -128,7 +128,11 @@ function MCPSelectContent() { {configDialogProps && ( - + )} ); diff --git a/client/src/components/Chat/Input/MCPSubMenu.tsx b/client/src/components/Chat/Input/MCPSubMenu.tsx index ca547ca1f7..b0b8fad1bb 100644 --- a/client/src/components/Chat/Input/MCPSubMenu.tsx +++ b/client/src/components/Chat/Input/MCPSubMenu.tsx @@ -15,7 +15,7 @@ interface MCPSubMenuProps { const MCPSubMenu = React.forwardRef( ({ placeholder, ...props }, ref) => { const localize = useLocalize(); - const { mcpServerManager } = useBadgeRowContext(); + const { storageContextKey, mcpServerManager } = useBadgeRowContext(); const { isPinned, mcpValues, @@ -106,7 +106,9 @@ const MCPSubMenu = React.forwardRef(
- {configDialogProps && } + {configDialogProps && ( + + )}
); }, diff --git a/client/src/components/Chat/Input/Mention.tsx b/client/src/components/Chat/Input/Mention.tsx index 2defcc7623..9e56068def 100644 --- a/client/src/components/Chat/Input/Mention.tsx +++ b/client/src/components/Chat/Input/Mention.tsx @@ -12,7 +12,7 @@ import useMentions from '~/hooks/Input/useMentions'; import { removeCharIfLast } from '~/utils'; import MentionItem from './MentionItem'; -const ROW_HEIGHT = 40; +const ROW_HEIGHT = 44; export default function Mention({ conversation, diff --git a/client/src/components/Chat/Input/MentionItem.tsx b/client/src/components/Chat/Input/MentionItem.tsx index fcfb22c312..6c978240ee 100644 --- a/client/src/components/Chat/Input/MentionItem.tsx +++ b/client/src/components/Chat/Input/MentionItem.tsx @@ -25,15 +25,16 @@ export default function MentionItem({ }: MentionItemProps) { return (
diff --git a/client/src/components/Chat/Menus/Endpoints/components/EndpointModelItem.tsx b/client/src/components/Chat/Menus/Endpoints/components/EndpointModelItem.tsx index cb9d24eb61..752788d63a 100644 --- a/client/src/components/Chat/Menus/Endpoints/components/EndpointModelItem.tsx +++ b/client/src/components/Chat/Menus/Endpoints/components/EndpointModelItem.tsx @@ -1,5 +1,6 @@ import React, { useRef, useState, useEffect } from 'react'; -import { EarthIcon, Pin, PinOff } from 'lucide-react'; +import { VisuallyHidden } from '@ariakit/react'; +import { CheckCircle2, EarthIcon, Pin, PinOff } from 'lucide-react'; import { isAgentsEndpoint, isAssistantsEndpoint } from 'librechat-data-provider'; import { useModelSelectorContext } from '../ModelSelectorContext'; import { CustomMenuItem as MenuItem } from '../CustomMenu'; @@ -110,6 +111,7 @@ export function EndpointModelItem({ modelId, endpoint, isSelected }: EndpointMod handleSelectModel(endpoint, modelId ?? '')} + aria-selected={isSelected || undefined} className="group flex w-full cursor-pointer items-center justify-between rounded-lg px-2 text-sm" >
@@ -133,23 +135,10 @@ export function EndpointModelItem({ modelId, endpoint, isSelected }: EndpointMod )} {isSelected && ( -
- - - -
+ <> +
{isSelected && ( -
- - - -
+ <> +
{selectedSpec === spec.name && ( -
- - - -
+ <> +