From 8f4705f683efd13b5c77417e0586bdfef8d741a6 Mon Sep 17 00:00:00 2001 From: "Theo N. Truong" <644650+nhtruong@users.noreply.github.com> Date: Thu, 30 Oct 2025 15:08:04 -0600 Subject: [PATCH 001/207] =?UTF-8?q?=F0=9F=91=91=20feat:=20Distributed=20Le?= =?UTF-8?q?ader=20Election=20with=20Redis=20for=20Multi-instance=20Coordin?= =?UTF-8?q?ation=20(#10189)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🔧 refactor: Move GLOBAL_PREFIX_SEPARATOR to cacheConfig for consistency * 👑 feat: Implement distributed leader election using Redis --- .env.example | 10 + .github/workflows/cache-integration-tests.yml | 13 +- packages/api/package.json | 7 +- ...=> limiterCache.cache_integration.spec.ts} | 0 ...=> sessionCache.cache_integration.spec.ts} | 0 ...> standardCache.cache_integration.spec.ts} | 11 +- ... violationCache.cache_integration.spec.ts} | 0 ...=> redisClients.cache_integration.spec.ts} | 0 packages/api/src/cache/cacheConfig.ts | 1 + packages/api/src/cache/cacheFactory.ts | 4 +- packages/api/src/cache/redisClients.ts | 6 +- packages/api/src/cluster/LeaderElection.ts | 180 ++++++++++++++ .../LeaderElection.cache_integration.spec.ts | 220 ++++++++++++++++++ packages/api/src/cluster/config.ts | 14 ++ packages/api/src/cluster/index.ts | 1 + 15 files changed, 452 insertions(+), 15 deletions(-) rename packages/api/src/cache/__tests__/cacheFactory/{limiterCache.integration.spec.ts => limiterCache.cache_integration.spec.ts} (100%) rename packages/api/src/cache/__tests__/cacheFactory/{sessionCache.integration.spec.ts => sessionCache.cache_integration.spec.ts} (100%) rename packages/api/src/cache/__tests__/cacheFactory/{standardCache.integration.spec.ts => standardCache.cache_integration.spec.ts} (96%) rename packages/api/src/cache/__tests__/cacheFactory/{violationCache.integration.spec.ts => violationCache.cache_integration.spec.ts} (100%) rename packages/api/src/cache/__tests__/{redisClients.integration.spec.ts => redisClients.cache_integration.spec.ts} (100%) create mode 100644 packages/api/src/cluster/LeaderElection.ts create mode 100644 packages/api/src/cluster/__tests__/LeaderElection.cache_integration.spec.ts create mode 100644 packages/api/src/cluster/config.ts create mode 100644 packages/api/src/cluster/index.ts diff --git a/.env.example b/.env.example index f1666fb763..24c74487aa 100644 --- a/.env.example +++ b/.env.example @@ -702,6 +702,16 @@ HELP_AND_FAQ_URL=https://librechat.ai # Comma-separated list of CacheKeys (e.g., ROLES,MESSAGES) # FORCED_IN_MEMORY_CACHE_NAMESPACES=ROLES,MESSAGES +# Leader Election Configuration (for multi-instance deployments with Redis) +# Duration in seconds that the leader lease is valid before it expires (default: 25) +# LEADER_LEASE_DURATION=25 +# Interval in seconds at which the leader renews its lease (default: 10) +# LEADER_RENEW_INTERVAL=10 +# Maximum number of retry attempts when renewing the lease fails (default: 3) +# LEADER_RENEW_ATTEMPTS=3 +# Delay in seconds between retry attempts when renewing the lease (default: 0.5) +# LEADER_RENEW_RETRY_DELAY=0.5 + #==================================================# # Others # #==================================================# diff --git a/.github/workflows/cache-integration-tests.yml b/.github/workflows/cache-integration-tests.yml index 2afe68287e..f7ac638282 100644 --- a/.github/workflows/cache-integration-tests.yml +++ b/.github/workflows/cache-integration-tests.yml @@ -8,12 +8,13 @@ on: - release/* paths: - 'packages/api/src/cache/**' + - 'packages/api/src/cluster/**' - 'redis-config/**' - '.github/workflows/cache-integration-tests.yml' jobs: cache_integration_tests: - name: Run Cache Integration Tests + name: Integration Tests that use actual Redis Cache timeout-minutes: 30 runs-on: ubuntu-latest @@ -66,7 +67,15 @@ jobs: USE_REDIS: true REDIS_URI: redis://127.0.0.1:6379 REDIS_CLUSTER_URI: redis://127.0.0.1:7001,redis://127.0.0.1:7002,redis://127.0.0.1:7003 - run: npm run test:cache:integration + run: npm run test:cache-integration:core + + - name: Run cluster integration tests + working-directory: packages/api + env: + NODE_ENV: test + USE_REDIS: true + REDIS_URI: redis://127.0.0.1:6379 + run: npm run test:cache-integration:cluster - name: Stop Redis Cluster if: always() diff --git a/packages/api/package.json b/packages/api/package.json index 7db4a54a21..4d333082a3 100644 --- a/packages/api/package.json +++ b/packages/api/package.json @@ -18,9 +18,10 @@ "build:dev": "npm run clean && NODE_ENV=development rollup -c --bundleConfigAsCjs", "build:watch": "NODE_ENV=development rollup -c -w --bundleConfigAsCjs", "build:watch:prod": "rollup -c -w --bundleConfigAsCjs", - "test": "jest --coverage --watch --testPathIgnorePatterns=\"\\.integration\\.\"", - "test:ci": "jest --coverage --ci --testPathIgnorePatterns=\"\\.integration\\.\"", - "test:cache:integration": "jest --testPathPattern=\"src/cache/.*\\.integration\\.spec\\.ts$\" --coverage=false", + "test": "jest --coverage --watch --testPathIgnorePatterns=\"\\.*integration\\.\"", + "test:ci": "jest --coverage --ci --testPathIgnorePatterns=\"\\.*integration\\.\"", + "test:cache-integration:core": "jest --testPathPattern=\"src/cache/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false", + "test:cache-integration:cluster": "jest --testPathPattern=\"src/cluster/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false --runInBand", "verify": "npm run test:ci", "b:clean": "bun run rimraf dist", "b:build": "bun run b:clean && bun run rollup -c --silent --bundleConfigAsCjs", diff --git a/packages/api/src/cache/__tests__/cacheFactory/limiterCache.integration.spec.ts b/packages/api/src/cache/__tests__/cacheFactory/limiterCache.cache_integration.spec.ts similarity index 100% rename from packages/api/src/cache/__tests__/cacheFactory/limiterCache.integration.spec.ts rename to packages/api/src/cache/__tests__/cacheFactory/limiterCache.cache_integration.spec.ts diff --git a/packages/api/src/cache/__tests__/cacheFactory/sessionCache.integration.spec.ts b/packages/api/src/cache/__tests__/cacheFactory/sessionCache.cache_integration.spec.ts similarity index 100% rename from packages/api/src/cache/__tests__/cacheFactory/sessionCache.integration.spec.ts rename to packages/api/src/cache/__tests__/cacheFactory/sessionCache.cache_integration.spec.ts diff --git a/packages/api/src/cache/__tests__/cacheFactory/standardCache.integration.spec.ts b/packages/api/src/cache/__tests__/cacheFactory/standardCache.cache_integration.spec.ts similarity index 96% rename from packages/api/src/cache/__tests__/cacheFactory/standardCache.integration.spec.ts rename to packages/api/src/cache/__tests__/cacheFactory/standardCache.cache_integration.spec.ts index db40ad636c..b5fcc207da 100644 --- a/packages/api/src/cache/__tests__/cacheFactory/standardCache.integration.spec.ts +++ b/packages/api/src/cache/__tests__/cacheFactory/standardCache.cache_integration.spec.ts @@ -1,11 +1,14 @@ import type { Keyv } from 'keyv'; -// Mock GLOBAL_PREFIX_SEPARATOR -jest.mock('../../redisClients', () => { - const originalModule = jest.requireActual('../../redisClients'); +// Mock GLOBAL_PREFIX_SEPARATOR from cacheConfig +jest.mock('../../cacheConfig', () => { + const originalModule = jest.requireActual('../../cacheConfig'); return { ...originalModule, - GLOBAL_PREFIX_SEPARATOR: '>>', + cacheConfig: { + ...originalModule.cacheConfig, + GLOBAL_PREFIX_SEPARATOR: '>>', + }, }; }); diff --git a/packages/api/src/cache/__tests__/cacheFactory/violationCache.integration.spec.ts b/packages/api/src/cache/__tests__/cacheFactory/violationCache.cache_integration.spec.ts similarity index 100% rename from packages/api/src/cache/__tests__/cacheFactory/violationCache.integration.spec.ts rename to packages/api/src/cache/__tests__/cacheFactory/violationCache.cache_integration.spec.ts diff --git a/packages/api/src/cache/__tests__/redisClients.integration.spec.ts b/packages/api/src/cache/__tests__/redisClients.cache_integration.spec.ts similarity index 100% rename from packages/api/src/cache/__tests__/redisClients.integration.spec.ts rename to packages/api/src/cache/__tests__/redisClients.cache_integration.spec.ts diff --git a/packages/api/src/cache/cacheConfig.ts b/packages/api/src/cache/cacheConfig.ts index aebfeef3bd..3e5c1646d2 100644 --- a/packages/api/src/cache/cacheConfig.ts +++ b/packages/api/src/cache/cacheConfig.ts @@ -65,6 +65,7 @@ const cacheConfig = { REDIS_PASSWORD: process.env.REDIS_PASSWORD, REDIS_CA: getRedisCA(), REDIS_KEY_PREFIX: process.env[REDIS_KEY_PREFIX_VAR ?? ''] || REDIS_KEY_PREFIX || '', + GLOBAL_PREFIX_SEPARATOR: '::', REDIS_MAX_LISTENERS: math(process.env.REDIS_MAX_LISTENERS, 40), REDIS_PING_INTERVAL: math(process.env.REDIS_PING_INTERVAL, 0), /** Max delay between reconnection attempts in ms */ diff --git a/packages/api/src/cache/cacheFactory.ts b/packages/api/src/cache/cacheFactory.ts index 427b1b38ad..d2244a662a 100644 --- a/packages/api/src/cache/cacheFactory.ts +++ b/packages/api/src/cache/cacheFactory.ts @@ -14,7 +14,7 @@ import { logger } from '@librechat/data-schemas'; import session, { MemoryStore } from 'express-session'; import { RedisStore as ConnectRedis } from 'connect-redis'; import type { SendCommandFn } from 'rate-limit-redis'; -import { keyvRedisClient, ioredisClient, GLOBAL_PREFIX_SEPARATOR } from './redisClients'; +import { keyvRedisClient, ioredisClient } from './redisClients'; import { cacheConfig } from './cacheConfig'; import { violationFile } from './keyvFiles'; @@ -31,7 +31,7 @@ export const standardCache = (namespace: string, ttl?: number, fallbackStore?: o const keyvRedis = new KeyvRedis(keyvRedisClient); const cache = new Keyv(keyvRedis, { namespace, ttl }); keyvRedis.namespace = cacheConfig.REDIS_KEY_PREFIX; - keyvRedis.keyPrefixSeparator = GLOBAL_PREFIX_SEPARATOR; + keyvRedis.keyPrefixSeparator = cacheConfig.GLOBAL_PREFIX_SEPARATOR; cache.on('error', (err) => { logger.error(`Cache error in namespace ${namespace}:`, err); diff --git a/packages/api/src/cache/redisClients.ts b/packages/api/src/cache/redisClients.ts index 6c11c1a0a8..6f0e27d772 100644 --- a/packages/api/src/cache/redisClients.ts +++ b/packages/api/src/cache/redisClients.ts @@ -5,8 +5,6 @@ import { createClient, createCluster } from '@keyv/redis'; import type { RedisClientType, RedisClusterType } from '@redis/client'; import { cacheConfig } from './cacheConfig'; -const GLOBAL_PREFIX_SEPARATOR = '::'; - const urls = cacheConfig.REDIS_URI?.split(',').map((uri) => new URL(uri)) || []; const username = urls?.[0]?.username || cacheConfig.REDIS_USERNAME; const password = urls?.[0]?.password || cacheConfig.REDIS_PASSWORD; @@ -18,7 +16,7 @@ if (cacheConfig.USE_REDIS) { username: username, password: password, tls: ca ? { ca } : undefined, - keyPrefix: `${cacheConfig.REDIS_KEY_PREFIX}${GLOBAL_PREFIX_SEPARATOR}`, + keyPrefix: `${cacheConfig.REDIS_KEY_PREFIX}${cacheConfig.GLOBAL_PREFIX_SEPARATOR}`, maxListeners: cacheConfig.REDIS_MAX_LISTENERS, retryStrategy: (times: number) => { if ( @@ -192,4 +190,4 @@ if (cacheConfig.USE_REDIS) { }); } -export { ioredisClient, keyvRedisClient, GLOBAL_PREFIX_SEPARATOR }; +export { ioredisClient, keyvRedisClient }; diff --git a/packages/api/src/cluster/LeaderElection.ts b/packages/api/src/cluster/LeaderElection.ts new file mode 100644 index 0000000000..726c56b185 --- /dev/null +++ b/packages/api/src/cluster/LeaderElection.ts @@ -0,0 +1,180 @@ +import { keyvRedisClient } from '~/cache/redisClients'; +import { cacheConfig as cache } from '~/cache/cacheConfig'; +import { clusterConfig as cluster } from './config'; +import { logger } from '@librechat/data-schemas'; + +/** + * Distributed leader election implementation using Redis for coordination across multiple server instances. + * + * Leadership election: + * - During bootup, every server attempts to become the leader by calling isLeader() + * - Uses atomic Redis SET NX (set if not exists) to ensure only ONE server can claim leadership + * - The first server to successfully set the key becomes the leader; others become followers + * - Works with any number of servers (1 to infinite) - single server always becomes leader + * + * Leadership maintenance: + * - Leader holds a key in Redis with a 25-second lease duration + * - Leader renews this lease every 10 seconds to maintain leadership + * - If leader crashes, the lease eventually expires, and the key disappears + * - On shutdown, leader deletes its key to allow immediate re-election + * - Followers check for leadership and attempt to claim it when the key is empty + */ +export class LeaderElection { + // We can't use Keyv namespace here because we need direct Redis access for atomic operations + static readonly LEADER_KEY = `${cache.REDIS_KEY_PREFIX}${cache.GLOBAL_PREFIX_SEPARATOR}LeadingServerUUID`; + private static _instance = new LeaderElection(); + + readonly UUID: string = crypto.randomUUID(); + private refreshTimer: NodeJS.Timeout | undefined = undefined; + + // DO NOT create new instances of this class directly. + // Use the exported isLeader() function which uses a singleton instance. + constructor() { + if (LeaderElection._instance) return LeaderElection._instance; + + process.on('SIGTERM', () => this.resign()); + process.on('SIGINT', () => this.resign()); + LeaderElection._instance = this; + } + + /** + * Checks if this instance is the current leader. + * If no leader exists, waits upto 2 seconds (randomized to avoid thundering herd) then attempts self-election. + * Always returns true in non-Redis mode (single-instance deployment). + */ + public async isLeader(): Promise { + if (!cache.USE_REDIS) return true; + + try { + const currentLeader = await LeaderElection.getLeaderUUID(); + // If we own the leadership lock, return true. + // However, in case the leadership refresh retries have been exhausted, something has gone wrong. + // This server is not considered the leader anymore, similar to a crash, to avoid split-brain scenario. + if (currentLeader === this.UUID) return this.refreshTimer != null; + if (currentLeader != null) return false; // someone holds leadership lock + + const delay = Math.random() * 2000; + await new Promise((resolve) => setTimeout(resolve, delay)); + return await this.electSelf(); + } catch (error) { + logger.error('Failed to check leadership status:', error); + return false; + } + } + + /** + * Steps down from leadership by stopping the refresh timer and releasing the leader key. + * Atomically deletes the leader key (only if we still own it) so another server can become leader immediately. + */ + public async resign(): Promise { + if (!cache.USE_REDIS) return; + + try { + this.clearRefreshTimer(); + + // Lua script for atomic check-and-delete (only delete if we still own it) + const script = ` + if redis.call("get", KEYS[1]) == ARGV[1] then + redis.call("del", KEYS[1]) + end + `; + + await keyvRedisClient!.eval(script, { + keys: [LeaderElection.LEADER_KEY], + arguments: [this.UUID], + }); + } catch (error) { + logger.error('Failed to release leadership lock:', error); + } + } + + /** + * Gets the UUID of the current leader from Redis. + * Returns null if no leader exists or in non-Redis mode. + * Useful for testing and observability. + */ + public static async getLeaderUUID(): Promise { + if (!cache.USE_REDIS) return null; + return await keyvRedisClient!.get(LeaderElection.LEADER_KEY); + } + + /** + * Clears the refresh timer to stop leadership maintenance. + * Called when resigning or failing to refresh leadership. + * Calling this directly to simulate a crash in testing. + */ + public clearRefreshTimer(): void { + clearInterval(this.refreshTimer); + this.refreshTimer = undefined; + } + + /** + * Attempts to claim leadership using atomic Redis SET NX (set if not exists). + * If successful, starts a refresh timer to maintain leadership by extending the lease duration. + * The NX flag ensures only one server can become leader even if multiple attempt simultaneously. + */ + private async electSelf(): Promise { + try { + const result = await keyvRedisClient!.set(LeaderElection.LEADER_KEY, this.UUID, { + NX: true, + EX: cluster.LEADER_LEASE_DURATION, + }); + + if (result !== 'OK') return false; + + this.clearRefreshTimer(); + this.refreshTimer = setInterval(async () => { + await this.renewLeadership(); + }, cluster.LEADER_RENEW_INTERVAL * 1000); + this.refreshTimer.unref(); + + return true; + } catch (error) { + logger.error('Leader election failed:', error); + return false; + } + } + + /** + * Renews leadership by extending the lease duration on the leader key. + * Uses Lua script to atomically verify we still own the key before renewing (prevents race conditions). + * If we've lost leadership (key was taken by another server), stops the refresh timer. + * This is called every 10 seconds by the refresh timer. + */ + private async renewLeadership(attempts: number = 1): Promise { + try { + // Lua script for atomic check-and-renew + const script = ` + if redis.call("get", KEYS[1]) == ARGV[1] then + return redis.call("expire", KEYS[1], ARGV[2]) + else + return 0 + end + `; + + const result = await keyvRedisClient!.eval(script, { + keys: [LeaderElection.LEADER_KEY], + arguments: [this.UUID, cluster.LEADER_LEASE_DURATION.toString()], + }); + + if (result === 0) { + logger.warn('Lost leadership, clearing refresh timer'); + this.clearRefreshTimer(); + } + } catch (error) { + logger.error(`Failed to renew leadership (attempts No.${attempts}):`, error); + if (attempts <= cluster.LEADER_RENEW_ATTEMPTS) { + await new Promise((resolve) => + setTimeout(resolve, cluster.LEADER_RENEW_RETRY_DELAY * 1000), + ); + await this.renewLeadership(attempts + 1); + } else { + logger.error('Exceeded maximum attempts to renew leadership.'); + this.clearRefreshTimer(); + } + } + } +} + +const defaultElection = new LeaderElection(); +export const isLeader = (): Promise => defaultElection.isLeader(); diff --git a/packages/api/src/cluster/__tests__/LeaderElection.cache_integration.spec.ts b/packages/api/src/cluster/__tests__/LeaderElection.cache_integration.spec.ts new file mode 100644 index 0000000000..60bc1b439b --- /dev/null +++ b/packages/api/src/cluster/__tests__/LeaderElection.cache_integration.spec.ts @@ -0,0 +1,220 @@ +import { expect } from '@playwright/test'; + +describe('LeaderElection with Redis', () => { + let LeaderElection: typeof import('../LeaderElection').LeaderElection; + let instances: InstanceType[] = []; + let keyvRedisClient: Awaited['keyvRedisClient']; + let ioredisClient: Awaited['ioredisClient']; + + beforeAll(async () => { + // Set up environment variables for Redis + process.env.USE_REDIS = 'true'; + process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379'; + process.env.REDIS_KEY_PREFIX = 'LeaderElection-IntegrationTest'; + + // Import modules after setting env vars + const leaderElectionModule = await import('../LeaderElection'); + const redisClients = await import('~/cache/redisClients'); + + LeaderElection = leaderElectionModule.LeaderElection; + keyvRedisClient = redisClients.keyvRedisClient; + ioredisClient = redisClients.ioredisClient; + + // Ensure Redis is connected + if (!keyvRedisClient) { + throw new Error('Redis client is not initialized'); + } + + // Wait for Redis to be ready + if (!keyvRedisClient.isOpen) { + await keyvRedisClient.connect(); + } + + // Increase max listeners to handle many instances in tests + process.setMaxListeners(200); + }); + + afterEach(async () => { + await Promise.all(instances.map((instance) => instance.resign())); + instances = []; + + // Clean up: clear the leader key directly from Redis + if (keyvRedisClient) { + await keyvRedisClient.del(LeaderElection.LEADER_KEY); + } + }); + + afterAll(async () => { + // Close both Redis clients to prevent hanging + if (keyvRedisClient?.isOpen) await keyvRedisClient.disconnect(); + if (ioredisClient?.status === 'ready') await ioredisClient.quit(); + }); + + describe('Test Case 1: Simulate shutdown of the leader', () => { + it('should elect a new leader after the current leader resigns', async () => { + // Create 100 instances + instances = Array.from({ length: 100 }, () => new LeaderElection()); + + // Call isLeader on all instances and get leadership status + const resultsWithInstances = await Promise.all( + instances.map(async (instance) => ({ + instance, + isLeader: await instance.isLeader(), + })), + ); + + // Find leader and followers + const leaders = resultsWithInstances.filter((r) => r.isLeader); + const followers = resultsWithInstances.filter((r) => !r.isLeader); + const leader = leaders[0].instance; + const nextLeader = followers[0].instance; + + // Verify only one is leader + expect(leaders.length).toBe(1); + + // Verify getLeaderUUID matches the leader's UUID + expect(await LeaderElection.getLeaderUUID()).toBe(leader.UUID); + + // Leader resigns + await leader.resign(); + + // Verify getLeaderUUID returns null after resignation + expect(await LeaderElection.getLeaderUUID()).toBeNull(); + + // Next instance to call isLeader should become the new leader + expect(await nextLeader.isLeader()).toBe(true); + }, 30000); // 30 second timeout for 100 instances + }); + + describe('Test Case 2: Simulate crash of the leader', () => { + it('should allow re-election after leader crashes (lease expires)', async () => { + // Mock config with short lease duration + const clusterConfigModule = await import('../config'); + const originalConfig = { ...clusterConfigModule.clusterConfig }; + + // Override config values for this test + Object.assign(clusterConfigModule.clusterConfig, { + LEADER_LEASE_DURATION: 2, + LEADER_RENEW_INTERVAL: 4, + }); + + try { + // Create 1 instance with mocked config + const instance = new LeaderElection(); + instances.push(instance); + + // Become leader + expect(await instance.isLeader()).toBe(true); + + // Verify leader UUID is set + expect(await LeaderElection.getLeaderUUID()).toBe(instance.UUID); + + // Simulate crash by clearing refresh timer + instance.clearRefreshTimer(); + + // The instance no longer considers itself leader even though it still holds the key + expect(await LeaderElection.getLeaderUUID()).toBe(instance.UUID); + expect(await instance.isLeader()).toBe(false); + + // Wait for lease to expire (3 seconds > 2 second lease) + await new Promise((resolve) => setTimeout(resolve, 3000)); + + // Verify leader UUID is null after lease expiration + expect(await LeaderElection.getLeaderUUID()).toBeNull(); + } finally { + // Restore original config values + Object.assign(clusterConfigModule.clusterConfig, originalConfig); + } + }, 15000); // 15 second timeout + }); + + describe('Test Case 3: Stress testing', () => { + it('should ensure only one instance becomes leader even when multiple instances call electSelf() at once', async () => { + // Create 10 instances + instances = Array.from({ length: 10 }, () => new LeaderElection()); + + // Call electSelf on all instances in parallel + const results = await Promise.all(instances.map((instance) => instance['electSelf']())); + + // Verify only one returned true + const successCount = results.filter((success) => success).length; + expect(successCount).toBe(1); + + // Find the winning instance + const winnerInstance = instances.find((_, index) => results[index]); + + // Verify getLeaderUUID matches the winner's UUID + expect(await LeaderElection.getLeaderUUID()).toBe(winnerInstance?.UUID); + }, 15000); // 15 second timeout + }); +}); + +describe('LeaderElection without Redis', () => { + let LeaderElection: typeof import('../LeaderElection').LeaderElection; + let instances: InstanceType[] = []; + + beforeAll(async () => { + // Set up environment variables for non-Redis mode + process.env.USE_REDIS = 'false'; + + // Reset all modules to force re-evaluation with new env vars + jest.resetModules(); + + // Import modules after setting env vars and resetting modules + const leaderElectionModule = await import('../LeaderElection'); + LeaderElection = leaderElectionModule.LeaderElection; + }); + + afterEach(async () => { + await Promise.all(instances.map((instance) => instance.resign())); + instances = []; + }); + + afterAll(() => { + // Restore environment variables + process.env.USE_REDIS = 'true'; + + // Reset all modules to ensure next test runs get fresh imports + jest.resetModules(); + }); + + it('should allow all instances to be leaders when USE_REDIS is false', async () => { + // Create 10 instances + instances = Array.from({ length: 10 }, () => new LeaderElection()); + + // Call isLeader on all instances + const results = await Promise.all(instances.map((instance) => instance.isLeader())); + + // Verify all instances report themselves as leaders + expect(results.every((isLeader) => isLeader)).toBe(true); + expect(results.filter((isLeader) => isLeader).length).toBe(10); + }); + + it('should return null for getLeaderUUID when USE_REDIS is false', async () => { + // Create a few instances + instances = Array.from({ length: 3 }, () => new LeaderElection()); + + // Call isLeader on all instances to make them "leaders" + await Promise.all(instances.map((instance) => instance.isLeader())); + + // Verify getLeaderUUID returns null in non-Redis mode + expect(await LeaderElection.getLeaderUUID()).toBeNull(); + }); + + it('should allow resign() to be called without throwing errors', async () => { + // Create multiple instances + instances = Array.from({ length: 5 }, () => new LeaderElection()); + + // Make them all leaders + await Promise.all(instances.map((instance) => instance.isLeader())); + + // Call resign on all instances - should not throw + await expect( + Promise.all(instances.map((instance) => instance.resign())), + ).resolves.not.toThrow(); + + // Verify they're still leaders after resigning (since there's no shared state) + const results = await Promise.all(instances.map((instance) => instance.isLeader())); + expect(results.every((isLeader) => isLeader)).toBe(true); + }); +}); diff --git a/packages/api/src/cluster/config.ts b/packages/api/src/cluster/config.ts new file mode 100644 index 0000000000..d4a99b3217 --- /dev/null +++ b/packages/api/src/cluster/config.ts @@ -0,0 +1,14 @@ +import { math } from '~/utils'; + +const clusterConfig = { + /** Duration in seconds that the leader lease is valid before it expires */ + LEADER_LEASE_DURATION: math(process.env.LEADER_LEASE_DURATION, 25), + /** Interval in seconds at which the leader renews its lease */ + LEADER_RENEW_INTERVAL: math(process.env.LEADER_RENEW_INTERVAL, 10), + /** Maximum number of retry attempts when renewing the lease fails */ + LEADER_RENEW_ATTEMPTS: math(process.env.LEADER_RENEW_ATTEMPTS, 3), + /** Delay in seconds between retry attempts when renewing the lease */ + LEADER_RENEW_RETRY_DELAY: math(process.env.LEADER_RENEW_RETRY_DELAY, 0.5), +}; + +export { clusterConfig }; diff --git a/packages/api/src/cluster/index.ts b/packages/api/src/cluster/index.ts new file mode 100644 index 0000000000..71925a87ef --- /dev/null +++ b/packages/api/src/cluster/index.ts @@ -0,0 +1 @@ +export { isLeader } from './LeaderElection'; From ea45d0b9c6ced68cfe5e76d59b0c61bee9f36bd7 Mon Sep 17 00:00:00 2001 From: Federico Ruggi Date: Thu, 30 Oct 2025 22:09:56 +0100 Subject: [PATCH 002/207] =?UTF-8?q?=F0=9F=8F=B7=EF=B8=8F=20fix:=20Add=20us?= =?UTF-8?q?er=20ID=20to=20MCP=20tools=20cache=20keys=20(#10201)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add user id to mcp tools cache key * tests * clean up redundant tests * remove unused imports --- api/app/clients/tools/util/handleTools.js | 2 +- api/models/Agent.js | 3 +- api/models/Agent.spec.js | 6 ++-- api/server/controllers/mcp.js | 4 +-- api/server/routes/__tests__/mcp.spec.js | 1 + api/server/routes/mcp.js | 1 + .../Config/__tests__/getCachedTools.spec.js | 10 ++++++ api/server/services/Config/getCachedTools.js | 34 +++++++++++-------- api/server/services/Config/mcp.js | 20 ++++++----- api/server/services/Tools/mcp.js | 1 + 10 files changed, 52 insertions(+), 30 deletions(-) create mode 100644 api/server/services/Config/__tests__/getCachedTools.spec.js diff --git a/api/app/clients/tools/util/handleTools.js b/api/app/clients/tools/util/handleTools.js index 698014cbe0..e32ca6bc44 100644 --- a/api/app/clients/tools/util/handleTools.js +++ b/api/app/clients/tools/util/handleTools.js @@ -448,7 +448,7 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })} } if (!availableTools) { try { - availableTools = await getMCPServerTools(serverName); + availableTools = await getMCPServerTools(safeUser.id, serverName); } catch (error) { logger.error(`Error fetching available tools for MCP server ${serverName}:`, error); } diff --git a/api/models/Agent.js b/api/models/Agent.js index f5f740ba7b..b802ca187b 100644 --- a/api/models/Agent.js +++ b/api/models/Agent.js @@ -79,6 +79,7 @@ const loadEphemeralAgent = async ({ req, spec, agent_id, endpoint, model_paramet /** @type {TEphemeralAgent | null} */ const ephemeralAgent = req.body.ephemeralAgent; const mcpServers = new Set(ephemeralAgent?.mcp); + const userId = req.user?.id; // note: userId cannot be undefined at runtime if (modelSpec?.mcpServers) { for (const mcpServer of modelSpec.mcpServers) { mcpServers.add(mcpServer); @@ -102,7 +103,7 @@ const loadEphemeralAgent = async ({ req, spec, agent_id, endpoint, model_paramet if (addedServers.has(mcpServer)) { continue; } - const serverTools = await getMCPServerTools(mcpServer); + const serverTools = await getMCPServerTools(userId, mcpServer); if (!serverTools) { tools.push(`${mcp_all}${mcp_delimiter}${mcpServer}`); addedServers.add(mcpServer); diff --git a/api/models/Agent.spec.js b/api/models/Agent.spec.js index f95db65013..6c7db6121e 100644 --- a/api/models/Agent.spec.js +++ b/api/models/Agent.spec.js @@ -1931,7 +1931,7 @@ describe('models/Agent', () => { }); // Mock getMCPServerTools to return tools for each server - getMCPServerTools.mockImplementation(async (server) => { + getMCPServerTools.mockImplementation(async (_userId, server) => { if (server === 'server1') { return { tool1_mcp_server1: {} }; } else if (server === 'server2') { @@ -2125,7 +2125,7 @@ describe('models/Agent', () => { getCachedTools.mockResolvedValue(availableTools); // Mock getMCPServerTools to return all tools for server1 - getMCPServerTools.mockImplementation(async (server) => { + getMCPServerTools.mockImplementation(async (_userId, server) => { if (server === 'server1') { return availableTools; // All 100 tools belong to server1 } @@ -2674,7 +2674,7 @@ describe('models/Agent', () => { }); // Mock getMCPServerTools to return only tools matching the server - getMCPServerTools.mockImplementation(async (server) => { + getMCPServerTools.mockImplementation(async (_userId, server) => { if (server === 'server1') { // Only return tool that correctly matches server1 format return { tool_mcp_server1: {} }; diff --git a/api/server/controllers/mcp.js b/api/server/controllers/mcp.js index 839d9bd17b..9e520d392e 100644 --- a/api/server/controllers/mcp.js +++ b/api/server/controllers/mcp.js @@ -32,7 +32,7 @@ const getMCPTools = async (req, res) => { const mcpServers = {}; const cachePromises = configuredServers.map((serverName) => - getMCPServerTools(serverName).then((tools) => ({ serverName, tools })), + getMCPServerTools(userId, serverName).then((tools) => ({ serverName, tools })), ); const cacheResults = await Promise.all(cachePromises); @@ -52,7 +52,7 @@ const getMCPTools = async (req, res) => { if (Object.keys(serverTools).length > 0) { // Cache asynchronously without blocking - cacheMCPServerTools({ serverName, serverTools }).catch((err) => + cacheMCPServerTools({ userId, serverName, serverTools }).catch((err) => logger.error(`[getMCPTools] Failed to cache tools for ${serverName}:`, err), ); } diff --git a/api/server/routes/__tests__/mcp.spec.js b/api/server/routes/__tests__/mcp.spec.js index 64c95c58ee..8ae92cdd3d 100644 --- a/api/server/routes/__tests__/mcp.spec.js +++ b/api/server/routes/__tests__/mcp.spec.js @@ -47,6 +47,7 @@ jest.mock('~/models', () => ({ jest.mock('~/server/services/Config', () => ({ setCachedTools: jest.fn(), getCachedTools: jest.fn(), + getMCPServerTools: jest.fn(), loadCustomConfig: jest.fn(), })); diff --git a/api/server/routes/mcp.js b/api/server/routes/mcp.js index e8415fd801..9b66b10e52 100644 --- a/api/server/routes/mcp.js +++ b/api/server/routes/mcp.js @@ -205,6 +205,7 @@ router.get('/:serverName/oauth/callback', async (req, res) => { const tools = await userConnection.fetchTools(); await updateMCPServerTools({ + userId: flowState.userId, serverName, tools, }); diff --git a/api/server/services/Config/__tests__/getCachedTools.spec.js b/api/server/services/Config/__tests__/getCachedTools.spec.js new file mode 100644 index 0000000000..48ab6e0737 --- /dev/null +++ b/api/server/services/Config/__tests__/getCachedTools.spec.js @@ -0,0 +1,10 @@ +const { ToolCacheKeys } = require('../getCachedTools'); + +describe('getCachedTools - Cache Isolation Security', () => { + describe('ToolCacheKeys.MCP_SERVER', () => { + it('should generate cache keys that include userId', () => { + const key = ToolCacheKeys.MCP_SERVER('user123', 'github'); + expect(key).toBe('tools:mcp:user123:github'); + }); + }); +}); diff --git a/api/server/services/Config/getCachedTools.js b/api/server/services/Config/getCachedTools.js index 59a0c8cc5d..841ca04c94 100644 --- a/api/server/services/Config/getCachedTools.js +++ b/api/server/services/Config/getCachedTools.js @@ -7,24 +7,25 @@ const getLogStores = require('~/cache/getLogStores'); const ToolCacheKeys = { /** Global tools available to all users */ GLOBAL: 'tools:global', - /** MCP tools cached by server name */ - MCP_SERVER: (serverName) => `tools:mcp:${serverName}`, + /** MCP tools cached by user ID and server name */ + MCP_SERVER: (userId, serverName) => `tools:mcp:${userId}:${serverName}`, }; /** * Retrieves available tools from cache * @function getCachedTools * @param {Object} options - Options for retrieving tools + * @param {string} [options.userId] - User ID for user-specific MCP tools * @param {string} [options.serverName] - MCP server name to get cached tools for * @returns {Promise} The available tools object or null if not cached */ async function getCachedTools(options = {}) { const cache = getLogStores(CacheKeys.CONFIG_STORE); - const { serverName } = options; + const { userId, serverName } = options; // Return MCP server-specific tools if requested - if (serverName) { - return await cache.get(ToolCacheKeys.MCP_SERVER(serverName)); + if (serverName && userId) { + return await cache.get(ToolCacheKeys.MCP_SERVER(userId, serverName)); } // Default to global tools @@ -36,17 +37,18 @@ async function getCachedTools(options = {}) { * @function setCachedTools * @param {Object} tools - The tools object to cache * @param {Object} options - Options for caching tools + * @param {string} [options.userId] - User ID for user-specific MCP tools * @param {string} [options.serverName] - MCP server name for server-specific tools * @param {number} [options.ttl] - Time to live in milliseconds * @returns {Promise} Whether the operation was successful */ async function setCachedTools(tools, options = {}) { const cache = getLogStores(CacheKeys.CONFIG_STORE); - const { serverName, ttl } = options; + const { userId, serverName, ttl } = options; - // Cache by MCP server if specified - if (serverName) { - return await cache.set(ToolCacheKeys.MCP_SERVER(serverName), tools, ttl); + // Cache by MCP server if specified (requires userId) + if (serverName && userId) { + return await cache.set(ToolCacheKeys.MCP_SERVER(userId, serverName), tools, ttl); } // Default to global cache @@ -57,13 +59,14 @@ async function setCachedTools(tools, options = {}) { * Invalidates cached tools * @function invalidateCachedTools * @param {Object} options - Options for invalidating tools + * @param {string} [options.userId] - User ID for user-specific MCP tools * @param {string} [options.serverName] - MCP server name to invalidate * @param {boolean} [options.invalidateGlobal=false] - Whether to invalidate global tools * @returns {Promise} */ async function invalidateCachedTools(options = {}) { const cache = getLogStores(CacheKeys.CONFIG_STORE); - const { serverName, invalidateGlobal = false } = options; + const { userId, serverName, invalidateGlobal = false } = options; const keysToDelete = []; @@ -71,22 +74,23 @@ async function invalidateCachedTools(options = {}) { keysToDelete.push(ToolCacheKeys.GLOBAL); } - if (serverName) { - keysToDelete.push(ToolCacheKeys.MCP_SERVER(serverName)); + if (serverName && userId) { + keysToDelete.push(ToolCacheKeys.MCP_SERVER(userId, serverName)); } await Promise.all(keysToDelete.map((key) => cache.delete(key))); } /** - * Gets MCP tools for a specific server from cache or merges with global tools + * Gets MCP tools for a specific server from cache * @function getMCPServerTools + * @param {string} userId - The user ID * @param {string} serverName - The MCP server name * @returns {Promise} The available tools for the server */ -async function getMCPServerTools(serverName) { +async function getMCPServerTools(userId, serverName) { const cache = getLogStores(CacheKeys.CONFIG_STORE); - const serverTools = await cache.get(ToolCacheKeys.MCP_SERVER(serverName)); + const serverTools = await cache.get(ToolCacheKeys.MCP_SERVER(userId, serverName)); if (serverTools) { return serverTools; diff --git a/api/server/services/Config/mcp.js b/api/server/services/Config/mcp.js index 75824d1b30..7f4210f8c9 100644 --- a/api/server/services/Config/mcp.js +++ b/api/server/services/Config/mcp.js @@ -6,11 +6,12 @@ const { getLogStores } = require('~/cache'); /** * Updates MCP tools in the cache for a specific server * @param {Object} params - Parameters for updating MCP tools + * @param {string} params.userId - User ID for user-specific caching * @param {string} params.serverName - MCP server name * @param {Array} params.tools - Array of tool objects from MCP server * @returns {Promise} */ -async function updateMCPServerTools({ serverName, tools }) { +async function updateMCPServerTools({ userId, serverName, tools }) { try { const serverTools = {}; const mcpDelimiter = Constants.mcp_delimiter; @@ -27,14 +28,16 @@ async function updateMCPServerTools({ serverName, tools }) { }; } - await setCachedTools(serverTools, { serverName }); + await setCachedTools(serverTools, { userId, serverName }); const cache = getLogStores(CacheKeys.CONFIG_STORE); await cache.delete(CacheKeys.TOOLS); - logger.debug(`[MCP Cache] Updated ${tools.length} tools for server ${serverName}`); + logger.debug( + `[MCP Cache] Updated ${tools.length} tools for server ${serverName} (user: ${userId})`, + ); return serverTools; } catch (error) { - logger.error(`[MCP Cache] Failed to update tools for ${serverName}:`, error); + logger.error(`[MCP Cache] Failed to update tools for ${serverName} (user: ${userId}):`, error); throw error; } } @@ -65,21 +68,22 @@ async function mergeAppTools(appTools) { /** * Caches MCP server tools (no longer merges with global) * @param {object} params + * @param {string} params.userId - User ID for user-specific caching * @param {string} params.serverName * @param {import('@librechat/api').LCAvailableTools} params.serverTools * @returns {Promise} */ -async function cacheMCPServerTools({ serverName, serverTools }) { +async function cacheMCPServerTools({ userId, serverName, serverTools }) { try { const count = Object.keys(serverTools).length; if (!count) { return; } // Only cache server-specific tools, no merging with global - await setCachedTools(serverTools, { serverName }); - logger.debug(`Cached ${count} MCP server tools for ${serverName}`); + await setCachedTools(serverTools, { userId, serverName }); + logger.debug(`Cached ${count} MCP server tools for ${serverName} (user: ${userId})`); } catch (error) { - logger.error(`Failed to cache MCP server tools for ${serverName}:`, error); + logger.error(`Failed to cache MCP server tools for ${serverName} (user: ${userId}):`, error); throw error; } } diff --git a/api/server/services/Tools/mcp.js b/api/server/services/Tools/mcp.js index e6d293800d..521560aad4 100644 --- a/api/server/services/Tools/mcp.js +++ b/api/server/services/Tools/mcp.js @@ -98,6 +98,7 @@ async function reinitMCPServer({ if (connection && !oauthRequired) { tools = await connection.fetchTools(); availableTools = await updateMCPServerTools({ + userId: user.id, serverName, tools, }); From c0f1cfcaba12cc79ca64d5865cb4c368110333bf Mon Sep 17 00:00:00 2001 From: Marco Beretta <81851188+berry-13@users.noreply.github.com> Date: Thu, 30 Oct 2025 22:14:38 +0100 Subject: [PATCH 003/207] =?UTF-8?q?=F0=9F=92=A1=20feat:=20Improve=20Reason?= =?UTF-8?q?ing=20Content=20UI,=20copy-to-clipboard,=20and=20error=20handli?= =?UTF-8?q?ng=20(#10278)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ✨ feat: Refactor error handling and improve loading states in MessageContent component * ✨ feat: Enhance Thinking and ContentParts components with improved hover functionality and clipboard support * fix: Adjust padding in Thinking and ContentParts components for consistent layout * ✨ feat: Add response label and improve message editing UI with contextual indicators * ✨ feat: Add isEditing prop to Feedback and Fork components for improved editing state handling * refactor: Remove isEditing prop from Feedback and Fork components for cleaner state management * refactor: Migrate state management from Recoil to Jotai for font size and show thinking features * refactor: Separate ToggleSwitch into RecoilToggle and JotaiToggle components for improved clarity and state management * refactor: Remove unnecessary comments in ToggleSwitch and MessageContent components for cleaner code * chore: reorder import statements in Thinking.tsx * chore: reorder import statement in EditTextPart.tsx * chore: reorder import statement * chore: Reorganize imports in ToggleSwitch.tsx --------- Co-authored-by: Danny Avila --- client/src/components/Artifacts/Thinking.tsx | 171 ++++++++++++++---- .../Chat/Messages/Content/ContentParts.tsx | 111 +++++++++--- .../Chat/Messages/Content/EditMessage.tsx | 2 +- .../Chat/Messages/Content/MessageContent.tsx | 148 ++++++++------- .../Messages/Content/Parts/EditTextPart.tsx | 17 ++ .../Chat/Messages/Content/Parts/Reasoning.tsx | 30 ++- .../components/Nav/SettingsTabs/Chat/Chat.tsx | 3 +- .../Nav/SettingsTabs/Chat/ShowThinking.tsx | 8 +- .../Nav/SettingsTabs/ToggleSwitch.tsx | 74 +++++++- client/src/locales/en/translation.json | 3 + client/src/store/fontSize.ts | 51 +----- client/src/store/jotai-utils.ts | 88 +++++++++ client/src/store/showThinking.ts | 8 + 13 files changed, 528 insertions(+), 186 deletions(-) create mode 100644 client/src/store/jotai-utils.ts create mode 100644 client/src/store/showThinking.ts diff --git a/client/src/components/Artifacts/Thinking.tsx b/client/src/components/Artifacts/Thinking.tsx index 08e241c6e8..25d5810e16 100644 --- a/client/src/components/Artifacts/Thinking.tsx +++ b/client/src/components/Artifacts/Thinking.tsx @@ -1,72 +1,171 @@ import { useState, useMemo, memo, useCallback } from 'react'; -import { useRecoilValue } from 'recoil'; -import { Atom, ChevronDown } from 'lucide-react'; +import { useAtomValue } from 'jotai'; +import { Lightbulb, ChevronDown } from 'lucide-react'; +import { Clipboard, CheckMark } from '@librechat/client'; import type { MouseEvent, FC } from 'react'; +import { showThinkingAtom } from '~/store/showThinking'; +import { fontSizeAtom } from '~/store/fontSize'; import { useLocalize } from '~/hooks'; import { cn } from '~/utils'; -import store from '~/store'; -const BUTTON_STYLES = { - base: 'group mt-3 flex w-fit items-center justify-center rounded-xl bg-surface-tertiary px-3 py-2 text-xs leading-[18px] animate-thinking-appear', - icon: 'icon-sm ml-1.5 transform-gpu text-text-primary transition-transform duration-200', -} as const; +/** + * ThinkingContent - Displays the actual thinking/reasoning content + * Used by both legacy text-based messages and modern content parts + */ +export const ThinkingContent: FC<{ + children: React.ReactNode; +}> = memo(({ children }) => { + const fontSize = useAtomValue(fontSizeAtom); -const CONTENT_STYLES = { - wrapper: 'relative pl-3 text-text-secondary', - border: - 'absolute left-0 h-[calc(100%-10px)] border-l-2 border-border-medium dark:border-border-heavy', - partBorder: - 'absolute left-0 h-[calc(100%)] border-l-2 border-border-medium dark:border-border-heavy', - text: 'whitespace-pre-wrap leading-[26px]', -} as const; - -export const ThinkingContent: FC<{ children: React.ReactNode; isPart?: boolean }> = memo( - ({ isPart, children }) => ( -
-
-

{children}

+ return ( +
+

{children}

- ), -); + ); +}); +/** + * ThinkingButton - Toggle button for expanding/collapsing thinking content + * Shows lightbulb icon by default, chevron on hover + * Shared between legacy Thinking component and modern ContentParts + */ export const ThinkingButton = memo( ({ isExpanded, onClick, label, + content, + isContentHovered = false, }: { isExpanded: boolean; onClick: (e: MouseEvent) => void; label: string; - }) => ( - - ), + content?: string; + isContentHovered?: boolean; + }) => { + const localize = useLocalize(); + const fontSize = useAtomValue(fontSizeAtom); + + const [isButtonHovered, setIsButtonHovered] = useState(false); + const [isCopied, setIsCopied] = useState(false); + + const isHovered = useMemo( + () => isButtonHovered || isContentHovered, + [isButtonHovered, isContentHovered], + ); + + const handleCopy = useCallback( + (e: MouseEvent) => { + e.stopPropagation(); + if (content) { + navigator.clipboard.writeText(content); + setIsCopied(true); + setTimeout(() => setIsCopied(false), 2000); + } + }, + [content], + ); + + return ( +
+ + {content && ( + + )} +
+ ); + }, ); +/** + * Thinking Component (LEGACY SYSTEM) + * + * Used for simple text-based messages with `:::thinking:::` markers. + * This handles the old message format where text contains embedded thinking blocks. + * + * Pattern: `:::thinking\n{content}\n:::\n{response}` + * + * Used by: + * - MessageContent.tsx for plain text messages + * - Legacy message format compatibility + * - User messages when manually adding thinking content + * + * For modern structured content (agents/assistants), see Reasoning.tsx component. + */ const Thinking: React.ElementType = memo(({ children }: { children: React.ReactNode }) => { const localize = useLocalize(); - const showThinking = useRecoilValue(store.showThinking); + const showThinking = useAtomValue(showThinkingAtom); const [isExpanded, setIsExpanded] = useState(showThinking); + const [isContentHovered, setIsContentHovered] = useState(false); const handleClick = useCallback((e: MouseEvent) => { e.preventDefault(); setIsExpanded((prev) => !prev); }, []); + const handleContentEnter = useCallback(() => setIsContentHovered(true), []); + const handleContentLeave = useCallback(() => setIsContentHovered(false), []); + const label = useMemo(() => localize('com_ui_thoughts'), [localize]); + // Extract text content for copy functionality + const textContent = useMemo(() => { + if (typeof children === 'string') { + return children; + } + return ''; + }, [children]); + if (children == null) { return null; } return ( - <> -
- +
+
+
- {children} + {children}
- +
); }); diff --git a/client/src/components/Chat/Messages/Content/ContentParts.tsx b/client/src/components/Chat/Messages/Content/ContentParts.tsx index d9efa34cc4..157e57aa4a 100644 --- a/client/src/components/Chat/Messages/Content/ContentParts.tsx +++ b/client/src/components/Chat/Messages/Content/ContentParts.tsx @@ -1,5 +1,5 @@ -import { memo, useMemo, useState } from 'react'; -import { useRecoilState } from 'recoil'; +import { memo, useMemo, useState, useCallback } from 'react'; +import { useAtom } from 'jotai'; import { ContentTypes } from 'librechat-data-provider'; import type { TMessageContentParts, @@ -9,12 +9,12 @@ import type { } from 'librechat-data-provider'; import { ThinkingButton } from '~/components/Artifacts/Thinking'; import { MessageContext, SearchContext } from '~/Providers'; +import { showThinkingAtom } from '~/store/showThinking'; import MemoryArtifacts from './MemoryArtifacts'; import Sources from '~/components/Web/Sources'; import { mapAttachments } from '~/utils/map'; import { EditTextPart } from './Parts'; import { useLocalize } from '~/hooks'; -import store from '~/store'; import Part from './Part'; type ContentPartsProps = { @@ -53,12 +53,16 @@ const ContentParts = memo( setSiblingIdx, }: ContentPartsProps) => { const localize = useLocalize(); - const [showThinking, setShowThinking] = useRecoilState(store.showThinking); + const [showThinking, setShowThinking] = useAtom(showThinkingAtom); const [isExpanded, setIsExpanded] = useState(showThinking); + const [isContentHovered, setIsContentHovered] = useState(false); const attachmentMap = useMemo(() => mapAttachments(attachments ?? []), [attachments]); const effectiveIsSubmitting = isLatestMessage ? isSubmitting : false; + const handleContentEnter = useCallback(() => setIsContentHovered(true), []); + const handleContentLeave = useCallback(() => setIsContentHovered(false), []); + const hasReasoningParts = useMemo(() => { const hasThinkPart = content?.some((part) => part?.type === ContentTypes.THINK) ?? false; const allThinkPartsHaveContent = @@ -78,6 +82,23 @@ const ContentParts = memo( return hasThinkPart && allThinkPartsHaveContent; }, [content]); + // Extract all reasoning text for copy functionality + const reasoningContent = useMemo(() => { + if (!content) { + return ''; + } + return content + .filter((part) => part?.type === ContentTypes.THINK) + .map((part) => { + if (typeof part?.think === 'string') { + return part.think.replace(/<\/?think>/g, '').trim(); + } + return ''; + }) + .filter(Boolean) + .join('\n\n'); + }, [content]); + if (!content) { return null; } @@ -127,40 +148,74 @@ const ContentParts = memo( {hasReasoningParts && ( -
- - setIsExpanded((prev) => { - const val = !prev; - setShowThinking(val); - return val; - }) - } - label={ - effectiveIsSubmitting && isLast - ? localize('com_ui_thinking') - : localize('com_ui_thoughts') - } - /> +
+
+ + setIsExpanded((prev) => { + const val = !prev; + setShowThinking(val); + return val; + }) + } + label={ + effectiveIsSubmitting && isLast + ? localize('com_ui_thinking') + : localize('com_ui_thoughts') + } + content={reasoningContent} + isContentHovered={isContentHovered} + /> +
+ {content + .filter((part) => part?.type === ContentTypes.THINK) + .map((part) => { + const originalIdx = content.indexOf(part); + return ( + + + + ); + })}
)} {content - .filter((part) => part) - .map((part, idx) => { + .filter((part) => part && part.type !== ContentTypes.THINK) + .map((part) => { + const originalIdx = content.indexOf(part); const toolCallId = (part?.[ContentTypes.TOOL_CALL] as Agents.ToolCall | undefined)?.id ?? ''; const attachments = attachmentMap[toolCallId]; return ( ); diff --git a/client/src/components/Chat/Messages/Content/EditMessage.tsx b/client/src/components/Chat/Messages/Content/EditMessage.tsx index b24d67cf3d..e578c2a56c 100644 --- a/client/src/components/Chat/Messages/Content/EditMessage.tsx +++ b/client/src/components/Chat/Messages/Content/EditMessage.tsx @@ -151,7 +151,7 @@ const EditMessage = ({ return ( -
+
{ diff --git a/client/src/components/Chat/Messages/Content/MessageContent.tsx b/client/src/components/Chat/Messages/Content/MessageContent.tsx index f36375234c..9204d3738d 100644 --- a/client/src/components/Chat/Messages/Content/MessageContent.tsx +++ b/client/src/components/Chat/Messages/Content/MessageContent.tsx @@ -14,57 +14,79 @@ import Markdown from './Markdown'; import { cn } from '~/utils'; import store from '~/store'; +const ERROR_CONNECTION_TEXT = 'Error connecting to server, try refreshing the page.'; +const DELAYED_ERROR_TIMEOUT = 5500; +const UNFINISHED_DELAY = 250; + +const parseThinkingContent = (text: string) => { + const thinkingMatch = text.match(/:::thinking([\s\S]*?):::/); + return { + thinkingContent: thinkingMatch ? thinkingMatch[1].trim() : '', + regularContent: thinkingMatch ? text.replace(/:::thinking[\s\S]*?:::/, '').trim() : text, + }; +}; + +const LoadingFallback = () => ( +
+
+
+

+ +

+
+
+
+); + +const ErrorBox = ({ + children, + className = '', +}: { + children: React.ReactNode; + className?: string; +}) => ( +
+ {children} +
+); + +const ConnectionError = ({ message }: { message?: TMessage }) => { + const localize = useLocalize(); + + return ( + }> + + +
+ {localize('com_ui_error_connection')} +
+
+
+
+ ); +}; + export const ErrorMessage = ({ text, message, className = '', -}: Pick & { - message?: TMessage; -}) => { - const localize = useLocalize(); - if (text === 'Error connecting to server, try refreshing the page.') { - console.log('error message', message); - return ( - -
-
-

- -

-
-
-
- } - > - - -
- {localize('com_ui_error_connection')} -
-
-
- - ); +}: Pick & { message?: TMessage }) => { + if (text === ERROR_CONNECTION_TEXT) { + return ; } + return ( -
+ -
+
); }; @@ -72,27 +94,29 @@ export const ErrorMessage = ({ const DisplayMessage = ({ text, isCreatedByUser, message, showCursor }: TDisplayProps) => { const { isSubmitting = false, isLatestMessage = false } = useMessageContext(); const enableUserMsgMarkdown = useRecoilValue(store.enableUserMsgMarkdown); + const showCursorState = useMemo( () => showCursor === true && isSubmitting, [showCursor, isSubmitting], ); - let content: React.ReactElement; - if (!isCreatedByUser) { - content = ; - } else if (enableUserMsgMarkdown) { - content = ; - } else { - content = <>{text}; - } + const content = useMemo(() => { + if (!isCreatedByUser) { + return ; + } + if (enableUserMsgMarkdown) { + return ; + } + return <>{text}; + }, [isCreatedByUser, enableUserMsgMarkdown, text, isLatestMessage]); return (
0 && 'result-streaming', isCreatedByUser && !enableUserMsgMarkdown && 'whitespace-pre-wrap', isCreatedByUser ? 'dark:text-gray-20' : 'dark:text-gray-100', )} @@ -103,7 +127,6 @@ const DisplayMessage = ({ text, isCreatedByUser, message, showCursor }: TDisplay ); }; -// Unfinished Message Component export const UnfinishedMessage = ({ message }: { message: TMessage }) => ( { - const thinkingMatch = text.match(/:::thinking([\s\S]*?):::/); - return { - thinkingContent: thinkingMatch ? thinkingMatch[1].trim() : '', - regularContent: thinkingMatch ? text.replace(/:::thinking[\s\S]*?:::/, '').trim() : text, - }; - }, [text]); - + const { thinkingContent, regularContent } = useMemo(() => parseThinkingContent(text), [text]); const showRegularCursor = useMemo(() => isLast && isSubmitting, [isLast, isSubmitting]); const unfinishedMessage = useMemo( () => !isSubmitting && unfinished ? ( - + @@ -146,8 +162,10 @@ const MessageContent = ({ ); if (error) { - return ; - } else if (edit) { + return ; + } + + if (edit) { return ; } diff --git a/client/src/components/Chat/Messages/Content/Parts/EditTextPart.tsx b/client/src/components/Chat/Messages/Content/Parts/EditTextPart.tsx index 5422d9733d..10f61fd8af 100644 --- a/client/src/components/Chat/Messages/Content/Parts/EditTextPart.tsx +++ b/client/src/components/Chat/Messages/Content/Parts/EditTextPart.tsx @@ -3,6 +3,7 @@ import { useForm } from 'react-hook-form'; import { TextareaAutosize } from '@librechat/client'; import { ContentTypes } from 'librechat-data-provider'; import { useRecoilState, useRecoilValue } from 'recoil'; +import { Lightbulb, MessageSquare } from 'lucide-react'; import { useUpdateMessageContentMutation } from 'librechat-data-provider/react-query'; import type { Agents } from 'librechat-data-provider'; import type { TEditProps } from '~/common'; @@ -153,6 +154,22 @@ const EditTextPart = ({ return ( + {part.type === ContentTypes.THINK && ( +
+ + + {localize('com_ui_thoughts')} + +
+ )} + {part.type !== ContentTypes.THINK && ( +
+ + + {localize('com_ui_response')} + +
+ )}
content" }, ...] }` + * + * Used by: + * - ContentParts.tsx → Part.tsx for structured messages + * - Agent/Assistant responses (OpenAI Assistants, custom agents) + * - O-series models (o1, o3) with reasoning capabilities + * - Modern Claude responses with thinking blocks + * + * Key differences from legacy Thinking.tsx: + * - Works with content parts array instead of plain text + * - Strips `` tags instead of `:::thinking:::` markers + * - Uses shared ThinkingButton via ContentParts.tsx + * - Controlled by MessageContext isExpanded state + * + * For legacy text-based messages, see Thinking.tsx component. + */ const Reasoning = memo(({ reasoning }: ReasoningProps) => { const { isExpanded, nextType } = useMessageContext(); + + // Strip tags from the reasoning content (modern format) const reasoningText = useMemo(() => { return reasoning .replace(/^\s*/, '') @@ -21,18 +45,20 @@ const Reasoning = memo(({ reasoning }: ReasoningProps) => { return null; } + // Note: The toggle button is rendered separately in ContentParts.tsx + // This component only handles the collapsible content area return (
- {reasoningText} + {reasoningText}
); diff --git a/client/src/components/Nav/SettingsTabs/Chat/Chat.tsx b/client/src/components/Nav/SettingsTabs/Chat/Chat.tsx index 70858a9b72..5c703922fc 100644 --- a/client/src/components/Nav/SettingsTabs/Chat/Chat.tsx +++ b/client/src/components/Nav/SettingsTabs/Chat/Chat.tsx @@ -1,4 +1,5 @@ import { memo } from 'react'; +import { showThinkingAtom } from '~/store/showThinking'; import FontSizeSelector from './FontSizeSelector'; import { ForkSettings } from './ForkSettings'; import ChatDirection from './ChatDirection'; @@ -28,7 +29,7 @@ const toggleSwitchConfigs = [ key: 'centerFormOnLanding', }, { - stateAtom: store.showThinking, + stateAtom: showThinkingAtom, localizationKey: 'com_nav_show_thinking', switchId: 'showThinking', hoverCardText: undefined, diff --git a/client/src/components/Nav/SettingsTabs/Chat/ShowThinking.tsx b/client/src/components/Nav/SettingsTabs/Chat/ShowThinking.tsx index 949453cb5c..905efcc98c 100644 --- a/client/src/components/Nav/SettingsTabs/Chat/ShowThinking.tsx +++ b/client/src/components/Nav/SettingsTabs/Chat/ShowThinking.tsx @@ -1,18 +1,18 @@ -import { useRecoilState } from 'recoil'; +import { useAtom } from 'jotai'; import { Switch, InfoHoverCard, ESide } from '@librechat/client'; +import { showThinkingAtom } from '~/store/showThinking'; import { useLocalize } from '~/hooks'; -import store from '~/store'; export default function SaveDraft({ onCheckedChange, }: { onCheckedChange?: (value: boolean) => void; }) { - const [showThinking, setSaveDrafts] = useRecoilState(store.showThinking); + const [showThinking, setShowThinking] = useAtom(showThinkingAtom); const localize = useLocalize(); const handleCheckedChange = (value: boolean) => { - setSaveDrafts(value); + setShowThinking(value); if (onCheckedChange) { onCheckedChange(value); } diff --git a/client/src/components/Nav/SettingsTabs/ToggleSwitch.tsx b/client/src/components/Nav/SettingsTabs/ToggleSwitch.tsx index 391ab0a494..2bbe0d941f 100644 --- a/client/src/components/Nav/SettingsTabs/ToggleSwitch.tsx +++ b/client/src/components/Nav/SettingsTabs/ToggleSwitch.tsx @@ -1,3 +1,4 @@ +import { WritableAtom, useAtom } from 'jotai'; import { RecoilState, useRecoilState } from 'recoil'; import { Switch, InfoHoverCard, ESide } from '@librechat/client'; import { useLocalize } from '~/hooks'; @@ -6,7 +7,7 @@ type LocalizeFn = ReturnType; type LocalizeKey = Parameters[0]; interface ToggleSwitchProps { - stateAtom: RecoilState; + stateAtom: RecoilState | WritableAtom; localizationKey: LocalizeKey; hoverCardText?: LocalizeKey; switchId: string; @@ -16,13 +17,18 @@ interface ToggleSwitchProps { strongLabel?: boolean; } -const ToggleSwitch: React.FC = ({ +function isRecoilState(atom: unknown): atom is RecoilState { + return atom != null && typeof atom === 'object' && 'key' in atom; +} + +const RecoilToggle: React.FC< + Omit & { stateAtom: RecoilState } +> = ({ stateAtom, localizationKey, hoverCardText, switchId, onCheckedChange, - showSwitch = true, disabled = false, strongLabel = false, }) => { @@ -36,9 +42,47 @@ const ToggleSwitch: React.FC = ({ const labelId = `${switchId}-label`; - if (!showSwitch) { - return null; - } + return ( +
+
+
+ {strongLabel ? {localize(localizationKey)} : localize(localizationKey)} +
+ {hoverCardText && } +
+ +
+ ); +}; + +const JotaiToggle: React.FC< + Omit & { stateAtom: WritableAtom } +> = ({ + stateAtom, + localizationKey, + hoverCardText, + switchId, + onCheckedChange, + disabled = false, + strongLabel = false, +}) => { + const [switchState, setSwitchState] = useAtom(stateAtom); + const localize = useLocalize(); + + const handleCheckedChange = (value: boolean) => { + setSwitchState(value); + onCheckedChange?.(value); + }; + + const labelId = `${switchId}-label`; return (
@@ -52,13 +96,29 @@ const ToggleSwitch: React.FC = ({ id={switchId} checked={switchState} onCheckedChange={handleCheckedChange} + disabled={disabled} className="ml-4" data-testid={switchId} aria-labelledby={labelId} - disabled={disabled} />
); }; +const ToggleSwitch: React.FC = (props) => { + const { stateAtom, showSwitch = true } = props; + + if (!showSwitch) { + return null; + } + + const isRecoil = isRecoilState(stateAtom); + + if (isRecoil) { + return } />; + } + + return } />; +}; + export default ToggleSwitch; diff --git a/client/src/locales/en/translation.json b/client/src/locales/en/translation.json index 5becd0ce93..acbba5527b 100644 --- a/client/src/locales/en/translation.json +++ b/client/src/locales/en/translation.json @@ -789,6 +789,8 @@ "com_ui_copy_stack_trace": "Copy stack trace", "com_ui_copy_to_clipboard": "Copy to clipboard", "com_ui_copy_url_to_clipboard": "Copy URL to clipboard", + "com_ui_copy_stack_trace": "Copy stack trace", + "com_ui_copy_thoughts_to_clipboard": "Copy thoughts to clipboard", "com_ui_create": "Create", "com_ui_create_link": "Create link", "com_ui_create_memory": "Create Memory", @@ -1222,6 +1224,7 @@ "com_ui_terms_of_service": "Terms of service", "com_ui_thinking": "Thinking...", "com_ui_thoughts": "Thoughts", + "com_ui_response": "Response", "com_ui_token": "token", "com_ui_token_exchange_method": "Token Exchange Method", "com_ui_token_url": "Token URL", diff --git a/client/src/store/fontSize.ts b/client/src/store/fontSize.ts index 4b1a0666f3..19ec56e815 100644 --- a/client/src/store/fontSize.ts +++ b/client/src/store/fontSize.ts @@ -1,54 +1,21 @@ -import { atom } from 'jotai'; -import { atomWithStorage } from 'jotai/utils'; import { applyFontSize } from '@librechat/client'; +import { createStorageAtomWithEffect, initializeFromStorage } from './jotai-utils'; const DEFAULT_FONT_SIZE = 'text-base'; /** - * Base storage atom for font size + * This atom stores the user's font size preference */ -const fontSizeStorageAtom = atomWithStorage('fontSize', DEFAULT_FONT_SIZE, undefined, { - getOnInit: true, -}); - -/** - * Derived atom that applies font size changes to the DOM - * Read: returns the current font size - * Write: updates storage and applies the font size to the DOM - */ -export const fontSizeAtom = atom( - (get) => get(fontSizeStorageAtom), - (get, set, newValue: string) => { - set(fontSizeStorageAtom, newValue); - if (typeof window !== 'undefined' && typeof document !== 'undefined') { - applyFontSize(newValue); - } - }, +export const fontSizeAtom = createStorageAtomWithEffect( + 'fontSize', + DEFAULT_FONT_SIZE, + applyFontSize, ); /** * Initialize font size on app load + * This function applies the saved font size from localStorage to the DOM */ -export const initializeFontSize = () => { - if (typeof window === 'undefined' || typeof document === 'undefined') { - return; - } - - const savedValue = localStorage.getItem('fontSize'); - - if (savedValue !== null) { - try { - const parsedValue = JSON.parse(savedValue); - applyFontSize(parsedValue); - } catch (error) { - console.error( - 'Error parsing localStorage key "fontSize", resetting to default. Error:', - error, - ); - localStorage.setItem('fontSize', JSON.stringify(DEFAULT_FONT_SIZE)); - applyFontSize(DEFAULT_FONT_SIZE); - } - } else { - applyFontSize(DEFAULT_FONT_SIZE); - } +export const initializeFontSize = (): void => { + initializeFromStorage('fontSize', DEFAULT_FONT_SIZE, applyFontSize); }; diff --git a/client/src/store/jotai-utils.ts b/client/src/store/jotai-utils.ts new file mode 100644 index 0000000000..d3ca9d817c --- /dev/null +++ b/client/src/store/jotai-utils.ts @@ -0,0 +1,88 @@ +import { atom } from 'jotai'; +import { atomWithStorage } from 'jotai/utils'; + +/** + * Create a simple atom with localStorage persistence + * Uses Jotai's atomWithStorage with getOnInit for proper SSR support + * + * @param key - localStorage key + * @param defaultValue - default value if no saved value exists + * @returns Jotai atom with localStorage persistence + */ +export function createStorageAtom(key: string, defaultValue: T) { + return atomWithStorage(key, defaultValue, undefined, { + getOnInit: true, + }); +} + +/** + * Create an atom with localStorage persistence and side effects + * Useful when you need to apply changes to the DOM or trigger other actions + * + * @param key - localStorage key + * @param defaultValue - default value if no saved value exists + * @param onWrite - callback function to run when the value changes + * @returns Jotai atom with localStorage persistence and side effects + */ +export function createStorageAtomWithEffect( + key: string, + defaultValue: T, + onWrite: (value: T) => void, +) { + const baseAtom = createStorageAtom(key, defaultValue); + + return atom( + (get) => get(baseAtom), + (get, set, newValue: T) => { + set(baseAtom, newValue); + if (typeof window !== 'undefined') { + onWrite(newValue); + } + }, + ); +} + +/** + * Initialize a value from localStorage and optionally apply it + * Useful for applying saved values on app startup (e.g., theme, fontSize) + * + * @param key - localStorage key + * @param defaultValue - default value if no saved value exists + * @param onInit - optional callback to run with the loaded value + * @returns The loaded value (or default if none exists) + */ +export function initializeFromStorage( + key: string, + defaultValue: T, + onInit?: (value: T) => void, +): T { + if (typeof window === 'undefined' || typeof localStorage === 'undefined') { + return defaultValue; + } + + try { + const savedValue = localStorage.getItem(key); + const value = savedValue ? (JSON.parse(savedValue) as T) : defaultValue; + + if (onInit) { + onInit(value); + } + + return value; + } catch (error) { + console.error(`Error initializing ${key} from localStorage, using default. Error:`, error); + + // Reset corrupted value + try { + localStorage.setItem(key, JSON.stringify(defaultValue)); + } catch (setError) { + console.error(`Error resetting corrupted ${key} in localStorage:`, setError); + } + + if (onInit) { + onInit(defaultValue); + } + + return defaultValue; + } +} diff --git a/client/src/store/showThinking.ts b/client/src/store/showThinking.ts new file mode 100644 index 0000000000..313d6d14c2 --- /dev/null +++ b/client/src/store/showThinking.ts @@ -0,0 +1,8 @@ +import { createStorageAtom } from './jotai-utils'; + +const DEFAULT_SHOW_THINKING = false; + +/** + * This atom controls whether AI reasoning/thinking content is expanded by default. + */ +export const showThinkingAtom = createStorageAtom('showThinking', DEFAULT_SHOW_THINKING); From 9b4c4cafb6a2303f5db0eebd7ab4403a874a4243 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Fri, 31 Oct 2025 13:05:12 -0400 Subject: [PATCH 004/207] =?UTF-8?q?=F0=9F=A7=A0=20refactor:=20Improve=20Re?= =?UTF-8?q?asoning=20Component=20Structure=20and=20UX=20(#10320)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor: Reasoning components with independent toggle buttons - Refactored ThinkingButton to remove unnecessary state and props. - Updated ContentParts to simplify content rendering and remove hover handling. - Improved Reasoning component to include independent toggle functionality for each THINK part. - Adjusted styles for better layout consistency and user experience. * refactor: isolate hover effects for Reasoning - Updated ThinkingButton to improve hover effects and layout consistency. - Refactored Reasoning component to include a new wrapper class for better styling. - Adjusted icon visibility and transitions for a smoother user experience. * fix: Prevent rendering of empty messages in Chat component - Added a check to skip rendering if the message text is only whitespace, improving the user interface by avoiding empty containers. * chore: Replace div with fragment in Thinking component for cleaner markup * chore: move Thinking component to Content Parts directory * refactor: prevent rendering of whitespace-only text in Part component only for edge cases --- .../Chat/Messages/Content/ContentParts.tsx | 166 ++++-------------- .../Chat/Messages/Content/MessageContent.tsx | 2 +- .../components/Chat/Messages/Content/Part.tsx | 6 +- .../Chat/Messages/Content/Parts/Reasoning.tsx | 67 +++++-- .../Messages/Content/Parts}/Thinking.tsx | 30 +--- 5 files changed, 95 insertions(+), 176 deletions(-) rename client/src/components/{Artifacts => Chat/Messages/Content/Parts}/Thinking.tsx (82%) diff --git a/client/src/components/Chat/Messages/Content/ContentParts.tsx b/client/src/components/Chat/Messages/Content/ContentParts.tsx index 157e57aa4a..14883b4b94 100644 --- a/client/src/components/Chat/Messages/Content/ContentParts.tsx +++ b/client/src/components/Chat/Messages/Content/ContentParts.tsx @@ -1,5 +1,4 @@ -import { memo, useMemo, useState, useCallback } from 'react'; -import { useAtom } from 'jotai'; +import { memo, useMemo } from 'react'; import { ContentTypes } from 'librechat-data-provider'; import type { TMessageContentParts, @@ -7,14 +6,11 @@ import type { TAttachment, Agents, } from 'librechat-data-provider'; -import { ThinkingButton } from '~/components/Artifacts/Thinking'; import { MessageContext, SearchContext } from '~/Providers'; -import { showThinkingAtom } from '~/store/showThinking'; import MemoryArtifacts from './MemoryArtifacts'; import Sources from '~/components/Web/Sources'; import { mapAttachments } from '~/utils/map'; import { EditTextPart } from './Parts'; -import { useLocalize } from '~/hooks'; import Part from './Part'; type ContentPartsProps = { @@ -52,53 +48,10 @@ const ContentParts = memo( siblingIdx, setSiblingIdx, }: ContentPartsProps) => { - const localize = useLocalize(); - const [showThinking, setShowThinking] = useAtom(showThinkingAtom); - const [isExpanded, setIsExpanded] = useState(showThinking); - const [isContentHovered, setIsContentHovered] = useState(false); const attachmentMap = useMemo(() => mapAttachments(attachments ?? []), [attachments]); const effectiveIsSubmitting = isLatestMessage ? isSubmitting : false; - const handleContentEnter = useCallback(() => setIsContentHovered(true), []); - const handleContentLeave = useCallback(() => setIsContentHovered(false), []); - - const hasReasoningParts = useMemo(() => { - const hasThinkPart = content?.some((part) => part?.type === ContentTypes.THINK) ?? false; - const allThinkPartsHaveContent = - content?.every((part) => { - if (part?.type !== ContentTypes.THINK) { - return true; - } - - if (typeof part.think === 'string') { - const cleanedContent = part.think.replace(/<\/?think>/g, '').trim(); - return cleanedContent.length > 0; - } - - return false; - }) ?? false; - - return hasThinkPart && allThinkPartsHaveContent; - }, [content]); - - // Extract all reasoning text for copy functionality - const reasoningContent = useMemo(() => { - if (!content) { - return ''; - } - return content - .filter((part) => part?.type === ContentTypes.THINK) - .map((part) => { - if (typeof part?.think === 'string') { - return part.think.replace(/<\/?think>/g, '').trim(); - } - return ''; - }) - .filter(Boolean) - .join('\n\n'); - }, [content]); - if (!content) { return null; } @@ -147,91 +100,40 @@ const ContentParts = memo( - {hasReasoningParts && ( -
-
- - setIsExpanded((prev) => { - const val = !prev; - setShowThinking(val); - return val; - }) - } - label={ - effectiveIsSubmitting && isLast - ? localize('com_ui_thinking') - : localize('com_ui_thoughts') - } - content={reasoningContent} - isContentHovered={isContentHovered} - /> -
- {content - .filter((part) => part?.type === ContentTypes.THINK) - .map((part) => { - const originalIdx = content.indexOf(part); - return ( - - - - ); - })} -
- )} - {content - .filter((part) => part && part.type !== ContentTypes.THINK) - .map((part) => { - const originalIdx = content.indexOf(part); - const toolCallId = - (part?.[ContentTypes.TOOL_CALL] as Agents.ToolCall | undefined)?.id ?? ''; - const attachments = attachmentMap[toolCallId]; + {content.map((part, idx) => { + if (!part) { + return null; + } - return ( - - - - ); - })} + const toolCallId = + (part?.[ContentTypes.TOOL_CALL] as Agents.ToolCall | undefined)?.id ?? ''; + const partAttachments = attachmentMap[toolCallId]; + + return ( + + + + ); + })}
); diff --git a/client/src/components/Chat/Messages/Content/MessageContent.tsx b/client/src/components/Chat/Messages/Content/MessageContent.tsx index 9204d3738d..7a823a07e9 100644 --- a/client/src/components/Chat/Messages/Content/MessageContent.tsx +++ b/client/src/components/Chat/Messages/Content/MessageContent.tsx @@ -4,10 +4,10 @@ import { DelayedRender } from '@librechat/client'; import type { TMessage } from 'librechat-data-provider'; import type { TMessageContentProps, TDisplayProps } from '~/common'; import Error from '~/components/Messages/Content/Error'; -import Thinking from '~/components/Artifacts/Thinking'; import { useMessageContext } from '~/Providers'; import MarkdownLite from './MarkdownLite'; import EditMessage from './EditMessage'; +import Thinking from './Parts/Thinking'; import { useLocalize } from '~/hooks'; import Container from './Container'; import Markdown from './Markdown'; diff --git a/client/src/components/Chat/Messages/Content/Part.tsx b/client/src/components/Chat/Messages/Content/Part.tsx index aa9f4da82d..b8d70f33e4 100644 --- a/client/src/components/Chat/Messages/Content/Part.tsx +++ b/client/src/components/Chat/Messages/Content/Part.tsx @@ -65,6 +65,10 @@ const Part = memo( if (part.tool_call_ids != null && !text) { return null; } + /** Skip rendering if text is only whitespace to avoid empty Container */ + if (!isLast && text.length > 0 && /^\s*$/.test(text)) { + return null; + } return ( @@ -75,7 +79,7 @@ const Part = memo( if (typeof reasoning !== 'string') { return null; } - return ; + return ; } else if (part.type === ContentTypes.TOOL_CALL) { const toolCall = part[ContentTypes.TOOL_CALL]; diff --git a/client/src/components/Chat/Messages/Content/Parts/Reasoning.tsx b/client/src/components/Chat/Messages/Content/Parts/Reasoning.tsx index 8f7da551d0..0c1d0cc944 100644 --- a/client/src/components/Chat/Messages/Content/Parts/Reasoning.tsx +++ b/client/src/components/Chat/Messages/Content/Parts/Reasoning.tsx @@ -1,11 +1,16 @@ -import { memo, useMemo } from 'react'; +import { memo, useMemo, useState, useCallback } from 'react'; +import { useAtom } from 'jotai'; +import type { MouseEvent } from 'react'; import { ContentTypes } from 'librechat-data-provider'; -import { ThinkingContent } from '~/components/Artifacts/Thinking'; +import { ThinkingContent, ThinkingButton } from './Thinking'; +import { showThinkingAtom } from '~/store/showThinking'; import { useMessageContext } from '~/Providers'; +import { useLocalize } from '~/hooks'; import { cn } from '~/utils'; type ReasoningProps = { reasoning: string; + isLast: boolean; }; /** @@ -25,13 +30,16 @@ type ReasoningProps = { * Key differences from legacy Thinking.tsx: * - Works with content parts array instead of plain text * - Strips `` tags instead of `:::thinking:::` markers - * - Uses shared ThinkingButton via ContentParts.tsx - * - Controlled by MessageContext isExpanded state + * - Each THINK part has its own independent toggle button + * - Can be interleaved with other content types * * For legacy text-based messages, see Thinking.tsx component. */ -const Reasoning = memo(({ reasoning }: ReasoningProps) => { - const { isExpanded, nextType } = useMessageContext(); +const Reasoning = memo(({ reasoning, isLast }: ReasoningProps) => { + const localize = useLocalize(); + const [showThinking] = useAtom(showThinkingAtom); + const [isExpanded, setIsExpanded] = useState(showThinking); + const { isSubmitting, isLatestMessage, nextType } = useMessageContext(); // Strip tags from the reasoning content (modern format) const reasoningText = useMemo(() => { @@ -41,24 +49,45 @@ const Reasoning = memo(({ reasoning }: ReasoningProps) => { .trim(); }, [reasoning]); + const handleClick = useCallback((e: MouseEvent) => { + e.preventDefault(); + setIsExpanded((prev) => !prev); + }, []); + + const effectiveIsSubmitting = isLatestMessage ? isSubmitting : false; + + const label = useMemo( + () => + effectiveIsSubmitting && isLast ? localize('com_ui_thinking') : localize('com_ui_thoughts'), + [effectiveIsSubmitting, localize, isLast], + ); + if (!reasoningText) { return null; } - // Note: The toggle button is rendered separately in ContentParts.tsx - // This component only handles the collapsible content area return ( -
-
- {reasoningText} +
+
+ +
+
+
+ {reasoningText} +
); diff --git a/client/src/components/Artifacts/Thinking.tsx b/client/src/components/Chat/Messages/Content/Parts/Thinking.tsx similarity index 82% rename from client/src/components/Artifacts/Thinking.tsx rename to client/src/components/Chat/Messages/Content/Parts/Thinking.tsx index 25d5810e16..f84f4edc9d 100644 --- a/client/src/components/Artifacts/Thinking.tsx +++ b/client/src/components/Chat/Messages/Content/Parts/Thinking.tsx @@ -35,25 +35,17 @@ export const ThinkingButton = memo( onClick, label, content, - isContentHovered = false, }: { isExpanded: boolean; onClick: (e: MouseEvent) => void; label: string; content?: string; - isContentHovered?: boolean; }) => { const localize = useLocalize(); const fontSize = useAtomValue(fontSizeAtom); - const [isButtonHovered, setIsButtonHovered] = useState(false); const [isCopied, setIsCopied] = useState(false); - const isHovered = useMemo( - () => isButtonHovered || isContentHovered, - [isButtonHovered, isContentHovered], - ); - const handleCopy = useCallback( (e: MouseEvent) => { e.stopPropagation(); @@ -71,23 +63,20 @@ export const ThinkingButton = memo( {content && ( @@ -132,16 +121,12 @@ const Thinking: React.ElementType = memo(({ children }: { children: React.ReactN const localize = useLocalize(); const showThinking = useAtomValue(showThinkingAtom); const [isExpanded, setIsExpanded] = useState(showThinking); - const [isContentHovered, setIsContentHovered] = useState(false); const handleClick = useCallback((e: MouseEvent) => { e.preventDefault(); setIsExpanded((prev) => !prev); }, []); - const handleContentEnter = useCallback(() => setIsContentHovered(true), []); - const handleContentLeave = useCallback(() => setIsContentHovered(false), []); - const label = useMemo(() => localize('com_ui_thoughts'), [localize]); // Extract text content for copy functionality @@ -157,14 +142,13 @@ const Thinking: React.ElementType = memo(({ children }: { children: React.ReactN } return ( -
+ <>
{children}
-
+ ); }); From 961f87cfda20a11416b5589613769eb6dc16e547 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 31 Oct 2025 14:36:32 -0400 Subject: [PATCH 005/207] =?UTF-8?q?=F0=9F=8C=8D=20i18n:=20Update=20transla?= =?UTF-8?q?tion.json=20with=20latest=20translations=20(#10323)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- client/src/locales/en/translation.json | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/client/src/locales/en/translation.json b/client/src/locales/en/translation.json index acbba5527b..d379f85032 100644 --- a/client/src/locales/en/translation.json +++ b/client/src/locales/en/translation.json @@ -787,10 +787,9 @@ "com_ui_copy_code": "Copy code", "com_ui_copy_link": "Copy link", "com_ui_copy_stack_trace": "Copy stack trace", + "com_ui_copy_thoughts_to_clipboard": "Copy thoughts to clipboard", "com_ui_copy_to_clipboard": "Copy to clipboard", "com_ui_copy_url_to_clipboard": "Copy URL to clipboard", - "com_ui_copy_stack_trace": "Copy stack trace", - "com_ui_copy_thoughts_to_clipboard": "Copy thoughts to clipboard", "com_ui_create": "Create", "com_ui_create_link": "Create link", "com_ui_create_memory": "Create Memory", @@ -1121,6 +1120,7 @@ "com_ui_reset_var": "Reset {{0}}", "com_ui_reset_zoom": "Reset Zoom", "com_ui_resource": "resource", + "com_ui_response": "Response", "com_ui_result": "Result", "com_ui_revoke": "Revoke", "com_ui_revoke_info": "Revoke all user provided credentials", @@ -1224,7 +1224,6 @@ "com_ui_terms_of_service": "Terms of service", "com_ui_thinking": "Thinking...", "com_ui_thoughts": "Thoughts", - "com_ui_response": "Response", "com_ui_token": "token", "com_ui_token_exchange_method": "Token Exchange Method", "com_ui_token_url": "Token URL", From ce7e6edad8b87367fc61340658ef08d4b74acb96 Mon Sep 17 00:00:00 2001 From: "Theo N. Truong" <644650+nhtruong@users.noreply.github.com> Date: Fri, 31 Oct 2025 13:00:21 -0600 Subject: [PATCH 006/207] =?UTF-8?q?=F0=9F=94=84=20refactor:=20MCP=20Regist?= =?UTF-8?q?ry=20System=20with=20Distributed=20Caching=20(#10191)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor: Restructure MCP registry system with caching - Split MCPServersRegistry into modular components: - MCPServerInspector: handles server inspection and health checks - MCPServersInitializer: manages server initialization logic - MCPServersRegistry: simplified registry coordination - Add distributed caching layer: - ServerConfigsCacheRedis: Redis-backed configuration cache - ServerConfigsCacheInMemory: in-memory fallback cache - RegistryStatusCache: distributed leader election state - Add promise utilities (withTimeout) replacing Promise.race patterns - Add comprehensive cache integration tests for all cache implementations - Remove unused MCPManager.getAllToolFunctions method * fix: Update OAuth flow to include user-specific headers * chore: Update Jest configuration to ignore additional test files - Added patterns to ignore files ending with .helper.ts and .helper.d.ts in testPathIgnorePatterns for cleaner test runs. * fix: oauth headers in callback * chore: Update Jest testPathIgnorePatterns to exclude helper files - Modified testPathIgnorePatterns in package.json to ignore files ending with .helper.ts and .helper.d.ts for cleaner test execution. * ci: update test mocks --------- Co-authored-by: Danny Avila --- .github/workflows/cache-integration-tests.yml | 9 + api/server/controllers/UserController.js | 12 +- api/server/controllers/mcp.js | 3 +- api/server/routes/__tests__/mcp.spec.js | 120 ++-- api/server/routes/config.js | 10 +- api/server/routes/mcp.js | 22 +- api/server/services/MCP.js | 3 +- api/server/services/MCP.spec.js | 13 +- api/server/services/initializeMCPs.js | 2 +- packages/api/jest.config.mjs | 10 +- packages/api/package.json | 5 +- packages/api/src/index.ts | 1 + packages/api/src/mcp/MCPConnectionFactory.ts | 12 +- packages/api/src/mcp/MCPManager.ts | 76 +-- packages/api/src/mcp/MCPServersRegistry.ts | 230 ------- packages/api/src/mcp/UserConnectionManager.ts | 14 +- .../api/src/mcp/__tests__/MCPManager.test.ts | 282 ++++++++- .../mcp/__tests__/MCPServersRegistry.test.ts | 595 ------------------ .../MCPServersRegistry.parsedConfigs.yml | 67 -- .../MCPServersRegistry.rawConfigs.yml | 53 -- packages/api/src/mcp/connection.ts | 13 +- .../oauth/OAuthReconnectionManager.test.ts | 37 +- .../src/mcp/oauth/OAuthReconnectionManager.ts | 5 +- .../src/mcp/registry/MCPServerInspector.ts | 123 ++++ .../src/mcp/registry/MCPServersInitializer.ts | 96 +++ .../src/mcp/registry/MCPServersRegistry.ts | 91 +++ .../__tests__/MCPServerInspector.test.ts | 338 ++++++++++ ...rversInitializer.cache_integration.spec.ts | 301 +++++++++ .../__tests__/MCPServersInitializer.test.ts | 292 +++++++++ ...PServersRegistry.cache_integration.spec.ts | 227 +++++++ .../__tests__/MCPServersRegistry.test.ts | 175 ++++++ .../__tests__/mcpConnectionsMock.helper.ts | 55 ++ .../mcp/registry/cache/BaseRegistryCache.ts | 26 + .../mcp/registry/cache/RegistryStatusCache.ts | 37 ++ .../cache/ServerConfigsCacheFactory.ts | 31 + .../cache/ServerConfigsCacheInMemory.ts | 46 ++ .../registry/cache/ServerConfigsCacheRedis.ts | 80 +++ ...istryStatusCache.cache_integration.spec.ts | 73 +++ .../ServerConfigsCacheFactory.test.ts | 70 +++ .../ServerConfigsCacheInMemory.test.ts | 173 +++++ ...onfigsCacheRedis.cache_integration.spec.ts | 278 ++++++++ packages/api/src/mcp/types/index.ts | 2 + packages/api/src/utils/index.ts | 1 + packages/api/src/utils/promise.spec.ts | 115 ++++ packages/api/src/utils/promise.ts | 42 ++ 45 files changed, 3116 insertions(+), 1150 deletions(-) delete mode 100644 packages/api/src/mcp/MCPServersRegistry.ts delete mode 100644 packages/api/src/mcp/__tests__/MCPServersRegistry.test.ts delete mode 100644 packages/api/src/mcp/__tests__/fixtures/MCPServersRegistry.parsedConfigs.yml delete mode 100644 packages/api/src/mcp/__tests__/fixtures/MCPServersRegistry.rawConfigs.yml create mode 100644 packages/api/src/mcp/registry/MCPServerInspector.ts create mode 100644 packages/api/src/mcp/registry/MCPServersInitializer.ts create mode 100644 packages/api/src/mcp/registry/MCPServersRegistry.ts create mode 100644 packages/api/src/mcp/registry/__tests__/MCPServerInspector.test.ts create mode 100644 packages/api/src/mcp/registry/__tests__/MCPServersInitializer.cache_integration.spec.ts create mode 100644 packages/api/src/mcp/registry/__tests__/MCPServersInitializer.test.ts create mode 100644 packages/api/src/mcp/registry/__tests__/MCPServersRegistry.cache_integration.spec.ts create mode 100644 packages/api/src/mcp/registry/__tests__/MCPServersRegistry.test.ts create mode 100644 packages/api/src/mcp/registry/__tests__/mcpConnectionsMock.helper.ts create mode 100644 packages/api/src/mcp/registry/cache/BaseRegistryCache.ts create mode 100644 packages/api/src/mcp/registry/cache/RegistryStatusCache.ts create mode 100644 packages/api/src/mcp/registry/cache/ServerConfigsCacheFactory.ts create mode 100644 packages/api/src/mcp/registry/cache/ServerConfigsCacheInMemory.ts create mode 100644 packages/api/src/mcp/registry/cache/ServerConfigsCacheRedis.ts create mode 100644 packages/api/src/mcp/registry/cache/__tests__/RegistryStatusCache.cache_integration.spec.ts create mode 100644 packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheFactory.test.ts create mode 100644 packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheInMemory.test.ts create mode 100644 packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedis.cache_integration.spec.ts create mode 100644 packages/api/src/utils/promise.spec.ts create mode 100644 packages/api/src/utils/promise.ts diff --git a/.github/workflows/cache-integration-tests.yml b/.github/workflows/cache-integration-tests.yml index f7ac638282..bdd3f2e83d 100644 --- a/.github/workflows/cache-integration-tests.yml +++ b/.github/workflows/cache-integration-tests.yml @@ -9,6 +9,7 @@ on: paths: - 'packages/api/src/cache/**' - 'packages/api/src/cluster/**' + - 'packages/api/src/mcp/**' - 'redis-config/**' - '.github/workflows/cache-integration-tests.yml' @@ -77,6 +78,14 @@ jobs: REDIS_URI: redis://127.0.0.1:6379 run: npm run test:cache-integration:cluster + - name: Run mcp integration tests + working-directory: packages/api + env: + NODE_ENV: test + USE_REDIS: true + REDIS_URI: redis://127.0.0.1:6379 + run: npm run test:cache-integration:mcp + - name: Stop Redis Cluster if: always() working-directory: redis-config diff --git a/api/server/controllers/UserController.js b/api/server/controllers/UserController.js index 31295387ed..b488864a93 100644 --- a/api/server/controllers/UserController.js +++ b/api/server/controllers/UserController.js @@ -28,6 +28,7 @@ const { getMCPManager, getFlowStateManager } = require('~/config'); const { getAppConfig } = require('~/server/services/Config'); const { deleteToolCalls } = require('~/models/ToolCall'); const { getLogStores } = require('~/cache'); +const { mcpServersRegistry } = require('@librechat/api'); const getUserController = async (req, res) => { const appConfig = await getAppConfig({ role: req.user?.role }); @@ -198,7 +199,7 @@ const updateUserPluginsController = async (req, res) => { // If auth was updated successfully, disconnect MCP sessions as they might use these credentials if (pluginKey.startsWith(Constants.mcp_prefix)) { try { - const mcpManager = getMCPManager(user.id); + const mcpManager = getMCPManager(); if (mcpManager) { // Extract server name from pluginKey (format: "mcp_") const serverName = pluginKey.replace(Constants.mcp_prefix, ''); @@ -295,10 +296,11 @@ const maybeUninstallOAuthMCP = async (userId, pluginKey, appConfig) => { } const serverName = pluginKey.replace(Constants.mcp_prefix, ''); - const mcpManager = getMCPManager(userId); - const serverConfig = mcpManager.getRawConfig(serverName) ?? appConfig?.mcpServers?.[serverName]; - - if (!mcpManager.getOAuthServers().has(serverName)) { + const serverConfig = + (await mcpServersRegistry.getServerConfig(serverName, userId)) ?? + appConfig?.mcpServers?.[serverName]; + const oauthServers = await mcpServersRegistry.getOAuthServers(); + if (!oauthServers.has(serverName)) { // this server does not use OAuth, so nothing to do here as well return; } diff --git a/api/server/controllers/mcp.js b/api/server/controllers/mcp.js index 9e520d392e..e113b01f17 100644 --- a/api/server/controllers/mcp.js +++ b/api/server/controllers/mcp.js @@ -10,6 +10,7 @@ const { getAppConfig, } = require('~/server/services/Config'); const { getMCPManager } = require('~/config'); +const { mcpServersRegistry } = require('@librechat/api'); /** * Get all MCP tools available to the user @@ -65,7 +66,7 @@ const getMCPTools = async (req, res) => { // Get server config once const serverConfig = appConfig.mcpConfig[serverName]; - const rawServerConfig = mcpManager.getRawConfig(serverName); + const rawServerConfig = await mcpServersRegistry.getServerConfig(serverName, userId); // Initialize server object with all server-level data const server = { diff --git a/api/server/routes/__tests__/mcp.spec.js b/api/server/routes/__tests__/mcp.spec.js index 8ae92cdd3d..43e086f7b3 100644 --- a/api/server/routes/__tests__/mcp.spec.js +++ b/api/server/routes/__tests__/mcp.spec.js @@ -15,6 +15,10 @@ jest.mock('@librechat/api', () => ({ storeTokens: jest.fn(), }, getUserMCPAuthMap: jest.fn(), + mcpServersRegistry: { + getServerConfig: jest.fn(), + getOAuthServers: jest.fn(), + }, })); jest.mock('@librechat/data-schemas', () => ({ @@ -115,7 +119,7 @@ describe('MCP Routes', () => { }); describe('GET /:serverName/oauth/initiate', () => { - const { MCPOAuthHandler } = require('@librechat/api'); + const { MCPOAuthHandler, mcpServersRegistry } = require('@librechat/api'); const { getLogStores } = require('~/cache'); it('should initiate OAuth flow successfully', async () => { @@ -128,13 +132,9 @@ describe('MCP Routes', () => { }), }; - const mockMcpManager = { - getRawConfig: jest.fn().mockReturnValue({}), - }; - getLogStores.mockReturnValue({}); require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); - require('~/config').getMCPManager.mockReturnValue(mockMcpManager); + mcpServersRegistry.getServerConfig.mockResolvedValue({}); MCPOAuthHandler.initiateOAuthFlow.mockResolvedValue({ authorizationUrl: 'https://oauth.example.com/auth', @@ -288,6 +288,7 @@ describe('MCP Routes', () => { }); it('should handle OAuth callback successfully', async () => { + const { mcpServersRegistry } = require('@librechat/api'); const mockFlowManager = { completeFlow: jest.fn().mockResolvedValue(), deleteFlow: jest.fn().mockResolvedValue(true), @@ -307,6 +308,7 @@ describe('MCP Routes', () => { MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState); MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens); MCPTokenStorage.storeTokens.mockResolvedValue(); + mcpServersRegistry.getServerConfig.mockResolvedValue({}); getLogStores.mockReturnValue({}); require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); @@ -321,7 +323,6 @@ describe('MCP Routes', () => { }; const mockMcpManager = { getUserConnection: jest.fn().mockResolvedValue(mockUserConnection), - getRawConfig: jest.fn().mockReturnValue({}), }; require('~/config').getMCPManager.mockReturnValue(mockMcpManager); @@ -379,6 +380,7 @@ describe('MCP Routes', () => { }); it('should handle system-level OAuth completion', async () => { + const { mcpServersRegistry } = require('@librechat/api'); const mockFlowManager = { completeFlow: jest.fn().mockResolvedValue(), deleteFlow: jest.fn().mockResolvedValue(true), @@ -398,14 +400,10 @@ describe('MCP Routes', () => { MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState); MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens); MCPTokenStorage.storeTokens.mockResolvedValue(); + mcpServersRegistry.getServerConfig.mockResolvedValue({}); getLogStores.mockReturnValue({}); require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); - const mockMcpManager = { - getRawConfig: jest.fn().mockReturnValue({}), - }; - require('~/config').getMCPManager.mockReturnValue(mockMcpManager); - const response = await request(app).get('/api/mcp/test-server/oauth/callback').query({ code: 'test-auth-code', state: 'test-flow-id', @@ -417,6 +415,7 @@ describe('MCP Routes', () => { }); it('should handle reconnection failure after OAuth', async () => { + const { mcpServersRegistry } = require('@librechat/api'); const mockFlowManager = { completeFlow: jest.fn().mockResolvedValue(), deleteFlow: jest.fn().mockResolvedValue(true), @@ -436,12 +435,12 @@ describe('MCP Routes', () => { MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState); MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens); MCPTokenStorage.storeTokens.mockResolvedValue(); + mcpServersRegistry.getServerConfig.mockResolvedValue({}); getLogStores.mockReturnValue({}); require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); const mockMcpManager = { getUserConnection: jest.fn().mockRejectedValue(new Error('Reconnection failed')), - getRawConfig: jest.fn().mockReturnValue({}), }; require('~/config').getMCPManager.mockReturnValue(mockMcpManager); @@ -461,6 +460,7 @@ describe('MCP Routes', () => { }); it('should redirect to error page if token storage fails', async () => { + const { mcpServersRegistry } = require('@librechat/api'); const mockFlowManager = { completeFlow: jest.fn().mockResolvedValue(), deleteFlow: jest.fn().mockResolvedValue(true), @@ -480,6 +480,7 @@ describe('MCP Routes', () => { MCPOAuthHandler.getFlowState.mockResolvedValue(mockFlowState); MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens); MCPTokenStorage.storeTokens.mockRejectedValue(new Error('store failed')); + mcpServersRegistry.getServerConfig.mockResolvedValue({}); getLogStores.mockReturnValue({}); require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager); @@ -730,12 +731,14 @@ describe('MCP Routes', () => { }); describe('POST /:serverName/reinitialize', () => { + const { mcpServersRegistry } = require('@librechat/api'); + it('should return 404 when server is not found in configuration', async () => { const mockMcpManager = { - getRawConfig: jest.fn().mockReturnValue(null), disconnectUserConnection: jest.fn().mockResolvedValue(), }; + mcpServersRegistry.getServerConfig.mockResolvedValue(null); require('~/config').getMCPManager.mockReturnValue(mockMcpManager); require('~/config').getFlowStateManager.mockReturnValue({}); require('~/cache').getLogStores.mockReturnValue({}); @@ -750,9 +753,6 @@ describe('MCP Routes', () => { it('should handle OAuth requirement during reinitialize', async () => { const mockMcpManager = { - getRawConfig: jest.fn().mockReturnValue({ - customUserVars: {}, - }), disconnectUserConnection: jest.fn().mockResolvedValue(), mcpConfigs: {}, getUserConnection: jest.fn().mockImplementation(async ({ oauthStart }) => { @@ -763,6 +763,9 @@ describe('MCP Routes', () => { }), }; + mcpServersRegistry.getServerConfig.mockResolvedValue({ + customUserVars: {}, + }); require('~/config').getMCPManager.mockReturnValue(mockMcpManager); require('~/config').getFlowStateManager.mockReturnValue({}); require('~/cache').getLogStores.mockReturnValue({}); @@ -788,12 +791,12 @@ describe('MCP Routes', () => { it('should return 500 when reinitialize fails with non-OAuth error', async () => { const mockMcpManager = { - getRawConfig: jest.fn().mockReturnValue({}), disconnectUserConnection: jest.fn().mockResolvedValue(), mcpConfigs: {}, getUserConnection: jest.fn().mockRejectedValue(new Error('Connection failed')), }; + mcpServersRegistry.getServerConfig.mockResolvedValue({}); require('~/config').getMCPManager.mockReturnValue(mockMcpManager); require('~/config').getFlowStateManager.mockReturnValue({}); require('~/cache').getLogStores.mockReturnValue({}); @@ -809,11 +812,12 @@ describe('MCP Routes', () => { it('should return 500 when unexpected error occurs', async () => { const mockMcpManager = { - getRawConfig: jest.fn().mockImplementation(() => { - throw new Error('Config loading failed'); - }), + disconnectUserConnection: jest.fn(), }; + mcpServersRegistry.getServerConfig.mockImplementation(() => { + throw new Error('Config loading failed'); + }); require('~/config').getMCPManager.mockReturnValue(mockMcpManager); const response = await request(app).post('/api/mcp/test-server/reinitialize'); @@ -846,11 +850,11 @@ describe('MCP Routes', () => { }; const mockMcpManager = { - getRawConfig: jest.fn().mockReturnValue({ endpoint: 'http://test-server.com' }), disconnectUserConnection: jest.fn().mockResolvedValue(), getUserConnection: jest.fn().mockResolvedValue(mockUserConnection), }; + mcpServersRegistry.getServerConfig.mockResolvedValue({ endpoint: 'http://test-server.com' }); require('~/config').getMCPManager.mockReturnValue(mockMcpManager); require('~/config').getFlowStateManager.mockReturnValue({}); require('~/cache').getLogStores.mockReturnValue({}); @@ -891,16 +895,16 @@ describe('MCP Routes', () => { }; const mockMcpManager = { - getRawConfig: jest.fn().mockReturnValue({ - endpoint: 'http://test-server.com', - customUserVars: { - API_KEY: 'some-env-var', - }, - }), disconnectUserConnection: jest.fn().mockResolvedValue(), getUserConnection: jest.fn().mockResolvedValue(mockUserConnection), }; + mcpServersRegistry.getServerConfig.mockResolvedValue({ + endpoint: 'http://test-server.com', + customUserVars: { + API_KEY: 'some-env-var', + }, + }); require('~/config').getMCPManager.mockReturnValue(mockMcpManager); require('~/config').getFlowStateManager.mockReturnValue({}); require('~/cache').getLogStores.mockReturnValue({}); @@ -1105,17 +1109,17 @@ describe('MCP Routes', () => { describe('GET /:serverName/auth-values', () => { const { getUserPluginAuthValue } = require('~/server/services/PluginService'); + const { mcpServersRegistry } = require('@librechat/api'); it('should return auth value flags for server', async () => { - const mockMcpManager = { - getRawConfig: jest.fn().mockReturnValue({ - customUserVars: { - API_KEY: 'some-env-var', - SECRET_TOKEN: 'another-env-var', - }, - }), - }; + const mockMcpManager = {}; + mcpServersRegistry.getServerConfig.mockResolvedValue({ + customUserVars: { + API_KEY: 'some-env-var', + SECRET_TOKEN: 'another-env-var', + }, + }); require('~/config').getMCPManager.mockReturnValue(mockMcpManager); getUserPluginAuthValue.mockResolvedValueOnce('some-api-key-value').mockResolvedValueOnce(''); @@ -1135,10 +1139,9 @@ describe('MCP Routes', () => { }); it('should return 404 when server is not found in configuration', async () => { - const mockMcpManager = { - getRawConfig: jest.fn().mockReturnValue(null), - }; + const mockMcpManager = {}; + mcpServersRegistry.getServerConfig.mockResolvedValue(null); require('~/config').getMCPManager.mockReturnValue(mockMcpManager); const response = await request(app).get('/api/mcp/non-existent-server/auth-values'); @@ -1150,14 +1153,13 @@ describe('MCP Routes', () => { }); it('should handle errors when checking auth values', async () => { - const mockMcpManager = { - getRawConfig: jest.fn().mockReturnValue({ - customUserVars: { - API_KEY: 'some-env-var', - }, - }), - }; + const mockMcpManager = {}; + mcpServersRegistry.getServerConfig.mockResolvedValue({ + customUserVars: { + API_KEY: 'some-env-var', + }, + }); require('~/config').getMCPManager.mockReturnValue(mockMcpManager); getUserPluginAuthValue.mockRejectedValue(new Error('Database error')); @@ -1174,12 +1176,11 @@ describe('MCP Routes', () => { }); it('should return 500 when auth values check throws unexpected error', async () => { - const mockMcpManager = { - getRawConfig: jest.fn().mockImplementation(() => { - throw new Error('Config loading failed'); - }), - }; + const mockMcpManager = {}; + mcpServersRegistry.getServerConfig.mockImplementation(() => { + throw new Error('Config loading failed'); + }); require('~/config').getMCPManager.mockReturnValue(mockMcpManager); const response = await request(app).get('/api/mcp/test-server/auth-values'); @@ -1189,12 +1190,11 @@ describe('MCP Routes', () => { }); it('should handle customUserVars that is not an object', async () => { - const mockMcpManager = { - getRawConfig: jest.fn().mockReturnValue({ - customUserVars: 'not-an-object', - }), - }; + const mockMcpManager = {}; + mcpServersRegistry.getServerConfig.mockResolvedValue({ + customUserVars: 'not-an-object', + }); require('~/config').getMCPManager.mockReturnValue(mockMcpManager); const response = await request(app).get('/api/mcp/test-server/auth-values'); @@ -1221,7 +1221,7 @@ describe('MCP Routes', () => { describe('GET /:serverName/oauth/callback - Edge Cases', () => { it('should handle OAuth callback without toolFlowId (falsy toolFlowId)', async () => { - const { MCPOAuthHandler, MCPTokenStorage } = require('@librechat/api'); + const { MCPOAuthHandler, MCPTokenStorage, mcpServersRegistry } = require('@librechat/api'); const mockTokens = { access_token: 'edge-access-token', refresh_token: 'edge-refresh-token', @@ -1239,6 +1239,7 @@ describe('MCP Routes', () => { }); MCPOAuthHandler.completeOAuthFlow = jest.fn().mockResolvedValue(mockTokens); MCPTokenStorage.storeTokens.mockResolvedValue(); + mcpServersRegistry.getServerConfig.mockResolvedValue({}); const mockFlowManager = { completeFlow: jest.fn(), @@ -1249,7 +1250,6 @@ describe('MCP Routes', () => { getUserConnection: jest.fn().mockResolvedValue({ fetchTools: jest.fn().mockResolvedValue([]), }), - getRawConfig: jest.fn().mockReturnValue({}), }; require('~/config').getMCPManager.mockReturnValue(mockMcpManager); @@ -1264,7 +1264,7 @@ describe('MCP Routes', () => { it('should handle null cached tools in OAuth callback (triggers || {} fallback)', async () => { const { getCachedTools } = require('~/server/services/Config'); getCachedTools.mockResolvedValue(null); - const { MCPOAuthHandler, MCPTokenStorage } = require('@librechat/api'); + const { MCPOAuthHandler, MCPTokenStorage, mcpServersRegistry } = require('@librechat/api'); const mockTokens = { access_token: 'edge-access-token', refresh_token: 'edge-refresh-token', @@ -1290,6 +1290,7 @@ describe('MCP Routes', () => { }); MCPOAuthHandler.completeOAuthFlow.mockResolvedValue(mockTokens); MCPTokenStorage.storeTokens.mockResolvedValue(); + mcpServersRegistry.getServerConfig.mockResolvedValue({}); const mockMcpManager = { getUserConnection: jest.fn().mockResolvedValue({ @@ -1297,7 +1298,6 @@ describe('MCP Routes', () => { .fn() .mockResolvedValue([{ name: 'test-tool', description: 'Test tool' }]), }), - getRawConfig: jest.fn().mockReturnValue({}), }; require('~/config').getMCPManager.mockReturnValue(mockMcpManager); diff --git a/api/server/routes/config.js b/api/server/routes/config.js index bae5f764b0..f1d2332047 100644 --- a/api/server/routes/config.js +++ b/api/server/routes/config.js @@ -12,6 +12,7 @@ const { getAppConfig } = require('~/server/services/Config/app'); const { getProjectByName } = require('~/models/Project'); const { getMCPManager } = require('~/config'); const { getLogStores } = require('~/cache'); +const { mcpServersRegistry } = require('@librechat/api'); const router = express.Router(); const emailLoginEnabled = @@ -125,7 +126,7 @@ router.get('/', async function (req, res) { payload.minPasswordLength = minPasswordLength; } - const getMCPServers = () => { + const getMCPServers = async () => { try { if (appConfig?.mcpConfig == null) { return; @@ -134,9 +135,8 @@ router.get('/', async function (req, res) { if (!mcpManager) { return; } - const mcpServers = mcpManager.getAllServers(); + const mcpServers = await mcpServersRegistry.getAllServerConfigs(); if (!mcpServers) return; - const oauthServers = mcpManager.getOAuthServers(); for (const serverName in mcpServers) { if (!payload.mcpServers) { payload.mcpServers = {}; @@ -145,7 +145,7 @@ router.get('/', async function (req, res) { payload.mcpServers[serverName] = removeNullishValues({ startup: serverConfig?.startup, chatMenu: serverConfig?.chatMenu, - isOAuth: oauthServers?.has(serverName), + isOAuth: serverConfig.requiresOAuth, customUserVars: serverConfig?.customUserVars, }); } @@ -154,7 +154,7 @@ router.get('/', async function (req, res) { } }; - getMCPServers(); + await getMCPServers(); const webSearchConfig = appConfig?.webSearch; if ( webSearchConfig != null && diff --git a/api/server/routes/mcp.js b/api/server/routes/mcp.js index 9b66b10e52..8d6d91e8d9 100644 --- a/api/server/routes/mcp.js +++ b/api/server/routes/mcp.js @@ -6,6 +6,7 @@ const { MCPOAuthHandler, MCPTokenStorage, getUserMCPAuthMap, + mcpServersRegistry, } = require('@librechat/api'); const { getMCPManager, getFlowStateManager, getOAuthReconnectionManager } = require('~/config'); const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP'); @@ -61,11 +62,12 @@ router.get('/:serverName/oauth/initiate', requireJwtAuth, async (req, res) => { return res.status(400).json({ error: 'Invalid flow state' }); } + const oauthHeaders = await getOAuthHeaders(serverName, userId); const { authorizationUrl, flowId: oauthFlowId } = await MCPOAuthHandler.initiateOAuthFlow( serverName, serverUrl, userId, - getOAuthHeaders(serverName), + oauthHeaders, oauthConfig, ); @@ -133,12 +135,8 @@ router.get('/:serverName/oauth/callback', async (req, res) => { }); logger.debug('[MCP OAuth] Completing OAuth flow'); - const tokens = await MCPOAuthHandler.completeOAuthFlow( - flowId, - code, - flowManager, - getOAuthHeaders(serverName), - ); + const oauthHeaders = await getOAuthHeaders(serverName, flowState.userId); + const tokens = await MCPOAuthHandler.completeOAuthFlow(flowId, code, flowManager, oauthHeaders); logger.info('[MCP OAuth] OAuth flow completed, tokens received in callback route'); /** Persist tokens immediately so reconnection uses fresh credentials */ @@ -356,7 +354,7 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => { logger.info(`[MCP Reinitialize] Reinitializing server: ${serverName}`); const mcpManager = getMCPManager(); - const serverConfig = mcpManager.getRawConfig(serverName); + const serverConfig = await mcpServersRegistry.getServerConfig(serverName, user.id); if (!serverConfig) { return res.status(404).json({ error: `MCP server '${serverName}' not found in configuration`, @@ -505,8 +503,7 @@ router.get('/:serverName/auth-values', requireJwtAuth, async (req, res) => { return res.status(401).json({ error: 'User not authenticated' }); } - const mcpManager = getMCPManager(); - const serverConfig = mcpManager.getRawConfig(serverName); + const serverConfig = await mcpServersRegistry.getServerConfig(serverName, user.id); if (!serverConfig) { return res.status(404).json({ error: `MCP server '${serverName}' not found in configuration`, @@ -545,9 +542,8 @@ router.get('/:serverName/auth-values', requireJwtAuth, async (req, res) => { } }); -function getOAuthHeaders(serverName) { - const mcpManager = getMCPManager(); - const serverConfig = mcpManager.getRawConfig(serverName); +async function getOAuthHeaders(serverName, userId) { + const serverConfig = await mcpServersRegistry.getServerConfig(serverName, userId); return serverConfig?.oauth_headers ?? {}; } diff --git a/api/server/services/MCP.js b/api/server/services/MCP.js index b7975b12fa..e91e5e7904 100644 --- a/api/server/services/MCP.js +++ b/api/server/services/MCP.js @@ -25,6 +25,7 @@ const { findToken, createToken, updateToken } = require('~/models'); const { reinitMCPServer } = require('./Tools/mcp'); const { getAppConfig } = require('./Config'); const { getLogStores } = require('~/cache'); +const { mcpServersRegistry } = require('@librechat/api'); /** * @param {object} params @@ -450,7 +451,7 @@ async function getMCPSetupData(userId) { logger.error(`[MCP][User: ${userId}] Error getting app connections:`, error); } const userConnections = mcpManager.getUserConnections(userId) || new Map(); - const oauthServers = mcpManager.getOAuthServers(); + const oauthServers = await mcpServersRegistry.getOAuthServers(); return { mcpConfig, diff --git a/api/server/services/MCP.spec.js b/api/server/services/MCP.spec.js index 7b192995e3..18857c4893 100644 --- a/api/server/services/MCP.spec.js +++ b/api/server/services/MCP.spec.js @@ -50,6 +50,9 @@ jest.mock('@librechat/api', () => ({ sendEvent: jest.fn(), normalizeServerName: jest.fn((name) => name), convertWithResolvedRefs: jest.fn((params) => params), + mcpServersRegistry: { + getOAuthServers: jest.fn(() => Promise.resolve(new Set())), + }, })); jest.mock('librechat-data-provider', () => ({ @@ -100,6 +103,7 @@ describe('tests for the new helper functions used by the MCP connection status e let mockGetFlowStateManager; let mockGetLogStores; let mockGetOAuthReconnectionManager; + let mockMcpServersRegistry; beforeEach(() => { jest.clearAllMocks(); @@ -108,6 +112,7 @@ describe('tests for the new helper functions used by the MCP connection status e mockGetFlowStateManager = require('~/config').getFlowStateManager; mockGetLogStores = require('~/cache').getLogStores; mockGetOAuthReconnectionManager = require('~/config').getOAuthReconnectionManager; + mockMcpServersRegistry = require('@librechat/api').mcpServersRegistry; }); describe('getMCPSetupData', () => { @@ -125,8 +130,8 @@ describe('tests for the new helper functions used by the MCP connection status e mockGetMCPManager.mockReturnValue({ appConnections: { getAll: jest.fn(() => new Map()) }, getUserConnections: jest.fn(() => new Map()), - getOAuthServers: jest.fn(() => new Set()), }); + mockMcpServersRegistry.getOAuthServers.mockResolvedValue(new Set()); }); it('should successfully return MCP setup data', async () => { @@ -139,9 +144,9 @@ describe('tests for the new helper functions used by the MCP connection status e const mockMCPManager = { appConnections: { getAll: jest.fn(() => mockAppConnections) }, getUserConnections: jest.fn(() => mockUserConnections), - getOAuthServers: jest.fn(() => mockOAuthServers), }; mockGetMCPManager.mockReturnValue(mockMCPManager); + mockMcpServersRegistry.getOAuthServers.mockResolvedValue(mockOAuthServers); const result = await getMCPSetupData(mockUserId); @@ -149,7 +154,7 @@ describe('tests for the new helper functions used by the MCP connection status e expect(mockGetMCPManager).toHaveBeenCalledWith(mockUserId); expect(mockMCPManager.appConnections.getAll).toHaveBeenCalled(); expect(mockMCPManager.getUserConnections).toHaveBeenCalledWith(mockUserId); - expect(mockMCPManager.getOAuthServers).toHaveBeenCalled(); + expect(mockMcpServersRegistry.getOAuthServers).toHaveBeenCalled(); expect(result).toEqual({ mcpConfig: mockConfig.mcpServers, @@ -170,9 +175,9 @@ describe('tests for the new helper functions used by the MCP connection status e const mockMCPManager = { appConnections: { getAll: jest.fn(() => null) }, getUserConnections: jest.fn(() => null), - getOAuthServers: jest.fn(() => new Set()), }; mockGetMCPManager.mockReturnValue(mockMCPManager); + mockMcpServersRegistry.getOAuthServers.mockResolvedValue(new Set()); const result = await getMCPSetupData(mockUserId); diff --git a/api/server/services/initializeMCPs.js b/api/server/services/initializeMCPs.js index 397fc85202..7fdb128683 100644 --- a/api/server/services/initializeMCPs.js +++ b/api/server/services/initializeMCPs.js @@ -15,7 +15,7 @@ async function initializeMCPs() { const mcpManager = await createMCPManager(mcpServers); try { - const mcpTools = mcpManager.getAppToolFunctions() || {}; + const mcpTools = (await mcpManager.getAppToolFunctions()) || {}; await mergeAppTools(mcpTools); logger.info( diff --git a/packages/api/jest.config.mjs b/packages/api/jest.config.mjs index 1533a3d213..10fa4554e4 100644 --- a/packages/api/jest.config.mjs +++ b/packages/api/jest.config.mjs @@ -1,7 +1,13 @@ export default { collectCoverageFrom: ['src/**/*.{js,jsx,ts,tsx}', '!/node_modules/'], coveragePathIgnorePatterns: ['/node_modules/', '/dist/'], - testPathIgnorePatterns: ['/node_modules/', '/dist/', '\\.dev\\.ts$'], + testPathIgnorePatterns: [ + '/node_modules/', + '/dist/', + '\\.dev\\.ts$', + '\\.helper\\.ts$', + '\\.helper\\.d\\.ts$', + ], coverageReporters: ['text', 'cobertura'], testResultsProcessor: 'jest-junit', moduleNameMapper: { @@ -18,4 +24,4 @@ export default { // }, restoreMocks: true, testTimeout: 15000, -}; \ No newline at end of file +}; diff --git a/packages/api/package.json b/packages/api/package.json index 4d333082a3..86c2d3f42a 100644 --- a/packages/api/package.json +++ b/packages/api/package.json @@ -18,10 +18,11 @@ "build:dev": "npm run clean && NODE_ENV=development rollup -c --bundleConfigAsCjs", "build:watch": "NODE_ENV=development rollup -c -w --bundleConfigAsCjs", "build:watch:prod": "rollup -c -w --bundleConfigAsCjs", - "test": "jest --coverage --watch --testPathIgnorePatterns=\"\\.*integration\\.\"", - "test:ci": "jest --coverage --ci --testPathIgnorePatterns=\"\\.*integration\\.\"", + "test": "jest --coverage --watch --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.\"", + "test:ci": "jest --coverage --ci --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.\"", "test:cache-integration:core": "jest --testPathPattern=\"src/cache/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false", "test:cache-integration:cluster": "jest --testPathPattern=\"src/cluster/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false --runInBand", + "test:cache-integration:mcp": "jest --testPathPattern=\"src/mcp/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false", "verify": "npm run test:ci", "b:clean": "bun run rimraf dist", "b:build": "bun run b:clean && bun run rollup -c --silent --bundleConfigAsCjs", diff --git a/packages/api/src/index.ts b/packages/api/src/index.ts index e839a335a4..02d09797d3 100644 --- a/packages/api/src/index.ts +++ b/packages/api/src/index.ts @@ -3,6 +3,7 @@ export * from './cdn'; /* Auth */ export * from './auth'; /* MCP */ +export * from './mcp/registry/MCPServersRegistry'; export * from './mcp/MCPManager'; export * from './mcp/connection'; export * from './mcp/oauth'; diff --git a/packages/api/src/mcp/MCPConnectionFactory.ts b/packages/api/src/mcp/MCPConnectionFactory.ts index 5f4447b2bd..4425788cc9 100644 --- a/packages/api/src/mcp/MCPConnectionFactory.ts +++ b/packages/api/src/mcp/MCPConnectionFactory.ts @@ -9,6 +9,7 @@ import { MCPTokenStorage, MCPOAuthHandler } from '~/mcp/oauth'; import { sanitizeUrlForLogging } from './utils'; import { MCPConnection } from './connection'; import { processMCPEnv } from '~/utils'; +import { withTimeout } from '~/utils/promise'; /** * Factory for creating MCP connections with optional OAuth authentication. @@ -231,14 +232,11 @@ export class MCPConnectionFactory { /** Attempts to establish connection with timeout handling */ protected async attemptToConnect(connection: MCPConnection): Promise { const connectTimeout = this.connectionTimeout ?? this.serverConfig.initTimeout ?? 30000; - const connectionTimeout = new Promise((_, reject) => - setTimeout( - () => reject(new Error(`Connection timeout after ${connectTimeout}ms`)), - connectTimeout, - ), + await withTimeout( + this.connectTo(connection), + connectTimeout, + `Connection timeout after ${connectTimeout}ms`, ); - const connectionAttempt = this.connectTo(connection); - await Promise.race([connectionAttempt, connectionTimeout]); if (await connection.isConnected()) return; logger.error(`${this.logPrefix} Failed to establish connection.`); diff --git a/packages/api/src/mcp/MCPManager.ts b/packages/api/src/mcp/MCPManager.ts index c6bfe77b8f..1e0d483f17 100644 --- a/packages/api/src/mcp/MCPManager.ts +++ b/packages/api/src/mcp/MCPManager.ts @@ -5,11 +5,14 @@ import type { RequestOptions } from '@modelcontextprotocol/sdk/shared/protocol.j import type { TokenMethods } from '@librechat/data-schemas'; import type { FlowStateManager } from '~/flow/manager'; import type { TUser } from 'librechat-data-provider'; -import type { MCPOAuthTokens } from '~/mcp/oauth'; +import type { MCPOAuthTokens } from './oauth'; import type { RequestBody } from '~/types'; import type * as t from './types'; -import { UserConnectionManager } from '~/mcp/UserConnectionManager'; -import { ConnectionsRepository } from '~/mcp/ConnectionsRepository'; +import { UserConnectionManager } from './UserConnectionManager'; +import { ConnectionsRepository } from './ConnectionsRepository'; +import { MCPServerInspector } from './registry/MCPServerInspector'; +import { MCPServersInitializer } from './registry/MCPServersInitializer'; +import { mcpServersRegistry as registry } from './registry/MCPServersRegistry'; import { formatToolContent } from './parsers'; import { MCPConnection } from './connection'; import { processMCPEnv } from '~/utils/env'; @@ -24,8 +27,8 @@ export class MCPManager extends UserConnectionManager { /** Creates and initializes the singleton MCPManager instance */ public static async createInstance(configs: t.MCPServers): Promise { if (MCPManager.instance) throw new Error('MCPManager has already been initialized.'); - MCPManager.instance = new MCPManager(configs); - await MCPManager.instance.initialize(); + MCPManager.instance = new MCPManager(); + await MCPManager.instance.initialize(configs); return MCPManager.instance; } @@ -36,9 +39,10 @@ export class MCPManager extends UserConnectionManager { } /** Initializes the MCPManager by setting up server registry and app connections */ - public async initialize() { - await this.serversRegistry.initialize(); - this.appConnections = new ConnectionsRepository(this.serversRegistry.appServerConfigs); + public async initialize(configs: t.MCPServers) { + await MCPServersInitializer.initialize(configs); + const appConfigs = await registry.sharedAppServers.getAll(); + this.appConnections = new ConnectionsRepository(appConfigs); } /** Retrieves an app-level or user-specific connection based on provided arguments */ @@ -62,36 +66,18 @@ export class MCPManager extends UserConnectionManager { } } - /** Get servers that require OAuth */ - public getOAuthServers(): Set { - return this.serversRegistry.oauthServers; - } - - /** Get all servers */ - public getAllServers(): t.MCPServers { - return this.serversRegistry.rawConfigs; - } - /** Returns all available tool functions from app-level connections */ - public getAppToolFunctions(): t.LCAvailableTools { - return this.serversRegistry.toolFunctions; + public async getAppToolFunctions(): Promise { + const toolFunctions: t.LCAvailableTools = {}; + const configs = await registry.getAllServerConfigs(); + for (const config of Object.values(configs)) { + if (config.toolFunctions != null) { + Object.assign(toolFunctions, config.toolFunctions); + } + } + return toolFunctions; } - /** Returns all available tool functions from all connections available to user */ - public async getAllToolFunctions(userId: string): Promise { - const allToolFunctions: t.LCAvailableTools = this.getAppToolFunctions(); - const userConnections = this.getUserConnections(userId); - if (!userConnections || userConnections.size === 0) { - return allToolFunctions; - } - - for (const [serverName, connection] of userConnections.entries()) { - const toolFunctions = await this.serversRegistry.getToolFunctions(serverName, connection); - Object.assign(allToolFunctions, toolFunctions); - } - - return allToolFunctions; - } /** Returns all available tool functions from all connections available to user */ public async getServerToolFunctions( userId: string, @@ -99,7 +85,7 @@ export class MCPManager extends UserConnectionManager { ): Promise { try { if (this.appConnections?.has(serverName)) { - return this.serversRegistry.getToolFunctions( + return MCPServerInspector.getToolFunctions( serverName, await this.appConnections.get(serverName), ); @@ -113,7 +99,7 @@ export class MCPManager extends UserConnectionManager { return null; } - return this.serversRegistry.getToolFunctions(serverName, userConnections.get(serverName)!); + return MCPServerInspector.getToolFunctions(serverName, userConnections.get(serverName)!); } catch (error) { logger.warn( `[getServerToolFunctions] Error getting tool functions for server ${serverName}`, @@ -128,8 +114,14 @@ export class MCPManager extends UserConnectionManager { * @param serverNames Optional array of server names. If not provided or empty, returns all servers. * @returns Object mapping server names to their instructions */ - public getInstructions(serverNames?: string[]): Record { - const instructions = this.serversRegistry.serverInstructions; + private async getInstructions(serverNames?: string[]): Promise> { + const instructions: Record = {}; + const configs = await registry.getAllServerConfigs(); + for (const [serverName, config] of Object.entries(configs)) { + if (config.serverInstructions != null) { + instructions[serverName] = config.serverInstructions as string; + } + } if (!serverNames) return instructions; return pick(instructions, serverNames); } @@ -139,9 +131,9 @@ export class MCPManager extends UserConnectionManager { * @param serverNames Optional array of server names to include. If not provided, includes all servers. * @returns Formatted instructions string ready for context injection */ - public formatInstructionsForContext(serverNames?: string[]): string { + public async formatInstructionsForContext(serverNames?: string[]): Promise { /** Instructions for specified servers or all stored instructions */ - const instructionsToInclude = this.getInstructions(serverNames); + const instructionsToInclude = await this.getInstructions(serverNames); if (Object.keys(instructionsToInclude).length === 0) { return ''; @@ -225,7 +217,7 @@ Please follow these instructions when using tools from the respective MCP server ); } - const rawConfig = this.getRawConfig(serverName) as t.MCPOptions; + const rawConfig = (await registry.getServerConfig(serverName, userId)) as t.MCPOptions; const currentOptions = processMCPEnv({ user, options: rawConfig, diff --git a/packages/api/src/mcp/MCPServersRegistry.ts b/packages/api/src/mcp/MCPServersRegistry.ts deleted file mode 100644 index 668ad7d2c0..0000000000 --- a/packages/api/src/mcp/MCPServersRegistry.ts +++ /dev/null @@ -1,230 +0,0 @@ -import mapValues from 'lodash/mapValues'; -import { logger } from '@librechat/data-schemas'; -import { Constants } from 'librechat-data-provider'; -import type { JsonSchemaType } from '@librechat/data-schemas'; -import type { MCPConnection } from '~/mcp/connection'; -import type * as t from '~/mcp/types'; -import { ConnectionsRepository } from '~/mcp/ConnectionsRepository'; -import { detectOAuthRequirement } from '~/mcp/oauth'; -import { sanitizeUrlForLogging } from '~/mcp/utils'; -import { processMCPEnv, isEnabled } from '~/utils'; - -const DEFAULT_MCP_INIT_TIMEOUT_MS = 30_000; - -function getMCPInitTimeout(): number { - return process.env.MCP_INIT_TIMEOUT_MS != null - ? parseInt(process.env.MCP_INIT_TIMEOUT_MS) - : DEFAULT_MCP_INIT_TIMEOUT_MS; -} - -/** - * Manages MCP server configurations and metadata discovery. - * Fetches server capabilities, OAuth requirements, and tool definitions for registry. - * Determines which servers are for app-level connections. - * Has its own connections repository. All connections are disconnected after initialization. - */ -export class MCPServersRegistry { - private initialized: boolean = false; - private connections: ConnectionsRepository; - private initTimeoutMs: number; - - public readonly rawConfigs: t.MCPServers; - public readonly parsedConfigs: Record; - - public oauthServers: Set = new Set(); - public serverInstructions: Record = {}; - public toolFunctions: t.LCAvailableTools = {}; - public appServerConfigs: t.MCPServers = {}; - - constructor(configs: t.MCPServers) { - this.rawConfigs = configs; - this.parsedConfigs = mapValues(configs, (con) => processMCPEnv({ options: con })); - this.connections = new ConnectionsRepository(configs); - this.initTimeoutMs = getMCPInitTimeout(); - } - - /** Initializes all startup-enabled servers by gathering their metadata asynchronously */ - public async initialize(): Promise { - if (this.initialized) return; - this.initialized = true; - - const serverNames = Object.keys(this.parsedConfigs); - - await Promise.allSettled( - serverNames.map((serverName) => this.initializeServerWithTimeout(serverName)), - ); - } - - /** Wraps server initialization with a timeout to prevent hanging */ - private async initializeServerWithTimeout(serverName: string): Promise { - let timeoutId: NodeJS.Timeout | null = null; - - try { - await Promise.race([ - this.initializeServer(serverName), - new Promise((_, reject) => { - timeoutId = setTimeout(() => { - reject(new Error('Server initialization timed out')); - }, this.initTimeoutMs); - }), - ]); - } catch (error) { - logger.warn(`${this.prefix(serverName)} Server initialization failed:`, error); - throw error; - } finally { - if (timeoutId != null) { - clearTimeout(timeoutId); - } - } - } - - /** Initializes a single server with all its metadata and adds it to appropriate collections */ - private async initializeServer(serverName: string): Promise { - const start = Date.now(); - - const config = this.parsedConfigs[serverName]; - - // 1. Detect OAuth requirements if not already specified - try { - await this.fetchOAuthRequirement(serverName); - - if (config.startup !== false && !config.requiresOAuth) { - await Promise.allSettled([ - this.fetchServerInstructions(serverName).catch((error) => - logger.warn(`${this.prefix(serverName)} Failed to fetch server instructions:`, error), - ), - this.fetchServerCapabilities(serverName).catch((error) => - logger.warn(`${this.prefix(serverName)} Failed to fetch server capabilities:`, error), - ), - ]); - } - } catch (error) { - logger.warn(`${this.prefix(serverName)} Failed to initialize server:`, error); - } - - // 2. Fetch tool functions for this server if a connection was established - const getToolFunctions = async (): Promise => { - try { - const loadedConns = await this.connections.getLoaded(); - const conn = loadedConns.get(serverName); - if (conn == null) { - return null; - } - return this.getToolFunctions(serverName, conn); - } catch (error) { - logger.warn(`${this.prefix(serverName)} Error fetching tool functions:`, error); - return null; - } - }; - const toolFunctions = await getToolFunctions(); - - // 3. Disconnect this server's connection if it was established (fire-and-forget) - void this.connections.disconnect(serverName); - - // 4. Side effects - // 4.1 Add to OAuth servers if needed - if (config.requiresOAuth) { - this.oauthServers.add(serverName); - } - // 4.2 Add server instructions if available - if (config.serverInstructions != null) { - this.serverInstructions[serverName] = config.serverInstructions as string; - } - // 4.3 Add to app server configs if eligible (startup enabled, non-OAuth servers) - if (config.startup !== false && config.requiresOAuth === false) { - this.appServerConfigs[serverName] = this.rawConfigs[serverName]; - } - // 4.4 Add tool functions if available - if (toolFunctions != null) { - Object.assign(this.toolFunctions, toolFunctions); - } - - const duration = Date.now() - start; - this.logUpdatedConfig(serverName, duration); - } - - /** Converts server tools to LibreChat-compatible tool functions format */ - public async getToolFunctions( - serverName: string, - conn: MCPConnection, - ): Promise { - const { tools }: t.MCPToolListResponse = await conn.client.listTools(); - - const toolFunctions: t.LCAvailableTools = {}; - tools.forEach((tool) => { - const name = `${tool.name}${Constants.mcp_delimiter}${serverName}`; - toolFunctions[name] = { - type: 'function', - ['function']: { - name, - description: tool.description, - parameters: tool.inputSchema as JsonSchemaType, - }, - }; - }); - - return toolFunctions; - } - - /** Determines if server requires OAuth if not already specified in the config */ - private async fetchOAuthRequirement(serverName: string): Promise { - const config = this.parsedConfigs[serverName]; - if (config.requiresOAuth != null) return config.requiresOAuth; - if (config.url == null) return (config.requiresOAuth = false); - if (config.startup === false) return (config.requiresOAuth = false); - - const result = await detectOAuthRequirement(config.url); - config.requiresOAuth = result.requiresOAuth; - config.oauthMetadata = result.metadata; - return config.requiresOAuth; - } - - /** Retrieves server instructions from MCP server if enabled in the config */ - private async fetchServerInstructions(serverName: string): Promise { - const config = this.parsedConfigs[serverName]; - if (!config.serverInstructions) return; - - // If it's a string that's not "true", it's a custom instruction - if (typeof config.serverInstructions === 'string' && !isEnabled(config.serverInstructions)) { - return; - } - - // Fetch from server if true (boolean) or "true" (string) - const conn = await this.connections.get(serverName); - config.serverInstructions = conn.client.getInstructions(); - if (!config.serverInstructions) { - logger.warn(`${this.prefix(serverName)} No server instructions available`); - } - } - - /** Fetches server capabilities and available tools list */ - private async fetchServerCapabilities(serverName: string): Promise { - const config = this.parsedConfigs[serverName]; - const conn = await this.connections.get(serverName); - const capabilities = conn.client.getServerCapabilities(); - if (!capabilities) return; - config.capabilities = JSON.stringify(capabilities); - if (!capabilities.tools) return; - const tools = await conn.client.listTools(); - config.tools = tools.tools.map((tool) => tool.name).join(', '); - } - - // Logs server configuration summary after initialization - private logUpdatedConfig(serverName: string, initDuration: number): void { - const prefix = this.prefix(serverName); - const config = this.parsedConfigs[serverName]; - logger.info(`${prefix} -------------------------------------------------┐`); - logger.info(`${prefix} URL: ${config.url ? sanitizeUrlForLogging(config.url) : 'N/A'}`); - logger.info(`${prefix} OAuth Required: ${config.requiresOAuth}`); - logger.info(`${prefix} Capabilities: ${config.capabilities}`); - logger.info(`${prefix} Tools: ${config.tools}`); - logger.info(`${prefix} Server Instructions: ${config.serverInstructions}`); - logger.info(`${prefix} Initialized in: ${initDuration}ms`); - logger.info(`${prefix} -------------------------------------------------┘`); - } - - // Returns formatted log prefix for server messages - private prefix(serverName: string): string { - return `[MCP][${serverName}]`; - } -} diff --git a/packages/api/src/mcp/UserConnectionManager.ts b/packages/api/src/mcp/UserConnectionManager.ts index 7f5862b2a8..21c177dc7c 100644 --- a/packages/api/src/mcp/UserConnectionManager.ts +++ b/packages/api/src/mcp/UserConnectionManager.ts @@ -1,7 +1,7 @@ import { logger } from '@librechat/data-schemas'; import { ErrorCode, McpError } from '@modelcontextprotocol/sdk/types.js'; import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; -import { MCPServersRegistry } from '~/mcp/MCPServersRegistry'; +import { mcpServersRegistry as serversRegistry } from '~/mcp/registry/MCPServersRegistry'; import { MCPConnection } from './connection'; import type * as t from './types'; import { ConnectionsRepository } from '~/mcp/ConnectionsRepository'; @@ -14,7 +14,6 @@ import { ConnectionsRepository } from '~/mcp/ConnectionsRepository'; * https://github.com/danny-avila/LibreChat/discussions/8790 */ export abstract class UserConnectionManager { - protected readonly serversRegistry: MCPServersRegistry; // Connections shared by all users. public appConnections: ConnectionsRepository | null = null; // Connections per userId -> serverName -> connection @@ -23,15 +22,6 @@ export abstract class UserConnectionManager { protected userLastActivity: Map = new Map(); protected readonly USER_CONNECTION_IDLE_TIMEOUT = 15 * 60 * 1000; // 15 minutes (TODO: make configurable) - constructor(serverConfigs: t.MCPServers) { - this.serversRegistry = new MCPServersRegistry(serverConfigs); - } - - /** fetches am MCP Server config from the registry */ - public getRawConfig(serverName: string): t.MCPOptions | undefined { - return this.serversRegistry.rawConfigs[serverName]; - } - /** Updates the last activity timestamp for a user */ protected updateUserLastActivity(userId: string): void { const now = Date.now(); @@ -106,7 +96,7 @@ export abstract class UserConnectionManager { logger.info(`[MCP][User: ${userId}][${serverName}] Establishing new connection`); } - const config = this.serversRegistry.parsedConfigs[serverName]; + const config = await serversRegistry.getServerConfig(serverName, userId); if (!config) { throw new McpError( ErrorCode.InvalidRequest, diff --git a/packages/api/src/mcp/__tests__/MCPManager.test.ts b/packages/api/src/mcp/__tests__/MCPManager.test.ts index 4d60a16954..ff0ba8ad3b 100644 --- a/packages/api/src/mcp/__tests__/MCPManager.test.ts +++ b/packages/api/src/mcp/__tests__/MCPManager.test.ts @@ -1,7 +1,9 @@ import { logger } from '@librechat/data-schemas'; import type * as t from '~/mcp/types'; import { MCPManager } from '~/mcp/MCPManager'; -import { MCPServersRegistry } from '~/mcp/MCPServersRegistry'; +import { mcpServersRegistry } from '~/mcp/registry/MCPServersRegistry'; +import { MCPServersInitializer } from '~/mcp/registry/MCPServersInitializer'; +import { MCPServerInspector } from '~/mcp/registry/MCPServerInspector'; import { ConnectionsRepository } from '~/mcp/ConnectionsRepository'; import { MCPConnection } from '../connection'; @@ -15,7 +17,24 @@ jest.mock('@librechat/data-schemas', () => ({ }, })); -jest.mock('~/mcp/MCPServersRegistry'); +jest.mock('~/mcp/registry/MCPServersRegistry', () => ({ + mcpServersRegistry: { + sharedAppServers: { + getAll: jest.fn(), + }, + getServerConfig: jest.fn(), + getAllServerConfigs: jest.fn(), + getOAuthServers: jest.fn(), + }, +})); + +jest.mock('~/mcp/registry/MCPServersInitializer', () => ({ + MCPServersInitializer: { + initialize: jest.fn(), + }, +})); + +jest.mock('~/mcp/registry/MCPServerInspector'); jest.mock('~/mcp/ConnectionsRepository'); const mockLogger = logger as jest.Mocked; @@ -28,20 +47,12 @@ describe('MCPManager', () => { // Reset MCPManager singleton state (MCPManager as unknown as { instance: null }).instance = null; jest.clearAllMocks(); - }); - function mockRegistry( - registryConfig: Partial, - ): jest.MockedClass { - const mock = { - initialize: jest.fn().mockResolvedValue(undefined), - getToolFunctions: jest.fn().mockResolvedValue(null), - ...registryConfig, - }; - return (MCPServersRegistry as jest.MockedClass).mockImplementation( - () => mock as unknown as MCPServersRegistry, - ); - } + // Set up default mock implementations + (MCPServersInitializer.initialize as jest.Mock).mockResolvedValue(undefined); + (mcpServersRegistry.sharedAppServers.getAll as jest.Mock).mockResolvedValue({}); + (mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({}); + }); function mockAppConnections( appConnectionsConfig: Partial, @@ -66,12 +77,229 @@ describe('MCPManager', () => { }; } + describe('getAppToolFunctions', () => { + it('should return empty object when no servers have tool functions', async () => { + (mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({ + server1: { type: 'stdio', command: 'test', args: [] }, + server2: { type: 'stdio', command: 'test2', args: [] }, + }); + + const manager = await MCPManager.createInstance(newMCPServersConfig()); + const result = await manager.getAppToolFunctions(); + + expect(result).toEqual({}); + }); + + it('should collect tool functions from multiple servers', async () => { + const toolFunctions1 = { + tool1_mcp_server1: { + type: 'function' as const, + function: { + name: 'tool1_mcp_server1', + description: 'Tool 1', + parameters: { type: 'object' as const }, + }, + }, + }; + + const toolFunctions2 = { + tool2_mcp_server2: { + type: 'function' as const, + function: { + name: 'tool2_mcp_server2', + description: 'Tool 2', + parameters: { type: 'object' as const }, + }, + }, + }; + + (mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({ + server1: { + type: 'stdio', + command: 'test', + args: [], + toolFunctions: toolFunctions1, + }, + server2: { + type: 'stdio', + command: 'test2', + args: [], + toolFunctions: toolFunctions2, + }, + }); + + const manager = await MCPManager.createInstance(newMCPServersConfig()); + const result = await manager.getAppToolFunctions(); + + expect(result).toEqual({ + ...toolFunctions1, + ...toolFunctions2, + }); + }); + + it('should handle servers with null or undefined toolFunctions', async () => { + const toolFunctions1 = { + tool1_mcp_server1: { + type: 'function' as const, + function: { + name: 'tool1_mcp_server1', + description: 'Tool 1', + parameters: { type: 'object' as const }, + }, + }, + }; + + (mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({ + server1: { + type: 'stdio', + command: 'test', + args: [], + toolFunctions: toolFunctions1, + }, + server2: { + type: 'stdio', + command: 'test2', + args: [], + toolFunctions: null, + }, + server3: { + type: 'stdio', + command: 'test3', + args: [], + }, + }); + + const manager = await MCPManager.createInstance(newMCPServersConfig()); + const result = await manager.getAppToolFunctions(); + + expect(result).toEqual(toolFunctions1); + }); + }); + + describe('formatInstructionsForContext', () => { + it('should return empty string when no servers have instructions', async () => { + (mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({ + server1: { type: 'stdio', command: 'test', args: [] }, + server2: { type: 'stdio', command: 'test2', args: [] }, + }); + + const manager = await MCPManager.createInstance(newMCPServersConfig()); + const result = await manager.formatInstructionsForContext(); + + expect(result).toBe(''); + }); + + it('should format instructions from multiple servers', async () => { + (mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({ + github: { + type: 'sse', + url: 'https://api.github.com', + serverInstructions: 'Use GitHub API with care', + }, + files: { + type: 'stdio', + command: 'node', + args: ['files.js'], + serverInstructions: 'Only read/write files in allowed directories', + }, + }); + + const manager = await MCPManager.createInstance(newMCPServersConfig()); + const result = await manager.formatInstructionsForContext(); + + expect(result).toContain('# MCP Server Instructions'); + expect(result).toContain('## github MCP Server Instructions'); + expect(result).toContain('Use GitHub API with care'); + expect(result).toContain('## files MCP Server Instructions'); + expect(result).toContain('Only read/write files in allowed directories'); + }); + + it('should filter instructions by server names when provided', async () => { + (mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({ + github: { + type: 'sse', + url: 'https://api.github.com', + serverInstructions: 'Use GitHub API with care', + }, + files: { + type: 'stdio', + command: 'node', + args: ['files.js'], + serverInstructions: 'Only read/write files in allowed directories', + }, + database: { + type: 'stdio', + command: 'node', + args: ['db.js'], + serverInstructions: 'Be careful with database operations', + }, + }); + + const manager = await MCPManager.createInstance(newMCPServersConfig()); + const result = await manager.formatInstructionsForContext(['github', 'database']); + + expect(result).toContain('## github MCP Server Instructions'); + expect(result).toContain('Use GitHub API with care'); + expect(result).toContain('## database MCP Server Instructions'); + expect(result).toContain('Be careful with database operations'); + expect(result).not.toContain('files'); + expect(result).not.toContain('Only read/write files in allowed directories'); + }); + + it('should handle servers with null or undefined instructions', async () => { + (mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({ + github: { + type: 'sse', + url: 'https://api.github.com', + serverInstructions: 'Use GitHub API with care', + }, + files: { + type: 'stdio', + command: 'node', + args: ['files.js'], + serverInstructions: null, + }, + database: { + type: 'stdio', + command: 'node', + args: ['db.js'], + }, + }); + + const manager = await MCPManager.createInstance(newMCPServersConfig()); + const result = await manager.formatInstructionsForContext(); + + expect(result).toContain('## github MCP Server Instructions'); + expect(result).toContain('Use GitHub API with care'); + expect(result).not.toContain('files'); + expect(result).not.toContain('database'); + }); + + it('should return empty string when filtered servers have no instructions', async () => { + (mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({ + github: { + type: 'sse', + url: 'https://api.github.com', + serverInstructions: 'Use GitHub API with care', + }, + files: { + type: 'stdio', + command: 'node', + args: ['files.js'], + }, + }); + + const manager = await MCPManager.createInstance(newMCPServersConfig()); + const result = await manager.formatInstructionsForContext(['files']); + + expect(result).toBe(''); + }); + }); + describe('getServerToolFunctions', () => { it('should catch and handle errors gracefully', async () => { - mockRegistry({ - getToolFunctions: jest.fn(() => { - throw new Error('Connection failed'); - }), + (MCPServerInspector.getToolFunctions as jest.Mock) = jest.fn(() => { + throw new Error('Connection failed'); }); mockAppConnections({ @@ -90,9 +318,7 @@ describe('MCPManager', () => { }); it('should catch synchronous errors from getUserConnections', async () => { - mockRegistry({ - getToolFunctions: jest.fn().mockResolvedValue({}), - }); + (MCPServerInspector.getToolFunctions as jest.Mock) = jest.fn().mockResolvedValue({}); mockAppConnections({ has: jest.fn().mockReturnValue(false), @@ -126,9 +352,9 @@ describe('MCPManager', () => { }, }; - mockRegistry({ - getToolFunctions: jest.fn().mockResolvedValue(expectedTools), - }); + (MCPServerInspector.getToolFunctions as jest.Mock) = jest + .fn() + .mockResolvedValue(expectedTools); mockAppConnections({ has: jest.fn().mockReturnValue(true), @@ -145,10 +371,8 @@ describe('MCPManager', () => { it('should include specific server name in error messages', async () => { const specificServerName = 'github_mcp_server'; - mockRegistry({ - getToolFunctions: jest.fn(() => { - throw new Error('Server specific error'); - }), + (MCPServerInspector.getToolFunctions as jest.Mock) = jest.fn(() => { + throw new Error('Server specific error'); }); mockAppConnections({ diff --git a/packages/api/src/mcp/__tests__/MCPServersRegistry.test.ts b/packages/api/src/mcp/__tests__/MCPServersRegistry.test.ts deleted file mode 100644 index ade8eab32c..0000000000 --- a/packages/api/src/mcp/__tests__/MCPServersRegistry.test.ts +++ /dev/null @@ -1,595 +0,0 @@ -import { join } from 'path'; -import { readFileSync } from 'fs'; -import { load as yamlLoad } from 'js-yaml'; -import { logger } from '@librechat/data-schemas'; -import type { OAuthDetectionResult } from '~/mcp/oauth/detectOAuth'; -import type * as t from '~/mcp/types'; -import { ConnectionsRepository } from '~/mcp/ConnectionsRepository'; -import { MCPServersRegistry } from '~/mcp/MCPServersRegistry'; -import { detectOAuthRequirement } from '~/mcp/oauth'; -import { MCPConnection } from '~/mcp/connection'; - -// Mock external dependencies -jest.mock('../oauth/detectOAuth'); -jest.mock('../ConnectionsRepository'); -jest.mock('../connection'); -jest.mock('@librechat/data-schemas', () => ({ - logger: { - info: jest.fn(), - warn: jest.fn(), - error: jest.fn(), - debug: jest.fn(), - }, -})); - -// Mock processMCPEnv to verify it's called and adds a processed marker -jest.mock('~/utils', () => ({ - ...jest.requireActual('~/utils'), - processMCPEnv: jest.fn(({ options }) => ({ - ...options, - _processed: true, // Simple marker to verify processing occurred - })), -})); - -const mockDetectOAuthRequirement = detectOAuthRequirement as jest.MockedFunction< - typeof detectOAuthRequirement ->; -const mockLogger = logger as jest.Mocked; - -describe('MCPServersRegistry - Initialize Function', () => { - let rawConfigs: t.MCPServers; - let expectedParsedConfigs: Record; - let mockConnectionsRepo: jest.Mocked; - let mockConnections: Map>; - - beforeEach(() => { - // Load fixtures - const rawConfigsPath = join(__dirname, 'fixtures', 'MCPServersRegistry.rawConfigs.yml'); - const parsedConfigsPath = join(__dirname, 'fixtures', 'MCPServersRegistry.parsedConfigs.yml'); - - rawConfigs = yamlLoad(readFileSync(rawConfigsPath, 'utf8')) as t.MCPServers; - expectedParsedConfigs = yamlLoad(readFileSync(parsedConfigsPath, 'utf8')) as Record< - string, - t.ParsedServerConfig - >; - - // Setup mock connections - mockConnections = new Map(); - const serverNames = Object.keys(rawConfigs); - - serverNames.forEach((serverName) => { - const mockClient = { - listTools: jest.fn(), - getInstructions: jest.fn(), - getServerCapabilities: jest.fn(), - }; - const mockConnection = { - client: mockClient, - } as unknown as jest.Mocked; - - // Setup mock responses based on expected configs - const expectedConfig = expectedParsedConfigs[serverName]; - - // Mock listTools response - if (expectedConfig.tools) { - const toolNames = expectedConfig.tools.split(', '); - const tools = toolNames.map((name: string) => ({ - name, - description: `Description for ${name}`, - inputSchema: { - type: 'object' as const, - properties: { - input: { type: 'string' }, - }, - }, - })); - (mockClient.listTools as jest.Mock).mockResolvedValue({ tools }); - } else { - (mockClient.listTools as jest.Mock).mockResolvedValue({ tools: [] }); - } - - // Mock getInstructions response - if (expectedConfig.serverInstructions) { - (mockClient.getInstructions as jest.Mock).mockReturnValue( - expectedConfig.serverInstructions as string, - ); - } else { - (mockClient.getInstructions as jest.Mock).mockReturnValue(undefined); - } - - // Mock getServerCapabilities response - if (expectedConfig.capabilities) { - const capabilities = JSON.parse(expectedConfig.capabilities) as Record; - (mockClient.getServerCapabilities as jest.Mock).mockReturnValue(capabilities); - } else { - (mockClient.getServerCapabilities as jest.Mock).mockReturnValue(undefined); - } - - mockConnections.set(serverName, mockConnection); - }); - - // Setup ConnectionsRepository mock - mockConnectionsRepo = { - get: jest.fn(), - getLoaded: jest.fn(), - disconnectAll: jest.fn(), - disconnect: jest.fn().mockResolvedValue(undefined), - } as unknown as jest.Mocked; - - mockConnectionsRepo.get.mockImplementation((serverName: string) => { - const connection = mockConnections.get(serverName); - if (!connection) { - throw new Error(`Connection not found for server: ${serverName}`); - } - return Promise.resolve(connection); - }); - - mockConnectionsRepo.getLoaded.mockResolvedValue(mockConnections); - - (ConnectionsRepository as jest.Mock).mockImplementation(() => mockConnectionsRepo); - - // Setup OAuth detection mock with deterministic results - mockDetectOAuthRequirement.mockImplementation((url: string) => { - const oauthResults: Record = { - 'https://api.github.com/mcp': { - requiresOAuth: true, - method: 'protected-resource-metadata', - metadata: { - authorization_url: 'https://github.com/login/oauth/authorize', - token_url: 'https://github.com/login/oauth/access_token', - }, - }, - 'https://api.disabled.com/mcp': { - requiresOAuth: false, - method: 'no-metadata-found', - metadata: null, - }, - 'https://api.public.com/mcp': { - requiresOAuth: false, - method: 'no-metadata-found', - metadata: null, - }, - }; - - return Promise.resolve( - oauthResults[url] || { requiresOAuth: false, method: 'no-metadata-found', metadata: null }, - ); - }); - - // Clear all mocks - jest.clearAllMocks(); - }); - - afterEach(() => { - delete process.env.MCP_INIT_TIMEOUT_MS; - jest.clearAllMocks(); - }); - - describe('initialize() method', () => { - it('should only run initialization once', async () => { - const registry = new MCPServersRegistry(rawConfigs); - - await registry.initialize(); - await registry.initialize(); // Second call should not re-run - - // Verify that connections are only requested for servers that need them - // (servers with serverInstructions=true or all servers for capabilities) - expect(mockConnectionsRepo.get).toHaveBeenCalled(); - }); - - it('should set all public properties correctly after initialization', async () => { - const registry = new MCPServersRegistry(rawConfigs); - - // Verify initial state - expect(registry.oauthServers.size).toBe(0); - expect(registry.serverInstructions).toEqual({}); - expect(registry.toolFunctions).toEqual({}); - expect(registry.appServerConfigs).toEqual({}); - - await registry.initialize(); - - // Test oauthServers Set - expect(registry.oauthServers).toEqual( - new Set(['oauth_server', 'oauth_predefined', 'oauth_startup_enabled']), - ); - - // Test serverInstructions - OAuth servers keep their original boolean value, non-OAuth fetch actual strings - expect(registry.serverInstructions).toEqual({ - stdio_server: 'Follow these instructions for stdio server', - oauth_server: true, - non_oauth_server: 'Public API instructions', - }); - - // Test appServerConfigs (startup enabled, non-OAuth servers only) - expect(registry.appServerConfigs).toEqual({ - stdio_server: rawConfigs.stdio_server, - websocket_server: rawConfigs.websocket_server, - non_oauth_server: rawConfigs.non_oauth_server, - }); - - // Test toolFunctions (only non-OAuth servers get their tools fetched during initialization) - const expectedToolFunctions = { - file_read_mcp_stdio_server: { - type: 'function', - function: { - name: 'file_read_mcp_stdio_server', - description: 'Description for file_read', - parameters: { type: 'object', properties: { input: { type: 'string' } } }, - }, - }, - file_write_mcp_stdio_server: { - type: 'function', - function: { - name: 'file_write_mcp_stdio_server', - description: 'Description for file_write', - parameters: { type: 'object', properties: { input: { type: 'string' } } }, - }, - }, - }; - expect(registry.toolFunctions).toEqual(expectedToolFunctions); - }); - - it('should handle errors gracefully and continue initialization of other servers', async () => { - const registry = new MCPServersRegistry(rawConfigs); - - // Make one specific server throw an error during OAuth detection - mockDetectOAuthRequirement.mockImplementation((url: string) => { - if (url === 'https://api.github.com/mcp') { - return Promise.reject(new Error('OAuth detection failed')); - } - // Return normal responses for other servers - const oauthResults: Record = { - 'https://api.disabled.com/mcp': { - requiresOAuth: false, - method: 'no-metadata-found', - metadata: null, - }, - 'https://api.public.com/mcp': { - requiresOAuth: false, - method: 'no-metadata-found', - metadata: null, - }, - }; - return Promise.resolve( - oauthResults[url] ?? { - requiresOAuth: false, - method: 'no-metadata-found', - metadata: null, - }, - ); - }); - - await registry.initialize(); - - // Should still initialize successfully for other servers - expect(registry.oauthServers).toBeInstanceOf(Set); - expect(registry.toolFunctions).toBeDefined(); - - // The failed server should not be in oauthServers (since it failed OAuth detection) - expect(registry.oauthServers.has('oauth_server')).toBe(false); - - // But other servers should still be processed successfully - expect(registry.appServerConfigs).toHaveProperty('stdio_server'); - expect(registry.appServerConfigs).toHaveProperty('non_oauth_server'); - - // Error should be logged as a warning at the higher level - expect(mockLogger.warn).toHaveBeenCalledWith( - expect.stringContaining('[MCP][oauth_server] Failed to initialize server:'), - expect.any(Error), - ); - }); - - it('should disconnect individual connections after each server initialization', async () => { - const registry = new MCPServersRegistry(rawConfigs); - - await registry.initialize(); - - // Verify disconnect was called for each server during initialization - // All servers attempt to connect during initialization for metadata gathering - const serverNames = Object.keys(rawConfigs); - expect(mockConnectionsRepo.disconnect).toHaveBeenCalledTimes(serverNames.length); - }); - - it('should log configuration updates for each startup-enabled server', async () => { - const registry = new MCPServersRegistry(rawConfigs); - - await registry.initialize(); - - const serverNames = Object.keys(rawConfigs); - serverNames.forEach((serverName) => { - expect(mockLogger.info).toHaveBeenCalledWith( - expect.stringContaining(`[MCP][${serverName}] URL:`), - ); - expect(mockLogger.info).toHaveBeenCalledWith( - expect.stringContaining(`[MCP][${serverName}] OAuth Required:`), - ); - expect(mockLogger.info).toHaveBeenCalledWith( - expect.stringContaining(`[MCP][${serverName}] Capabilities:`), - ); - expect(mockLogger.info).toHaveBeenCalledWith( - expect.stringContaining(`[MCP][${serverName}] Tools:`), - ); - expect(mockLogger.info).toHaveBeenCalledWith( - expect.stringContaining(`[MCP][${serverName}] Server Instructions:`), - ); - }); - }); - - it('should have parsedConfigs matching the expected fixture after initialization', async () => { - const registry = new MCPServersRegistry(rawConfigs); - - await registry.initialize(); - - // Compare the actual parsedConfigs against the expected fixture - expect(registry.parsedConfigs).toEqual(expectedParsedConfigs); - }); - - it('should handle serverInstructions as string "true" correctly and fetch from server', async () => { - // Create test config with serverInstructions as string "true" - const testConfig: t.MCPServers = { - test_server_string_true: { - type: 'stdio', - args: [], - command: 'test-command', - serverInstructions: 'true', // Simulating string "true" from YAML parsing - }, - test_server_custom_string: { - type: 'stdio', - args: [], - command: 'test-command', - serverInstructions: 'Custom instructions here', - }, - test_server_bool_true: { - type: 'stdio', - args: [], - command: 'test-command', - serverInstructions: true, - }, - }; - - const registry = new MCPServersRegistry(testConfig); - - // Setup mock connection for servers that should fetch - const mockClient = { - listTools: jest.fn().mockResolvedValue({ tools: [] }), - getInstructions: jest.fn().mockReturnValue('Fetched instructions from server'), - getServerCapabilities: jest.fn().mockReturnValue({ tools: {} }), - }; - const mockConnection = { - client: mockClient, - } as unknown as jest.Mocked; - - mockConnectionsRepo.get.mockResolvedValue(mockConnection); - mockConnectionsRepo.getLoaded.mockResolvedValue( - new Map([ - ['test_server_string_true', mockConnection], - ['test_server_bool_true', mockConnection], - ]), - ); - mockDetectOAuthRequirement.mockResolvedValue({ - requiresOAuth: false, - method: 'no-metadata-found', - metadata: null, - }); - - await registry.initialize(); - - // Verify that string "true" was treated as fetch-from-server - expect(registry.parsedConfigs['test_server_string_true'].serverInstructions).toBe( - 'Fetched instructions from server', - ); - - // Verify that custom string was kept as-is - expect(registry.parsedConfigs['test_server_custom_string'].serverInstructions).toBe( - 'Custom instructions here', - ); - - // Verify that boolean true also fetched from server - expect(registry.parsedConfigs['test_server_bool_true'].serverInstructions).toBe( - 'Fetched instructions from server', - ); - - // Verify getInstructions was called for both "true" cases - expect(mockClient.getInstructions).toHaveBeenCalledTimes(2); - }); - - it('should use Promise.allSettled for individual server initialization', async () => { - const registry = new MCPServersRegistry(rawConfigs); - - // Spy on Promise.allSettled to verify it's being used - const allSettledSpy = jest.spyOn(Promise, 'allSettled'); - - await registry.initialize(); - - // Verify Promise.allSettled was called with an array of server initialization promises - expect(allSettledSpy).toHaveBeenCalledWith(expect.arrayContaining([expect.any(Promise)])); - - // Verify it was called with the correct number of server promises - const serverNames = Object.keys(rawConfigs); - expect(allSettledSpy).toHaveBeenCalledWith( - expect.arrayContaining(new Array(serverNames.length).fill(expect.any(Promise))), - ); - - allSettledSpy.mockRestore(); - }); - - it('should isolate server failures and not affect other servers', async () => { - const registry = new MCPServersRegistry(rawConfigs); - - // Make multiple servers fail in different ways - mockConnectionsRepo.get.mockImplementation((serverName: string) => { - if (serverName === 'stdio_server') { - // First server fails - throw new Error('Connection failed for stdio_server'); - } - if (serverName === 'websocket_server') { - // Second server fails - throw new Error('Connection failed for websocket_server'); - } - // Other servers succeed - const connection = mockConnections.get(serverName); - if (!connection) { - throw new Error(`Connection not found for server: ${serverName}`); - } - return Promise.resolve(connection); - }); - - await registry.initialize(); - - // Despite failures, initialization should complete - expect(registry.oauthServers).toBeInstanceOf(Set); - expect(registry.toolFunctions).toBeDefined(); - - // Successful servers should still be processed - expect(registry.appServerConfigs).toHaveProperty('non_oauth_server'); - - // Failed servers should not crash the whole initialization - expect(mockLogger.warn).toHaveBeenCalledWith( - expect.stringContaining('[MCP][stdio_server] Failed to fetch server capabilities:'), - expect.any(Error), - ); - expect(mockLogger.warn).toHaveBeenCalledWith( - expect.stringContaining('[MCP][websocket_server] Failed to fetch server capabilities:'), - expect.any(Error), - ); - }); - - it('should properly clean up connections even when some servers fail', async () => { - const registry = new MCPServersRegistry(rawConfigs); - - // Track disconnect failures but suppress unhandled rejections - const disconnectErrors: Error[] = []; - mockConnectionsRepo.disconnect.mockImplementation((serverName: string) => { - if (serverName === 'stdio_server') { - const error = new Error('Disconnect failed'); - disconnectErrors.push(error); - return Promise.reject(error).catch(() => {}); // Suppress unhandled rejection - } - return Promise.resolve(); - }); - - await registry.initialize(); - - // Should still attempt to disconnect all servers during initialization - const serverNames = Object.keys(rawConfigs); - expect(mockConnectionsRepo.disconnect).toHaveBeenCalledTimes(serverNames.length); - expect(disconnectErrors).toHaveLength(1); - }); - - it('should timeout individual server initialization after configured timeout', async () => { - const timeout = 2000; - // Create registry with a short timeout for testing - process.env.MCP_INIT_TIMEOUT_MS = `${timeout}`; - - const registry = new MCPServersRegistry(rawConfigs); - - // Make one server hang indefinitely during OAuth detection - mockDetectOAuthRequirement.mockImplementation((url: string) => { - if (url === 'https://api.github.com/mcp') { - // Slow init - return new Promise((res) => setTimeout(res, timeout * 2)); - } - // Return normal responses for other servers - return Promise.resolve({ - requiresOAuth: false, - method: 'no-metadata-found', - metadata: null, - }); - }); - - const start = Date.now(); - await registry.initialize(); - const duration = Date.now() - start; - - // Should complete within reasonable time despite one server hanging - // Allow some buffer for test execution overhead - expect(duration).toBeLessThan(timeout * 1.5); - - // The timeout should prevent the hanging server from blocking initialization - // Other servers should still be processed successfully - expect(registry.appServerConfigs).toHaveProperty('stdio_server'); - expect(registry.appServerConfigs).toHaveProperty('non_oauth_server'); - }, 10_000); // 10 second Jest timeout - - it('should skip tool function fetching if connection was not established', async () => { - const testConfig: t.MCPServers = { - server_with_connection: { - type: 'stdio', - args: [], - command: 'test-command', - }, - server_without_connection: { - type: 'stdio', - args: [], - command: 'failing-command', - }, - }; - - const registry = new MCPServersRegistry(testConfig); - - const mockClient = { - listTools: jest.fn().mockResolvedValue({ - tools: [ - { - name: 'test_tool', - description: 'Test tool', - inputSchema: { type: 'object', properties: {} }, - }, - ], - }), - getInstructions: jest.fn().mockReturnValue(undefined), - getServerCapabilities: jest.fn().mockReturnValue({ tools: {} }), - }; - const mockConnection = { - client: mockClient, - } as unknown as jest.Mocked; - - mockConnectionsRepo.get.mockImplementation((serverName: string) => { - if (serverName === 'server_with_connection') { - return Promise.resolve(mockConnection); - } - throw new Error('Connection failed'); - }); - - // Mock getLoaded to return connections map - the real implementation returns all loaded connections at once - mockConnectionsRepo.getLoaded.mockResolvedValue( - new Map([['server_with_connection', mockConnection]]), - ); - - mockDetectOAuthRequirement.mockResolvedValue({ - requiresOAuth: false, - method: 'no-metadata-found', - metadata: null, - }); - - await registry.initialize(); - - expect(registry.toolFunctions).toHaveProperty('test_tool_mcp_server_with_connection'); - expect(Object.keys(registry.toolFunctions)).toHaveLength(1); - }); - - it('should handle getLoaded returning empty map gracefully', async () => { - const testConfig: t.MCPServers = { - test_server: { - type: 'stdio', - args: [], - command: 'test-command', - }, - }; - - const registry = new MCPServersRegistry(testConfig); - - mockConnectionsRepo.get.mockRejectedValue(new Error('All connections failed')); - mockConnectionsRepo.getLoaded.mockResolvedValue(new Map()); - mockDetectOAuthRequirement.mockResolvedValue({ - requiresOAuth: false, - method: 'no-metadata-found', - metadata: null, - }); - - await registry.initialize(); - - expect(registry.toolFunctions).toEqual({}); - }); - }); -}); diff --git a/packages/api/src/mcp/__tests__/fixtures/MCPServersRegistry.parsedConfigs.yml b/packages/api/src/mcp/__tests__/fixtures/MCPServersRegistry.parsedConfigs.yml deleted file mode 100644 index 71b3e01d22..0000000000 --- a/packages/api/src/mcp/__tests__/fixtures/MCPServersRegistry.parsedConfigs.yml +++ /dev/null @@ -1,67 +0,0 @@ -# Expected parsed MCP server configurations after running initialize() -# These represent the expected state of parsedConfigs after all fetch functions complete - -oauth_server: - _processed: true - type: "streamable-http" - url: "https://api.github.com/mcp" - headers: - Authorization: "Bearer {{GITHUB_TOKEN}}" - serverInstructions: true - requiresOAuth: true - oauthMetadata: - authorization_url: "https://github.com/login/oauth/authorize" - token_url: "https://github.com/login/oauth/access_token" - -oauth_predefined: - _processed: true - type: "sse" - url: "https://api.example.com/sse" - requiresOAuth: true - oauthMetadata: - authorization_url: "https://example.com/oauth/authorize" - token_url: "https://example.com/oauth/token" - -stdio_server: - _processed: true - command: "node" - args: ["server.js"] - env: - API_KEY: "${TEST_API_KEY}" - startup: true - serverInstructions: "Follow these instructions for stdio server" - requiresOAuth: false - capabilities: '{"tools":{"listChanged":true},"resources":{"listChanged":true},"prompts":{}}' - tools: "file_read, file_write" - -websocket_server: - _processed: true - type: "websocket" - url: "ws://localhost:3001/mcp" - startup: true - requiresOAuth: false - oauthMetadata: null - capabilities: '{"tools":{},"resources":{},"prompts":{}}' - tools: "" - -disabled_server: - _processed: true - requiresOAuth: false - type: "streamable-http" - url: "https://api.disabled.com/mcp" - startup: false - -non_oauth_server: - _processed: true - type: "streamable-http" - url: "https://api.public.com/mcp" - requiresOAuth: false - serverInstructions: "Public API instructions" - capabilities: '{"tools":{},"resources":{},"prompts":{}}' - tools: "" - -oauth_startup_enabled: - _processed: true - type: "sse" - url: "https://api.oauth-startup.com/sse" - requiresOAuth: true \ No newline at end of file diff --git a/packages/api/src/mcp/__tests__/fixtures/MCPServersRegistry.rawConfigs.yml b/packages/api/src/mcp/__tests__/fixtures/MCPServersRegistry.rawConfigs.yml deleted file mode 100644 index 907dfaa96b..0000000000 --- a/packages/api/src/mcp/__tests__/fixtures/MCPServersRegistry.rawConfigs.yml +++ /dev/null @@ -1,53 +0,0 @@ -# Raw MCP server configurations used as input to MCPServersRegistry constructor -# These configs test different code paths in the initialization process - -# Test OAuth detection with URL - should trigger fetchOAuthRequirement -oauth_server: - type: "streamable-http" - url: "https://api.github.com/mcp" - headers: - Authorization: "Bearer {{GITHUB_TOKEN}}" - serverInstructions: true - -# Test OAuth already specified - should skip OAuth detection -oauth_predefined: - type: "sse" - url: "https://api.example.com/sse" - requiresOAuth: true - oauthMetadata: - authorization_url: "https://example.com/oauth/authorize" - token_url: "https://example.com/oauth/token" - -# Test stdio server without URL - should set requiresOAuth to false -stdio_server: - command: "node" - args: ["server.js"] - env: - API_KEY: "${TEST_API_KEY}" - startup: true - serverInstructions: "Follow these instructions for stdio server" - -# Test websocket server with capabilities but no tools -websocket_server: - type: "websocket" - url: "ws://localhost:3001/mcp" - startup: true - -# Test server with startup disabled - should not be included in appServerConfigs -disabled_server: - type: "streamable-http" - url: "https://api.disabled.com/mcp" - startup: false - -# Test non-OAuth server - should be included in appServerConfigs -non_oauth_server: - type: "streamable-http" - url: "https://api.public.com/mcp" - requiresOAuth: false - serverInstructions: true - -# Test server with OAuth but startup enabled - should not be in appServerConfigs -oauth_startup_enabled: - type: "sse" - url: "https://api.oauth-startup.com/sse" - requiresOAuth: true \ No newline at end of file diff --git a/packages/api/src/mcp/connection.ts b/packages/api/src/mcp/connection.ts index ad1b3e32aa..7e75acf751 100644 --- a/packages/api/src/mcp/connection.ts +++ b/packages/api/src/mcp/connection.ts @@ -18,6 +18,7 @@ import type { Response as UndiciResponse, } from 'undici'; import type { MCPOAuthTokens } from './oauth/types'; +import { withTimeout } from '~/utils/promise'; import type * as t from './types'; import { sanitizeUrlForLogging } from './utils'; import { mcpConfig } from './mcpConfig'; @@ -457,15 +458,11 @@ export class MCPConnection extends EventEmitter { this.setupTransportDebugHandlers(); const connectTimeout = this.options.initTimeout ?? 120000; - await Promise.race([ + await withTimeout( this.client.connect(this.transport), - new Promise((_resolve, reject) => - setTimeout( - () => reject(new Error(`Connection timeout after ${connectTimeout}ms`)), - connectTimeout, - ), - ), - ]); + connectTimeout, + `Connection timeout after ${connectTimeout}ms`, + ); this.connectionState = 'connected'; this.emit('connectionChange', 'connected'); diff --git a/packages/api/src/mcp/oauth/OAuthReconnectionManager.test.ts b/packages/api/src/mcp/oauth/OAuthReconnectionManager.test.ts index d2295191cf..f9a3c7ab73 100644 --- a/packages/api/src/mcp/oauth/OAuthReconnectionManager.test.ts +++ b/packages/api/src/mcp/oauth/OAuthReconnectionManager.test.ts @@ -1,6 +1,7 @@ import { TokenMethods } from '@librechat/data-schemas'; import { FlowStateManager, MCPConnection, MCPOAuthTokens, MCPOptions } from '../..'; import { MCPManager } from '../MCPManager'; +import { mcpServersRegistry } from '../../mcp/registry/MCPServersRegistry'; import { OAuthReconnectionManager } from './OAuthReconnectionManager'; import { OAuthReconnectionTracker } from './OAuthReconnectionTracker'; @@ -14,6 +15,12 @@ jest.mock('@librechat/data-schemas', () => ({ })); jest.mock('../MCPManager'); +jest.mock('../../mcp/registry/MCPServersRegistry', () => ({ + mcpServersRegistry: { + getServerConfig: jest.fn(), + getOAuthServers: jest.fn(), + }, +})); describe('OAuthReconnectionManager', () => { let flowManager: jest.Mocked>; @@ -51,10 +58,10 @@ describe('OAuthReconnectionManager', () => { getUserConnection: jest.fn(), getUserConnections: jest.fn(), disconnectUserConnection: jest.fn(), - getRawConfig: jest.fn(), } as unknown as jest.Mocked; (MCPManager.getInstance as jest.Mock).mockReturnValue(mockMCPManager); + (mcpServersRegistry.getServerConfig as jest.Mock).mockResolvedValue({}); }); afterEach(() => { @@ -152,7 +159,7 @@ describe('OAuthReconnectionManager', () => { it('should reconnect eligible servers', async () => { const userId = 'user-123'; const oauthServers = new Set(['server1', 'server2', 'server3']); - mockMCPManager.getOAuthServers.mockReturnValue(oauthServers); + (mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers); // server1: has failed reconnection reconnectionTracker.setFailed(userId, 'server1'); @@ -186,7 +193,9 @@ describe('OAuthReconnectionManager', () => { mockMCPManager.getUserConnection.mockResolvedValue( mockNewConnection as unknown as MCPConnection, ); - mockMCPManager.getRawConfig.mockReturnValue({ initTimeout: 5000 } as unknown as MCPOptions); + (mcpServersRegistry.getServerConfig as jest.Mock).mockResolvedValue({ + initTimeout: 5000, + } as unknown as MCPOptions); await reconnectionManager.reconnectServers(userId); @@ -215,7 +224,7 @@ describe('OAuthReconnectionManager', () => { it('should handle failed reconnection attempts', async () => { const userId = 'user-123'; const oauthServers = new Set(['server1']); - mockMCPManager.getOAuthServers.mockReturnValue(oauthServers); + (mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers); // server1: has valid token tokenMethods.findToken.mockResolvedValue({ @@ -226,7 +235,9 @@ describe('OAuthReconnectionManager', () => { // Mock failed connection mockMCPManager.getUserConnection.mockRejectedValue(new Error('Connection failed')); - mockMCPManager.getRawConfig.mockReturnValue({} as unknown as MCPOptions); + (mcpServersRegistry.getServerConfig as jest.Mock).mockResolvedValue( + {} as unknown as MCPOptions, + ); await reconnectionManager.reconnectServers(userId); @@ -242,7 +253,7 @@ describe('OAuthReconnectionManager', () => { it('should not reconnect servers with expired tokens', async () => { const userId = 'user-123'; const oauthServers = new Set(['server1']); - mockMCPManager.getOAuthServers.mockReturnValue(oauthServers); + (mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers); // server1: has expired token tokenMethods.findToken.mockResolvedValue({ @@ -261,7 +272,7 @@ describe('OAuthReconnectionManager', () => { it('should handle connection that returns but is not connected', async () => { const userId = 'user-123'; const oauthServers = new Set(['server1']); - mockMCPManager.getOAuthServers.mockReturnValue(oauthServers); + (mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers); tokenMethods.findToken.mockResolvedValue({ userId, @@ -277,7 +288,9 @@ describe('OAuthReconnectionManager', () => { mockMCPManager.getUserConnection.mockResolvedValue( mockConnection as unknown as MCPConnection, ); - mockMCPManager.getRawConfig.mockReturnValue({} as unknown as MCPOptions); + (mcpServersRegistry.getServerConfig as jest.Mock).mockResolvedValue( + {} as unknown as MCPOptions, + ); await reconnectionManager.reconnectServers(userId); @@ -359,7 +372,7 @@ describe('OAuthReconnectionManager', () => { it('should not attempt to reconnect servers that have timed out during reconnection', async () => { const userId = 'user-123'; const oauthServers = new Set(['server1', 'server2']); - mockMCPManager.getOAuthServers.mockReturnValue(oauthServers); + (mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers); const now = Date.now(); jest.setSystemTime(now); @@ -414,7 +427,7 @@ describe('OAuthReconnectionManager', () => { const userId = 'user-123'; const serverName = 'server1'; const oauthServers = new Set([serverName]); - mockMCPManager.getOAuthServers.mockReturnValue(oauthServers); + (mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers); const now = Date.now(); jest.setSystemTime(now); @@ -428,7 +441,9 @@ describe('OAuthReconnectionManager', () => { // First reconnect attempt - will fail mockMCPManager.getUserConnection.mockRejectedValueOnce(new Error('Connection failed')); - mockMCPManager.getRawConfig.mockReturnValue({} as unknown as MCPOptions); + (mcpServersRegistry.getServerConfig as jest.Mock).mockResolvedValue( + {} as unknown as MCPOptions, + ); await reconnectionManager.reconnectServers(userId); await jest.runAllTimersAsync(); diff --git a/packages/api/src/mcp/oauth/OAuthReconnectionManager.ts b/packages/api/src/mcp/oauth/OAuthReconnectionManager.ts index 09abb2b048..25edec7f3a 100644 --- a/packages/api/src/mcp/oauth/OAuthReconnectionManager.ts +++ b/packages/api/src/mcp/oauth/OAuthReconnectionManager.ts @@ -5,6 +5,7 @@ import type { MCPOAuthTokens } from './types'; import { OAuthReconnectionTracker } from './OAuthReconnectionTracker'; import { FlowStateManager } from '~/flow/manager'; import { MCPManager } from '~/mcp/MCPManager'; +import { mcpServersRegistry } from '~/mcp/registry/MCPServersRegistry'; const DEFAULT_CONNECTION_TIMEOUT_MS = 10_000; // ms @@ -72,7 +73,7 @@ export class OAuthReconnectionManager { // 1. derive the servers to reconnect const serversToReconnect = []; - for (const serverName of this.mcpManager.getOAuthServers()) { + for (const serverName of await mcpServersRegistry.getOAuthServers()) { const canReconnect = await this.canReconnect(userId, serverName); if (canReconnect) { serversToReconnect.push(serverName); @@ -104,7 +105,7 @@ export class OAuthReconnectionManager { logger.info(`${logPrefix} Attempting reconnection`); - const config = this.mcpManager.getRawConfig(serverName); + const config = await mcpServersRegistry.getServerConfig(serverName, userId); const cleanupOnFailedReconnect = () => { this.reconnectionsTracker.setFailed(userId, serverName); diff --git a/packages/api/src/mcp/registry/MCPServerInspector.ts b/packages/api/src/mcp/registry/MCPServerInspector.ts new file mode 100644 index 0000000000..3ae51d7b36 --- /dev/null +++ b/packages/api/src/mcp/registry/MCPServerInspector.ts @@ -0,0 +1,123 @@ +import { Constants } from 'librechat-data-provider'; +import type { JsonSchemaType } from '@librechat/data-schemas'; +import type { MCPConnection } from '~/mcp/connection'; +import type * as t from '~/mcp/types'; +import { detectOAuthRequirement } from '~/mcp/oauth'; +import { isEnabled } from '~/utils'; +import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; + +/** + * Inspects MCP servers to discover their metadata, capabilities, and tools. + * Connects to servers and populates configuration with OAuth requirements, + * server instructions, capabilities, and available tools. + */ +export class MCPServerInspector { + private constructor( + private readonly serverName: string, + private readonly config: t.ParsedServerConfig, + private connection: MCPConnection | undefined, + ) {} + + /** + * Inspects a server and returns an enriched configuration with metadata. + * Detects OAuth requirements and fetches server capabilities. + * @param serverName - The name of the server (used for tool function naming) + * @param rawConfig - The raw server configuration + * @param connection - The MCP connection + * @returns A fully processed and enriched configuration with server metadata + */ + public static async inspect( + serverName: string, + rawConfig: t.MCPOptions, + connection?: MCPConnection, + ): Promise { + const start = Date.now(); + const inspector = new MCPServerInspector(serverName, rawConfig, connection); + await inspector.inspectServer(); + inspector.config.initDuration = Date.now() - start; + return inspector.config; + } + + private async inspectServer(): Promise { + await this.detectOAuth(); + + if (this.config.startup !== false && !this.config.requiresOAuth) { + let tempConnection = false; + if (!this.connection) { + tempConnection = true; + this.connection = await MCPConnectionFactory.create({ + serverName: this.serverName, + serverConfig: this.config, + }); + } + + await Promise.allSettled([ + this.fetchServerInstructions(), + this.fetchServerCapabilities(), + this.fetchToolFunctions(), + ]); + + if (tempConnection) await this.connection.disconnect(); + } + } + + private async detectOAuth(): Promise { + if (this.config.requiresOAuth != null) return; + if (this.config.url == null || this.config.startup === false) { + this.config.requiresOAuth = false; + return; + } + + const result = await detectOAuthRequirement(this.config.url); + this.config.requiresOAuth = result.requiresOAuth; + this.config.oauthMetadata = result.metadata; + } + + private async fetchServerInstructions(): Promise { + if (isEnabled(this.config.serverInstructions)) { + this.config.serverInstructions = this.connection!.client.getInstructions(); + } + } + + private async fetchServerCapabilities(): Promise { + const capabilities = this.connection!.client.getServerCapabilities(); + this.config.capabilities = JSON.stringify(capabilities); + const tools = await this.connection!.client.listTools(); + this.config.tools = tools.tools.map((tool) => tool.name).join(', '); + } + + private async fetchToolFunctions(): Promise { + this.config.toolFunctions = await MCPServerInspector.getToolFunctions( + this.serverName, + this.connection!, + ); + } + + /** + * Converts server tools to LibreChat-compatible tool functions format. + * @param serverName - The name of the server + * @param connection - The MCP connection + * @returns Tool functions formatted for LibreChat + */ + public static async getToolFunctions( + serverName: string, + connection: MCPConnection, + ): Promise { + const { tools }: t.MCPToolListResponse = await connection.client.listTools(); + + const toolFunctions: t.LCAvailableTools = {}; + tools.forEach((tool) => { + const name = `${tool.name}${Constants.mcp_delimiter}${serverName}`; + toolFunctions[name] = { + type: 'function', + ['function']: { + name, + description: tool.description, + parameters: tool.inputSchema as JsonSchemaType, + }, + }; + }); + + return toolFunctions; + } +} diff --git a/packages/api/src/mcp/registry/MCPServersInitializer.ts b/packages/api/src/mcp/registry/MCPServersInitializer.ts new file mode 100644 index 0000000000..f29cd6769f --- /dev/null +++ b/packages/api/src/mcp/registry/MCPServersInitializer.ts @@ -0,0 +1,96 @@ +import { registryStatusCache as statusCache } from './cache/RegistryStatusCache'; +import { isLeader } from '~/cluster'; +import { withTimeout } from '~/utils'; +import { logger } from '@librechat/data-schemas'; +import { MCPServerInspector } from './MCPServerInspector'; +import { ParsedServerConfig } from '~/mcp/types'; +import { sanitizeUrlForLogging } from '~/mcp/utils'; +import type * as t from '~/mcp/types'; +import { mcpServersRegistry as registry } from './MCPServersRegistry'; + +const MCP_INIT_TIMEOUT_MS = + process.env.MCP_INIT_TIMEOUT_MS != null ? parseInt(process.env.MCP_INIT_TIMEOUT_MS) : 30_000; + +/** + * Handles initialization of MCP servers at application startup with distributed coordination. + * In cluster environments, ensures only the leader node performs initialization while followers wait. + * Connects to each configured MCP server, inspects capabilities and tools, then caches the results. + * Categorizes servers as either shared app servers (auto-started) or shared user servers (OAuth/on-demand). + * Uses a timeout mechanism to prevent hanging on unresponsive servers during initialization. + */ +export class MCPServersInitializer { + /** + * Initializes MCP servers with distributed leader-follower coordination. + * + * Design rationale: + * - Handles leader crash scenarios: If the leader crashes during initialization, all followers + * will independently attempt initialization after a 3-second delay. The first to become leader + * will complete the initialization. + * - Only the leader performs the actual initialization work (reset caches, inspect servers). + * When complete, the leader signals completion via `statusCache`, allowing followers to proceed. + * - Followers wait and poll `statusCache` until the leader finishes, ensuring only one node + * performs the expensive initialization operations. + */ + public static async initialize(rawConfigs: t.MCPServers): Promise { + if (await statusCache.isInitialized()) return; + + if (await isLeader()) { + // Leader performs initialization + await statusCache.reset(); + await registry.reset(); + const serverNames = Object.keys(rawConfigs); + await Promise.allSettled( + serverNames.map((serverName) => + withTimeout( + MCPServersInitializer.initializeServer(serverName, rawConfigs[serverName]), + MCP_INIT_TIMEOUT_MS, + `${MCPServersInitializer.prefix(serverName)} Server initialization timed out`, + logger.error, + ), + ), + ); + await statusCache.setInitialized(true); + } else { + // Followers try again after a delay if not initialized + await new Promise((resolve) => setTimeout(resolve, 3000)); + await this.initialize(rawConfigs); + } + } + + /** Initializes a single server with all its metadata and adds it to appropriate collections */ + private static async initializeServer( + serverName: string, + rawConfig: t.MCPOptions, + ): Promise { + try { + const config = await MCPServerInspector.inspect(serverName, rawConfig); + + if (config.startup === false || config.requiresOAuth) { + await registry.sharedUserServers.add(serverName, config); + } else { + await registry.sharedAppServers.add(serverName, config); + } + MCPServersInitializer.logParsedConfig(serverName, config); + } catch (error) { + logger.error(`${MCPServersInitializer.prefix(serverName)} Failed to initialize:`, error); + } + } + + // Logs server configuration summary after initialization + private static logParsedConfig(serverName: string, config: ParsedServerConfig): void { + const prefix = MCPServersInitializer.prefix(serverName); + logger.info(`${prefix} -------------------------------------------------┐`); + logger.info(`${prefix} URL: ${config.url ? sanitizeUrlForLogging(config.url) : 'N/A'}`); + logger.info(`${prefix} OAuth Required: ${config.requiresOAuth}`); + logger.info(`${prefix} Capabilities: ${config.capabilities}`); + logger.info(`${prefix} Tools: ${config.tools}`); + logger.info(`${prefix} Server Instructions: ${config.serverInstructions}`); + logger.info(`${prefix} Initialized in: ${config.initDuration ?? 'N/A'}ms`); + logger.info(`${prefix} -------------------------------------------------┘`); + } + + // Returns formatted log prefix for server messages + private static prefix(serverName: string): string { + return `[MCP][${serverName}]`; + } +} diff --git a/packages/api/src/mcp/registry/MCPServersRegistry.ts b/packages/api/src/mcp/registry/MCPServersRegistry.ts new file mode 100644 index 0000000000..8c6ef13e9c --- /dev/null +++ b/packages/api/src/mcp/registry/MCPServersRegistry.ts @@ -0,0 +1,91 @@ +import type * as t from '~/mcp/types'; +import { + ServerConfigsCacheFactory, + type ServerConfigsCache, +} from './cache/ServerConfigsCacheFactory'; + +/** + * Central registry for managing MCP server configurations across different scopes and users. + * Maintains three categories of server configurations: + * - Shared App Servers: Auto-started servers available to all users (initialized at startup) + * - Shared User Servers: User-scope servers that require OAuth or on-demand startup + * - Private User Servers: Per-user configurations dynamically added during runtime + * + * Provides a unified interface for retrieving server configs with proper fallback hierarchy: + * checks shared app servers first, then shared user servers, then private user servers. + * Handles server lifecycle operations including adding, removing, and querying configurations. + */ +class MCPServersRegistry { + public readonly sharedAppServers = ServerConfigsCacheFactory.create('App', true); + public readonly sharedUserServers = ServerConfigsCacheFactory.create('User', true); + private readonly privateUserServers: Map = new Map(); + + public async addPrivateUserServer( + userId: string, + serverName: string, + config: t.ParsedServerConfig, + ): Promise { + if (!this.privateUserServers.has(userId)) { + const cache = ServerConfigsCacheFactory.create(`User(${userId})`, false); + this.privateUserServers.set(userId, cache); + } + await this.privateUserServers.get(userId)!.add(serverName, config); + } + + public async updatePrivateUserServer( + userId: string, + serverName: string, + config: t.ParsedServerConfig, + ): Promise { + const userCache = this.privateUserServers.get(userId); + if (!userCache) throw new Error(`No private servers found for user "${userId}".`); + await userCache.update(serverName, config); + } + + public async removePrivateUserServer(userId: string, serverName: string): Promise { + await this.privateUserServers.get(userId)?.remove(serverName); + } + + public async getServerConfig( + serverName: string, + userId?: string, + ): Promise { + const sharedAppServer = await this.sharedAppServers.get(serverName); + if (sharedAppServer) return sharedAppServer; + + const sharedUserServer = await this.sharedUserServers.get(serverName); + if (sharedUserServer) return sharedUserServer; + + const privateUserServer = await this.privateUserServers.get(userId)?.get(serverName); + if (privateUserServer) return privateUserServer; + + return undefined; + } + + public async getAllServerConfigs(userId?: string): Promise> { + return { + ...(await this.sharedAppServers.getAll()), + ...(await this.sharedUserServers.getAll()), + ...((await this.privateUserServers.get(userId)?.getAll()) ?? {}), + }; + } + + // TODO: This is currently used to determine if a server requires OAuth. However, this info can + // can be determined through config.requiresOAuth. Refactor usages and remove this method. + public async getOAuthServers(userId?: string): Promise> { + const allServers = await this.getAllServerConfigs(userId); + const oauthServers = Object.entries(allServers).filter(([, config]) => config.requiresOAuth); + return new Set(oauthServers.map(([name]) => name)); + } + + public async reset(): Promise { + await this.sharedAppServers.reset(); + await this.sharedUserServers.reset(); + for (const cache of this.privateUserServers.values()) { + await cache.reset(); + } + this.privateUserServers.clear(); + } +} + +export const mcpServersRegistry = new MCPServersRegistry(); diff --git a/packages/api/src/mcp/registry/__tests__/MCPServerInspector.test.ts b/packages/api/src/mcp/registry/__tests__/MCPServerInspector.test.ts new file mode 100644 index 0000000000..0e4a6ebbe9 --- /dev/null +++ b/packages/api/src/mcp/registry/__tests__/MCPServerInspector.test.ts @@ -0,0 +1,338 @@ +import type { MCPConnection } from '~/mcp/connection'; +import type * as t from '~/mcp/types'; +import { MCPServerInspector } from '~/mcp/registry/MCPServerInspector'; +import { detectOAuthRequirement } from '~/mcp/oauth'; +import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; +import { createMockConnection } from './mcpConnectionsMock.helper'; + +// Mock external dependencies +jest.mock('../../oauth/detectOAuth'); +jest.mock('../../MCPConnectionFactory'); + +const mockDetectOAuthRequirement = detectOAuthRequirement as jest.MockedFunction< + typeof detectOAuthRequirement +>; + +describe('MCPServerInspector', () => { + let mockConnection: jest.Mocked; + + beforeEach(() => { + mockConnection = createMockConnection('test_server'); + jest.clearAllMocks(); + }); + + describe('inspect()', () => { + it('should process env and fetch all metadata for non-OAuth stdio server with serverInstructions=true', async () => { + const rawConfig: t.MCPOptions = { + type: 'stdio', + command: 'node', + args: ['server.js'], + serverInstructions: true, + }; + + mockDetectOAuthRequirement.mockResolvedValue({ + requiresOAuth: false, + method: 'no-metadata-found', + }); + + const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection); + + expect(result).toEqual({ + type: 'stdio', + command: 'node', + args: ['server.js'], + serverInstructions: 'instructions for test_server', + requiresOAuth: false, + capabilities: + '{"tools":{"listChanged":true},"resources":{"listChanged":true},"prompts":{"get":"getPrompts for test_server"}}', + tools: 'listFiles', + toolFunctions: { + listFiles_mcp_test_server: expect.objectContaining({ + type: 'function', + function: expect.objectContaining({ + name: 'listFiles_mcp_test_server', + }), + }), + }, + initDuration: expect.any(Number), + }); + }); + + it('should detect OAuth and skip capabilities fetch for streamable-http server', async () => { + const rawConfig: t.MCPOptions = { + type: 'streamable-http', + url: 'https://api.example.com/mcp', + }; + + mockDetectOAuthRequirement.mockResolvedValue({ + requiresOAuth: true, + method: 'protected-resource-metadata', + }); + + const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection); + + expect(result).toEqual({ + type: 'streamable-http', + url: 'https://api.example.com/mcp', + requiresOAuth: true, + oauthMetadata: undefined, + initDuration: expect.any(Number), + }); + }); + + it('should skip capabilities fetch when startup=false', async () => { + const rawConfig: t.MCPOptions = { + type: 'stdio', + command: 'node', + args: ['server.js'], + startup: false, + }; + + const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection); + + expect(result).toEqual({ + type: 'stdio', + command: 'node', + args: ['server.js'], + startup: false, + requiresOAuth: false, + initDuration: expect.any(Number), + }); + }); + + it('should keep custom serverInstructions string and not fetch from server', async () => { + const rawConfig: t.MCPOptions = { + type: 'stdio', + command: 'node', + args: ['server.js'], + serverInstructions: 'Custom instructions here', + }; + + mockDetectOAuthRequirement.mockResolvedValue({ + requiresOAuth: false, + method: 'no-metadata-found', + }); + + const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection); + + expect(result).toEqual({ + type: 'stdio', + command: 'node', + args: ['server.js'], + serverInstructions: 'Custom instructions here', + requiresOAuth: false, + capabilities: + '{"tools":{"listChanged":true},"resources":{"listChanged":true},"prompts":{"get":"getPrompts for test_server"}}', + tools: 'listFiles', + toolFunctions: expect.any(Object), + initDuration: expect.any(Number), + }); + }); + + it('should handle serverInstructions as string "true" and fetch from server', async () => { + const rawConfig: t.MCPOptions = { + type: 'stdio', + command: 'node', + args: ['server.js'], + serverInstructions: 'true', // String "true" from YAML + }; + + mockDetectOAuthRequirement.mockResolvedValue({ + requiresOAuth: false, + method: 'no-metadata-found', + }); + + const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection); + + expect(result).toEqual({ + type: 'stdio', + command: 'node', + args: ['server.js'], + serverInstructions: 'instructions for test_server', + requiresOAuth: false, + capabilities: + '{"tools":{"listChanged":true},"resources":{"listChanged":true},"prompts":{"get":"getPrompts for test_server"}}', + tools: 'listFiles', + toolFunctions: expect.any(Object), + initDuration: expect.any(Number), + }); + }); + + it('should handle predefined requiresOAuth without detection', async () => { + const rawConfig: t.MCPOptions = { + type: 'sse', + url: 'https://api.example.com/sse', + requiresOAuth: true, + }; + + const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection); + + expect(result).toEqual({ + type: 'sse', + url: 'https://api.example.com/sse', + requiresOAuth: true, + initDuration: expect.any(Number), + }); + }); + + it('should fetch capabilities when server has no tools', async () => { + const rawConfig: t.MCPOptions = { + type: 'stdio', + command: 'node', + args: ['server.js'], + }; + + mockDetectOAuthRequirement.mockResolvedValue({ + requiresOAuth: false, + method: 'no-metadata-found', + }); + + // Mock server with no tools + mockConnection.client.listTools = jest.fn().mockResolvedValue({ tools: [] }); + + const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection); + + expect(result).toEqual({ + type: 'stdio', + command: 'node', + args: ['server.js'], + requiresOAuth: false, + capabilities: + '{"tools":{"listChanged":true},"resources":{"listChanged":true},"prompts":{"get":"getPrompts for test_server"}}', + tools: '', + toolFunctions: {}, + initDuration: expect.any(Number), + }); + }); + + it('should create temporary connection when no connection is provided', async () => { + const rawConfig: t.MCPOptions = { + type: 'stdio', + command: 'node', + args: ['server.js'], + serverInstructions: true, + }; + + const tempMockConnection = createMockConnection('test_server'); + (MCPConnectionFactory.create as jest.Mock).mockResolvedValue(tempMockConnection); + + mockDetectOAuthRequirement.mockResolvedValue({ + requiresOAuth: false, + method: 'no-metadata-found', + }); + + const result = await MCPServerInspector.inspect('test_server', rawConfig); + + // Verify factory was called to create connection + expect(MCPConnectionFactory.create).toHaveBeenCalledWith({ + serverName: 'test_server', + serverConfig: expect.objectContaining({ type: 'stdio', command: 'node' }), + }); + + // Verify temporary connection was disconnected + expect(tempMockConnection.disconnect).toHaveBeenCalled(); + + // Verify result is correct + expect(result).toEqual({ + type: 'stdio', + command: 'node', + args: ['server.js'], + serverInstructions: 'instructions for test_server', + requiresOAuth: false, + capabilities: + '{"tools":{"listChanged":true},"resources":{"listChanged":true},"prompts":{"get":"getPrompts for test_server"}}', + tools: 'listFiles', + toolFunctions: expect.any(Object), + initDuration: expect.any(Number), + }); + }); + + it('should not create temporary connection when connection is provided', async () => { + const rawConfig: t.MCPOptions = { + type: 'stdio', + command: 'node', + args: ['server.js'], + serverInstructions: true, + }; + + mockDetectOAuthRequirement.mockResolvedValue({ + requiresOAuth: false, + method: 'no-metadata-found', + }); + + await MCPServerInspector.inspect('test_server', rawConfig, mockConnection); + + // Verify factory was NOT called + expect(MCPConnectionFactory.create).not.toHaveBeenCalled(); + + // Verify provided connection was NOT disconnected + expect(mockConnection.disconnect).not.toHaveBeenCalled(); + }); + }); + + describe('getToolFunctions()', () => { + it('should convert MCP tools to LibreChat tool functions format', async () => { + mockConnection.client.listTools = jest.fn().mockResolvedValue({ + tools: [ + { + name: 'file_read', + description: 'Read a file', + inputSchema: { + type: 'object', + properties: { path: { type: 'string' } }, + }, + }, + { + name: 'file_write', + description: 'Write a file', + inputSchema: { + type: 'object', + properties: { + path: { type: 'string' }, + content: { type: 'string' }, + }, + }, + }, + ], + }); + + const result = await MCPServerInspector.getToolFunctions('my_server', mockConnection); + + expect(result).toEqual({ + file_read_mcp_my_server: { + type: 'function', + function: { + name: 'file_read_mcp_my_server', + description: 'Read a file', + parameters: { + type: 'object', + properties: { path: { type: 'string' } }, + }, + }, + }, + file_write_mcp_my_server: { + type: 'function', + function: { + name: 'file_write_mcp_my_server', + description: 'Write a file', + parameters: { + type: 'object', + properties: { + path: { type: 'string' }, + content: { type: 'string' }, + }, + }, + }, + }, + }); + }); + + it('should handle empty tools list', async () => { + mockConnection.client.listTools = jest.fn().mockResolvedValue({ tools: [] }); + + const result = await MCPServerInspector.getToolFunctions('my_server', mockConnection); + + expect(result).toEqual({}); + }); + }); +}); diff --git a/packages/api/src/mcp/registry/__tests__/MCPServersInitializer.cache_integration.spec.ts b/packages/api/src/mcp/registry/__tests__/MCPServersInitializer.cache_integration.spec.ts new file mode 100644 index 0000000000..820cdfa54e --- /dev/null +++ b/packages/api/src/mcp/registry/__tests__/MCPServersInitializer.cache_integration.spec.ts @@ -0,0 +1,301 @@ +import { expect } from '@playwright/test'; +import type * as t from '~/mcp/types'; +import type { MCPConnection } from '~/mcp/connection'; + +// Mock isLeader to always return true to avoid lock contention during parallel operations +jest.mock('~/cluster', () => ({ + ...jest.requireActual('~/cluster'), + isLeader: jest.fn().mockResolvedValue(true), +})); + +describe('MCPServersInitializer Redis Integration Tests', () => { + let MCPServersInitializer: typeof import('../MCPServersInitializer').MCPServersInitializer; + let registry: typeof import('../MCPServersRegistry').mcpServersRegistry; + let registryStatusCache: typeof import('../cache/RegistryStatusCache').registryStatusCache; + let MCPServerInspector: typeof import('../MCPServerInspector').MCPServerInspector; + let MCPConnectionFactory: typeof import('~/mcp/MCPConnectionFactory').MCPConnectionFactory; + let keyvRedisClient: Awaited['keyvRedisClient']; + let LeaderElection: typeof import('~/cluster/LeaderElection').LeaderElection; + let leaderInstance: InstanceType; + + const testConfigs: t.MCPServers = { + disabled_server: { + type: 'stdio', + command: 'node', + args: ['disabled.js'], + startup: false, + }, + oauth_server: { + type: 'streamable-http', + url: 'https://api.example.com/mcp', + }, + file_tools_server: { + type: 'stdio', + command: 'node', + args: ['tools.js'], + }, + search_tools_server: { + type: 'stdio', + command: 'node', + args: ['instructions.js'], + }, + }; + + const testParsedConfigs: Record = { + disabled_server: { + type: 'stdio', + command: 'node', + args: ['disabled.js'], + startup: false, + requiresOAuth: false, + }, + oauth_server: { + type: 'streamable-http', + url: 'https://api.example.com/mcp', + requiresOAuth: true, + }, + file_tools_server: { + type: 'stdio', + command: 'node', + args: ['tools.js'], + requiresOAuth: false, + serverInstructions: 'Instructions for file_tools_server', + tools: 'file_read, file_write', + capabilities: '{"tools":{"listChanged":true}}', + toolFunctions: { + file_read_mcp_file_tools_server: { + type: 'function', + function: { + name: 'file_read_mcp_file_tools_server', + description: 'Read a file', + parameters: { type: 'object' }, + }, + }, + }, + }, + search_tools_server: { + type: 'stdio', + command: 'node', + args: ['instructions.js'], + requiresOAuth: false, + serverInstructions: 'Instructions for search_tools_server', + capabilities: '{"tools":{"listChanged":true}}', + tools: 'search', + toolFunctions: { + search_mcp_search_tools_server: { + type: 'function', + function: { + name: 'search_mcp_search_tools_server', + description: 'Search tool', + parameters: { type: 'object' }, + }, + }, + }, + }, + }; + + beforeAll(async () => { + // Set up environment variables for Redis (only if not already set) + process.env.USE_REDIS = process.env.USE_REDIS ?? 'true'; + process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379'; + process.env.REDIS_KEY_PREFIX = + process.env.REDIS_KEY_PREFIX ?? 'MCPServersInitializer-IntegrationTest'; + + // Import modules after setting env vars + const initializerModule = await import('../MCPServersInitializer'); + const registryModule = await import('../MCPServersRegistry'); + const statusCacheModule = await import('../cache/RegistryStatusCache'); + const inspectorModule = await import('../MCPServerInspector'); + const connectionFactoryModule = await import('~/mcp/MCPConnectionFactory'); + const redisClients = await import('~/cache/redisClients'); + const leaderElectionModule = await import('~/cluster/LeaderElection'); + + MCPServersInitializer = initializerModule.MCPServersInitializer; + registry = registryModule.mcpServersRegistry; + registryStatusCache = statusCacheModule.registryStatusCache; + MCPServerInspector = inspectorModule.MCPServerInspector; + MCPConnectionFactory = connectionFactoryModule.MCPConnectionFactory; + keyvRedisClient = redisClients.keyvRedisClient; + LeaderElection = leaderElectionModule.LeaderElection; + + // Ensure Redis is connected + if (!keyvRedisClient) throw new Error('Redis client is not initialized'); + + // Wait for Redis to be ready + if (!keyvRedisClient.isOpen) await keyvRedisClient.connect(); + + // Become leader so we can perform write operations + leaderInstance = new LeaderElection(); + const isLeader = await leaderInstance.isLeader(); + expect(isLeader).toBe(true); + }); + + beforeEach(async () => { + // Ensure we're still the leader + const isLeader = await leaderInstance.isLeader(); + if (!isLeader) { + throw new Error('Lost leader status before test'); + } + + // Mock MCPServerInspector.inspect to return parsed config + jest.spyOn(MCPServerInspector, 'inspect').mockImplementation(async (serverName: string) => { + return { + ...testParsedConfigs[serverName], + _processedByInspector: true, + } as unknown as t.ParsedServerConfig; + }); + + // Mock MCPConnection + const mockConnection = { + disconnect: jest.fn().mockResolvedValue(undefined), + } as unknown as jest.Mocked; + + // Mock MCPConnectionFactory + jest.spyOn(MCPConnectionFactory, 'create').mockResolvedValue(mockConnection); + + // Reset caches before each test + await registryStatusCache.reset(); + await registry.reset(); + }); + + afterEach(async () => { + // Clean up: clear all test keys from Redis + if (keyvRedisClient) { + const pattern = '*MCPServersInitializer-IntegrationTest*'; + if ('scanIterator' in keyvRedisClient) { + for await (const key of keyvRedisClient.scanIterator({ MATCH: pattern })) { + await keyvRedisClient.del(key); + } + } + } + + jest.restoreAllMocks(); + }); + + afterAll(async () => { + // Resign as leader + if (leaderInstance) await leaderInstance.resign(); + + // Close Redis connection + if (keyvRedisClient?.isOpen) await keyvRedisClient.disconnect(); + }); + + describe('initialize()', () => { + it('should reset registry and status cache before initialization', async () => { + // Pre-populate registry with some old servers + await registry.sharedAppServers.add('old_app_server', testParsedConfigs.file_tools_server); + await registry.sharedUserServers.add('old_user_server', testParsedConfigs.oauth_server); + + // Initialize with new configs (this should reset first) + await MCPServersInitializer.initialize(testConfigs); + + // Verify old servers are gone + expect(await registry.sharedAppServers.get('old_app_server')).toBeUndefined(); + expect(await registry.sharedUserServers.get('old_user_server')).toBeUndefined(); + + // Verify new servers are present + expect(await registry.sharedAppServers.get('file_tools_server')).toBeDefined(); + expect(await registry.sharedUserServers.get('oauth_server')).toBeDefined(); + expect(await registryStatusCache.isInitialized()).toBe(true); + }); + + it('should skip initialization if already initialized', async () => { + // First initialization + await MCPServersInitializer.initialize(testConfigs); + + // Clear mock calls + jest.clearAllMocks(); + + // Second initialization should skip due to static flag + await MCPServersInitializer.initialize(testConfigs); + + // Verify inspect was not called again + expect(MCPServerInspector.inspect).not.toHaveBeenCalled(); + }); + + it('should add disabled servers to sharedUserServers', async () => { + await MCPServersInitializer.initialize(testConfigs); + + const disabledServer = await registry.sharedUserServers.get('disabled_server'); + expect(disabledServer).toBeDefined(); + expect(disabledServer).toMatchObject({ + ...testParsedConfigs.disabled_server, + _processedByInspector: true, + }); + }); + + it('should add OAuth servers to sharedUserServers', async () => { + await MCPServersInitializer.initialize(testConfigs); + + const oauthServer = await registry.sharedUserServers.get('oauth_server'); + expect(oauthServer).toBeDefined(); + expect(oauthServer).toMatchObject({ + ...testParsedConfigs.oauth_server, + _processedByInspector: true, + }); + }); + + it('should add enabled non-OAuth servers to sharedAppServers', async () => { + await MCPServersInitializer.initialize(testConfigs); + + const fileToolsServer = await registry.sharedAppServers.get('file_tools_server'); + expect(fileToolsServer).toBeDefined(); + expect(fileToolsServer).toMatchObject({ + ...testParsedConfigs.file_tools_server, + _processedByInspector: true, + }); + + const searchToolsServer = await registry.sharedAppServers.get('search_tools_server'); + expect(searchToolsServer).toBeDefined(); + expect(searchToolsServer).toMatchObject({ + ...testParsedConfigs.search_tools_server, + _processedByInspector: true, + }); + }); + + it('should successfully initialize all servers', async () => { + await MCPServersInitializer.initialize(testConfigs); + + // Verify all servers were added to appropriate registries + expect(await registry.sharedUserServers.get('disabled_server')).toBeDefined(); + expect(await registry.sharedUserServers.get('oauth_server')).toBeDefined(); + expect(await registry.sharedAppServers.get('file_tools_server')).toBeDefined(); + expect(await registry.sharedAppServers.get('search_tools_server')).toBeDefined(); + }); + + it('should handle inspection failures gracefully', async () => { + // Mock inspection failure for one server + jest.spyOn(MCPServerInspector, 'inspect').mockImplementation(async (serverName: string) => { + if (serverName === 'file_tools_server') { + throw new Error('Inspection failed'); + } + return { + ...testParsedConfigs[serverName], + _processedByInspector: true, + } as unknown as t.ParsedServerConfig; + }); + + await MCPServersInitializer.initialize(testConfigs); + + // Verify other servers were still processed + const disabledServer = await registry.sharedUserServers.get('disabled_server'); + expect(disabledServer).toBeDefined(); + + const oauthServer = await registry.sharedUserServers.get('oauth_server'); + expect(oauthServer).toBeDefined(); + + const searchToolsServer = await registry.sharedAppServers.get('search_tools_server'); + expect(searchToolsServer).toBeDefined(); + + // Verify file_tools_server was not added (due to inspection failure) + const fileToolsServer = await registry.sharedAppServers.get('file_tools_server'); + expect(fileToolsServer).toBeUndefined(); + }); + + it('should set initialized status after completion', async () => { + await MCPServersInitializer.initialize(testConfigs); + + expect(await registryStatusCache.isInitialized()).toBe(true); + }); + }); +}); diff --git a/packages/api/src/mcp/registry/__tests__/MCPServersInitializer.test.ts b/packages/api/src/mcp/registry/__tests__/MCPServersInitializer.test.ts new file mode 100644 index 0000000000..2ce8d09d93 --- /dev/null +++ b/packages/api/src/mcp/registry/__tests__/MCPServersInitializer.test.ts @@ -0,0 +1,292 @@ +import { logger } from '@librechat/data-schemas'; +import * as t from '~/mcp/types'; +import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory'; +import { MCPServersInitializer } from '~/mcp/registry/MCPServersInitializer'; +import { MCPConnection } from '~/mcp/connection'; +import { registryStatusCache } from '~/mcp/registry/cache/RegistryStatusCache'; +import { MCPServerInspector } from '~/mcp/registry/MCPServerInspector'; +import { mcpServersRegistry as registry } from '~/mcp/registry/MCPServersRegistry'; + +// Mock external dependencies +jest.mock('../../MCPConnectionFactory'); +jest.mock('../../connection'); +jest.mock('../../registry/MCPServerInspector'); +jest.mock('~/cluster', () => ({ + isLeader: jest.fn().mockResolvedValue(true), +})); +jest.mock('@librechat/data-schemas', () => ({ + logger: { + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + debug: jest.fn(), + }, +})); + +const mockLogger = logger as jest.Mocked; +const mockInspect = MCPServerInspector.inspect as jest.MockedFunction< + typeof MCPServerInspector.inspect +>; + +describe('MCPServersInitializer', () => { + let mockConnection: jest.Mocked; + + const testConfigs: t.MCPServers = { + disabled_server: { + type: 'stdio', + command: 'node', + args: ['disabled.js'], + startup: false, + }, + oauth_server: { + type: 'streamable-http', + url: 'https://api.example.com/mcp', + }, + file_tools_server: { + type: 'stdio', + command: 'node', + args: ['tools.js'], + }, + search_tools_server: { + type: 'stdio', + command: 'node', + args: ['instructions.js'], + }, + }; + + const testParsedConfigs: Record = { + disabled_server: { + type: 'stdio', + command: 'node', + args: ['disabled.js'], + startup: false, + requiresOAuth: false, + }, + oauth_server: { + type: 'streamable-http', + url: 'https://api.example.com/mcp', + requiresOAuth: true, + }, + file_tools_server: { + type: 'stdio', + command: 'node', + args: ['tools.js'], + requiresOAuth: false, + serverInstructions: 'Instructions for file_tools_server', + tools: 'file_read, file_write', + capabilities: '{"tools":{"listChanged":true}}', + toolFunctions: { + file_read_mcp_file_tools_server: { + type: 'function', + function: { + name: 'file_read_mcp_file_tools_server', + description: 'Read a file', + parameters: { type: 'object' }, + }, + }, + }, + }, + search_tools_server: { + type: 'stdio', + command: 'node', + args: ['instructions.js'], + requiresOAuth: false, + serverInstructions: 'Instructions for search_tools_server', + capabilities: '{"tools":{"listChanged":true}}', + tools: 'search', + toolFunctions: { + search_mcp_search_tools_server: { + type: 'function', + function: { + name: 'search_mcp_search_tools_server', + description: 'Search tool', + parameters: { type: 'object' }, + }, + }, + }, + }, + }; + + beforeEach(async () => { + // Setup MCPConnection mock + mockConnection = { + disconnect: jest.fn().mockResolvedValue(undefined), + } as unknown as jest.Mocked; + + // Setup MCPConnectionFactory mock + (MCPConnectionFactory.create as jest.Mock).mockResolvedValue(mockConnection); + + // Mock MCPServerInspector.inspect to return parsed config + mockInspect.mockImplementation(async (serverName: string) => { + return { + ...testParsedConfigs[serverName], + _processedByInspector: true, + } as unknown as t.ParsedServerConfig; + }); + + // Reset caches before each test + await registryStatusCache.reset(); + await registry.sharedAppServers.reset(); + await registry.sharedUserServers.reset(); + jest.clearAllMocks(); + }); + + afterEach(() => { + jest.restoreAllMocks(); + }); + + describe('initialize()', () => { + it('should reset registry and status cache before initialization', async () => { + // Pre-populate registry with some old servers + await registry.sharedAppServers.add('old_app_server', testParsedConfigs.file_tools_server); + await registry.sharedUserServers.add('old_user_server', testParsedConfigs.oauth_server); + + // Initialize with new configs (this should reset first) + await MCPServersInitializer.initialize(testConfigs); + + // Verify old servers are gone + expect(await registry.sharedAppServers.get('old_app_server')).toBeUndefined(); + expect(await registry.sharedUserServers.get('old_user_server')).toBeUndefined(); + + // Verify new servers are present + expect(await registry.sharedAppServers.get('file_tools_server')).toBeDefined(); + expect(await registry.sharedUserServers.get('oauth_server')).toBeDefined(); + expect(await registryStatusCache.isInitialized()).toBe(true); + }); + + it('should skip initialization if already initialized (Redis flag)', async () => { + // First initialization + await MCPServersInitializer.initialize(testConfigs); + + jest.clearAllMocks(); + + // Second initialization should skip due to Redis cache flag + await MCPServersInitializer.initialize(testConfigs); + + expect(mockInspect).not.toHaveBeenCalled(); + }); + + it('should process all server configs through inspector', async () => { + await MCPServersInitializer.initialize(testConfigs); + + // Verify all configs were processed by inspector (without connection parameter) + expect(mockInspect).toHaveBeenCalledTimes(4); + expect(mockInspect).toHaveBeenCalledWith('disabled_server', testConfigs.disabled_server); + expect(mockInspect).toHaveBeenCalledWith('oauth_server', testConfigs.oauth_server); + expect(mockInspect).toHaveBeenCalledWith('file_tools_server', testConfigs.file_tools_server); + expect(mockInspect).toHaveBeenCalledWith( + 'search_tools_server', + testConfigs.search_tools_server, + ); + }); + + it('should add disabled servers to sharedUserServers', async () => { + await MCPServersInitializer.initialize(testConfigs); + + const disabledServer = await registry.sharedUserServers.get('disabled_server'); + expect(disabledServer).toBeDefined(); + expect(disabledServer).toMatchObject({ + ...testParsedConfigs.disabled_server, + _processedByInspector: true, + }); + }); + + it('should add OAuth servers to sharedUserServers', async () => { + await MCPServersInitializer.initialize(testConfigs); + + const oauthServer = await registry.sharedUserServers.get('oauth_server'); + expect(oauthServer).toBeDefined(); + expect(oauthServer).toMatchObject({ + ...testParsedConfigs.oauth_server, + _processedByInspector: true, + }); + }); + + it('should add enabled non-OAuth servers to sharedAppServers', async () => { + await MCPServersInitializer.initialize(testConfigs); + + const fileToolsServer = await registry.sharedAppServers.get('file_tools_server'); + expect(fileToolsServer).toBeDefined(); + expect(fileToolsServer).toMatchObject({ + ...testParsedConfigs.file_tools_server, + _processedByInspector: true, + }); + + const searchToolsServer = await registry.sharedAppServers.get('search_tools_server'); + expect(searchToolsServer).toBeDefined(); + expect(searchToolsServer).toMatchObject({ + ...testParsedConfigs.search_tools_server, + _processedByInspector: true, + }); + }); + + it('should successfully initialize all servers', async () => { + await MCPServersInitializer.initialize(testConfigs); + + // Verify all servers were added to appropriate registries + expect(await registry.sharedUserServers.get('disabled_server')).toBeDefined(); + expect(await registry.sharedUserServers.get('oauth_server')).toBeDefined(); + expect(await registry.sharedAppServers.get('file_tools_server')).toBeDefined(); + expect(await registry.sharedAppServers.get('search_tools_server')).toBeDefined(); + }); + + it('should handle inspection failures gracefully', async () => { + // Mock inspection failure for one server + mockInspect.mockImplementation(async (serverName: string) => { + if (serverName === 'file_tools_server') { + throw new Error('Inspection failed'); + } + return { + ...testParsedConfigs[serverName], + _processedByInspector: true, + } as unknown as t.ParsedServerConfig; + }); + + await MCPServersInitializer.initialize(testConfigs); + + // Verify other servers were still processed + const disabledServer = await registry.sharedUserServers.get('disabled_server'); + expect(disabledServer).toBeDefined(); + + const oauthServer = await registry.sharedUserServers.get('oauth_server'); + expect(oauthServer).toBeDefined(); + + const searchToolsServer = await registry.sharedAppServers.get('search_tools_server'); + expect(searchToolsServer).toBeDefined(); + + // Verify file_tools_server was not added (due to inspection failure) + const fileToolsServer = await registry.sharedAppServers.get('file_tools_server'); + expect(fileToolsServer).toBeUndefined(); + }); + + it('should log server configuration after initialization', async () => { + await MCPServersInitializer.initialize(testConfigs); + + // Verify logging occurred for each server + expect(mockLogger.info).toHaveBeenCalledWith( + expect.stringContaining('[MCP][disabled_server]'), + ); + expect(mockLogger.info).toHaveBeenCalledWith(expect.stringContaining('[MCP][oauth_server]')); + expect(mockLogger.info).toHaveBeenCalledWith( + expect.stringContaining('[MCP][file_tools_server]'), + ); + }); + + it('should use Promise.allSettled for parallel server initialization', async () => { + const allSettledSpy = jest.spyOn(Promise, 'allSettled'); + + await MCPServersInitializer.initialize(testConfigs); + + expect(allSettledSpy).toHaveBeenCalledWith(expect.arrayContaining([expect.any(Promise)])); + expect(allSettledSpy).toHaveBeenCalledTimes(1); + + allSettledSpy.mockRestore(); + }); + + it('should set initialized status after completion', async () => { + await MCPServersInitializer.initialize(testConfigs); + + expect(await registryStatusCache.isInitialized()).toBe(true); + }); + }); +}); diff --git a/packages/api/src/mcp/registry/__tests__/MCPServersRegistry.cache_integration.spec.ts b/packages/api/src/mcp/registry/__tests__/MCPServersRegistry.cache_integration.spec.ts new file mode 100644 index 0000000000..68e9291d46 --- /dev/null +++ b/packages/api/src/mcp/registry/__tests__/MCPServersRegistry.cache_integration.spec.ts @@ -0,0 +1,227 @@ +import { expect } from '@playwright/test'; +import type * as t from '~/mcp/types'; + +/** + * Integration tests for MCPServersRegistry using Redis-backed cache. + * For unit tests using in-memory cache, see MCPServersRegistry.test.ts + */ +describe('MCPServersRegistry Redis Integration Tests', () => { + let registry: typeof import('../MCPServersRegistry').mcpServersRegistry; + let keyvRedisClient: Awaited['keyvRedisClient']; + let LeaderElection: typeof import('~/cluster/LeaderElection').LeaderElection; + let leaderInstance: InstanceType; + + const testParsedConfig: t.ParsedServerConfig = { + type: 'stdio', + command: 'node', + args: ['tools.js'], + requiresOAuth: false, + serverInstructions: 'Instructions for file_tools_server', + tools: 'file_read, file_write', + capabilities: '{"tools":{"listChanged":true}}', + toolFunctions: { + file_read_mcp_file_tools_server: { + type: 'function', + function: { + name: 'file_read_mcp_file_tools_server', + description: 'Read a file', + parameters: { type: 'object' }, + }, + }, + }, + }; + + beforeAll(async () => { + // Set up environment variables for Redis (only if not already set) + process.env.USE_REDIS = process.env.USE_REDIS ?? 'true'; + process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379'; + process.env.REDIS_KEY_PREFIX = + process.env.REDIS_KEY_PREFIX ?? 'MCPServersRegistry-IntegrationTest'; + + // Import modules after setting env vars + const registryModule = await import('../MCPServersRegistry'); + const redisClients = await import('~/cache/redisClients'); + const leaderElectionModule = await import('~/cluster/LeaderElection'); + + registry = registryModule.mcpServersRegistry; + keyvRedisClient = redisClients.keyvRedisClient; + LeaderElection = leaderElectionModule.LeaderElection; + + // Ensure Redis is connected + if (!keyvRedisClient) throw new Error('Redis client is not initialized'); + + // Wait for Redis to be ready + if (!keyvRedisClient.isOpen) await keyvRedisClient.connect(); + + // Become leader so we can perform write operations + leaderInstance = new LeaderElection(); + const isLeader = await leaderInstance.isLeader(); + expect(isLeader).toBe(true); + }); + + afterEach(async () => { + // Clean up: reset registry to clear all test data + await registry.reset(); + + // Also clean up any remaining test keys from Redis + if (keyvRedisClient) { + const pattern = '*MCPServersRegistry-IntegrationTest*'; + if ('scanIterator' in keyvRedisClient) { + for await (const key of keyvRedisClient.scanIterator({ MATCH: pattern })) { + await keyvRedisClient.del(key); + } + } + } + }); + + afterAll(async () => { + // Resign as leader + if (leaderInstance) await leaderInstance.resign(); + + // Close Redis connection + if (keyvRedisClient?.isOpen) await keyvRedisClient.disconnect(); + }); + + describe('private user servers', () => { + it('should add and remove private user server', async () => { + const userId = 'user123'; + const serverName = 'private_server'; + + // Add private user server + await registry.addPrivateUserServer(userId, serverName, testParsedConfig); + + // Verify server was added + const retrievedConfig = await registry.getServerConfig(serverName, userId); + expect(retrievedConfig).toEqual(testParsedConfig); + + // Remove private user server + await registry.removePrivateUserServer(userId, serverName); + + // Verify server was removed + const configAfterRemoval = await registry.getServerConfig(serverName, userId); + expect(configAfterRemoval).toBeUndefined(); + }); + + it('should throw error when adding duplicate private user server', async () => { + const userId = 'user123'; + const serverName = 'private_server'; + + await registry.addPrivateUserServer(userId, serverName, testParsedConfig); + await expect( + registry.addPrivateUserServer(userId, serverName, testParsedConfig), + ).rejects.toThrow( + 'Server "private_server" already exists in cache. Use update() to modify existing configs.', + ); + }); + + it('should update an existing private user server', async () => { + const userId = 'user123'; + const serverName = 'private_server'; + const updatedConfig: t.ParsedServerConfig = { + type: 'stdio', + command: 'python', + args: ['updated.py'], + requiresOAuth: true, + }; + + // Add private user server + await registry.addPrivateUserServer(userId, serverName, testParsedConfig); + + // Update the server config + await registry.updatePrivateUserServer(userId, serverName, updatedConfig); + + // Verify server was updated + const retrievedConfig = await registry.getServerConfig(serverName, userId); + expect(retrievedConfig).toEqual(updatedConfig); + }); + + it('should throw error when updating non-existent server', async () => { + const userId = 'user123'; + const serverName = 'private_server'; + + // Add a user cache first + await registry.addPrivateUserServer(userId, 'other_server', testParsedConfig); + + await expect( + registry.updatePrivateUserServer(userId, serverName, testParsedConfig), + ).rejects.toThrow( + 'Server "private_server" does not exist in cache. Use add() to create new configs.', + ); + }); + + it('should throw error when updating server for non-existent user', async () => { + const userId = 'nonexistent_user'; + const serverName = 'private_server'; + + await expect( + registry.updatePrivateUserServer(userId, serverName, testParsedConfig), + ).rejects.toThrow('No private servers found for user "nonexistent_user".'); + }); + }); + + describe('getAllServerConfigs', () => { + it('should return correct servers based on userId', async () => { + // Add servers to all three caches + await registry.sharedAppServers.add('app_server', testParsedConfig); + await registry.sharedUserServers.add('user_server', testParsedConfig); + await registry.addPrivateUserServer('abc', 'abc_private_server', testParsedConfig); + await registry.addPrivateUserServer('xyz', 'xyz_private_server', testParsedConfig); + + // Without userId: should return only shared app + shared user servers + const configsNoUser = await registry.getAllServerConfigs(); + expect(Object.keys(configsNoUser)).toHaveLength(2); + expect(configsNoUser).toHaveProperty('app_server'); + expect(configsNoUser).toHaveProperty('user_server'); + + // With userId 'abc': should return shared app + shared user + abc's private servers + const configsAbc = await registry.getAllServerConfigs('abc'); + expect(Object.keys(configsAbc)).toHaveLength(3); + expect(configsAbc).toHaveProperty('app_server'); + expect(configsAbc).toHaveProperty('user_server'); + expect(configsAbc).toHaveProperty('abc_private_server'); + + // With userId 'xyz': should return shared app + shared user + xyz's private servers + const configsXyz = await registry.getAllServerConfigs('xyz'); + expect(Object.keys(configsXyz)).toHaveLength(3); + expect(configsXyz).toHaveProperty('app_server'); + expect(configsXyz).toHaveProperty('user_server'); + expect(configsXyz).toHaveProperty('xyz_private_server'); + }); + }); + + describe('reset', () => { + it('should clear all servers from all caches (shared app, shared user, and private user)', async () => { + const userId = 'user123'; + + // Add servers to all three caches + await registry.sharedAppServers.add('app_server', testParsedConfig); + await registry.sharedUserServers.add('user_server', testParsedConfig); + await registry.addPrivateUserServer(userId, 'private_server', testParsedConfig); + + // Verify all servers are accessible before reset + const appConfigBefore = await registry.getServerConfig('app_server'); + const userConfigBefore = await registry.getServerConfig('user_server'); + const privateConfigBefore = await registry.getServerConfig('private_server', userId); + const allConfigsBefore = await registry.getAllServerConfigs(userId); + + expect(appConfigBefore).toEqual(testParsedConfig); + expect(userConfigBefore).toEqual(testParsedConfig); + expect(privateConfigBefore).toEqual(testParsedConfig); + expect(Object.keys(allConfigsBefore)).toHaveLength(3); + + // Reset everything + await registry.reset(); + + // Verify all servers are cleared after reset + const appConfigAfter = await registry.getServerConfig('app_server'); + const userConfigAfter = await registry.getServerConfig('user_server'); + const privateConfigAfter = await registry.getServerConfig('private_server', userId); + const allConfigsAfter = await registry.getAllServerConfigs(userId); + + expect(appConfigAfter).toBeUndefined(); + expect(userConfigAfter).toBeUndefined(); + expect(privateConfigAfter).toBeUndefined(); + expect(Object.keys(allConfigsAfter)).toHaveLength(0); + }); + }); +}); diff --git a/packages/api/src/mcp/registry/__tests__/MCPServersRegistry.test.ts b/packages/api/src/mcp/registry/__tests__/MCPServersRegistry.test.ts new file mode 100644 index 0000000000..db4b40a46b --- /dev/null +++ b/packages/api/src/mcp/registry/__tests__/MCPServersRegistry.test.ts @@ -0,0 +1,175 @@ +import * as t from '~/mcp/types'; +import { mcpServersRegistry as registry } from '~/mcp/registry/MCPServersRegistry'; + +/** + * Unit tests for MCPServersRegistry using in-memory cache. + * For integration tests using Redis-backed cache, see MCPServersRegistry.cache_integration.spec.ts + */ +describe('MCPServersRegistry', () => { + const testParsedConfig: t.ParsedServerConfig = { + type: 'stdio', + command: 'node', + args: ['tools.js'], + requiresOAuth: false, + serverInstructions: 'Instructions for file_tools_server', + tools: 'file_read, file_write', + capabilities: '{"tools":{"listChanged":true}}', + toolFunctions: { + file_read_mcp_file_tools_server: { + type: 'function', + function: { + name: 'file_read_mcp_file_tools_server', + description: 'Read a file', + parameters: { type: 'object' }, + }, + }, + }, + }; + + beforeEach(async () => { + await registry.reset(); + }); + + describe('private user servers', () => { + it('should add and remove private user server', async () => { + const userId = 'user123'; + const serverName = 'private_server'; + + // Add private user server + await registry.addPrivateUserServer(userId, serverName, testParsedConfig); + + // Verify server was added + const retrievedConfig = await registry.getServerConfig(serverName, userId); + expect(retrievedConfig).toEqual(testParsedConfig); + + // Remove private user server + await registry.removePrivateUserServer(userId, serverName); + + // Verify server was removed + const configAfterRemoval = await registry.getServerConfig(serverName, userId); + expect(configAfterRemoval).toBeUndefined(); + }); + + it('should throw error when adding duplicate private user server', async () => { + const userId = 'user123'; + const serverName = 'private_server'; + + await registry.addPrivateUserServer(userId, serverName, testParsedConfig); + await expect( + registry.addPrivateUserServer(userId, serverName, testParsedConfig), + ).rejects.toThrow( + 'Server "private_server" already exists in cache. Use update() to modify existing configs.', + ); + }); + + it('should update an existing private user server', async () => { + const userId = 'user123'; + const serverName = 'private_server'; + const updatedConfig: t.ParsedServerConfig = { + type: 'stdio', + command: 'python', + args: ['updated.py'], + requiresOAuth: true, + }; + + // Add private user server + await registry.addPrivateUserServer(userId, serverName, testParsedConfig); + + // Update the server config + await registry.updatePrivateUserServer(userId, serverName, updatedConfig); + + // Verify server was updated + const retrievedConfig = await registry.getServerConfig(serverName, userId); + expect(retrievedConfig).toEqual(updatedConfig); + }); + + it('should throw error when updating non-existent server', async () => { + const userId = 'user123'; + const serverName = 'private_server'; + + // Add a user cache first + await registry.addPrivateUserServer(userId, 'other_server', testParsedConfig); + + await expect( + registry.updatePrivateUserServer(userId, serverName, testParsedConfig), + ).rejects.toThrow( + 'Server "private_server" does not exist in cache. Use add() to create new configs.', + ); + }); + + it('should throw error when updating server for non-existent user', async () => { + const userId = 'nonexistent_user'; + const serverName = 'private_server'; + + await expect( + registry.updatePrivateUserServer(userId, serverName, testParsedConfig), + ).rejects.toThrow('No private servers found for user "nonexistent_user".'); + }); + }); + + describe('getAllServerConfigs', () => { + it('should return correct servers based on userId', async () => { + // Add servers to all three caches + await registry.sharedAppServers.add('app_server', testParsedConfig); + await registry.sharedUserServers.add('user_server', testParsedConfig); + await registry.addPrivateUserServer('abc', 'abc_private_server', testParsedConfig); + await registry.addPrivateUserServer('xyz', 'xyz_private_server', testParsedConfig); + + // Without userId: should return only shared app + shared user servers + const configsNoUser = await registry.getAllServerConfigs(); + expect(Object.keys(configsNoUser)).toHaveLength(2); + expect(configsNoUser).toHaveProperty('app_server'); + expect(configsNoUser).toHaveProperty('user_server'); + + // With userId 'abc': should return shared app + shared user + abc's private servers + const configsAbc = await registry.getAllServerConfigs('abc'); + expect(Object.keys(configsAbc)).toHaveLength(3); + expect(configsAbc).toHaveProperty('app_server'); + expect(configsAbc).toHaveProperty('user_server'); + expect(configsAbc).toHaveProperty('abc_private_server'); + + // With userId 'xyz': should return shared app + shared user + xyz's private servers + const configsXyz = await registry.getAllServerConfigs('xyz'); + expect(Object.keys(configsXyz)).toHaveLength(3); + expect(configsXyz).toHaveProperty('app_server'); + expect(configsXyz).toHaveProperty('user_server'); + expect(configsXyz).toHaveProperty('xyz_private_server'); + }); + }); + + describe('reset', () => { + it('should clear all servers from all caches (shared app, shared user, and private user)', async () => { + const userId = 'user123'; + + // Add servers to all three caches + await registry.sharedAppServers.add('app_server', testParsedConfig); + await registry.sharedUserServers.add('user_server', testParsedConfig); + await registry.addPrivateUserServer(userId, 'private_server', testParsedConfig); + + // Verify all servers are accessible before reset + const appConfigBefore = await registry.getServerConfig('app_server'); + const userConfigBefore = await registry.getServerConfig('user_server'); + const privateConfigBefore = await registry.getServerConfig('private_server', userId); + const allConfigsBefore = await registry.getAllServerConfigs(userId); + + expect(appConfigBefore).toEqual(testParsedConfig); + expect(userConfigBefore).toEqual(testParsedConfig); + expect(privateConfigBefore).toEqual(testParsedConfig); + expect(Object.keys(allConfigsBefore)).toHaveLength(3); + + // Reset everything + await registry.reset(); + + // Verify all servers are cleared after reset + const appConfigAfter = await registry.getServerConfig('app_server'); + const userConfigAfter = await registry.getServerConfig('user_server'); + const privateConfigAfter = await registry.getServerConfig('private_server', userId); + const allConfigsAfter = await registry.getAllServerConfigs(userId); + + expect(appConfigAfter).toBeUndefined(); + expect(userConfigAfter).toBeUndefined(); + expect(privateConfigAfter).toBeUndefined(); + expect(Object.keys(allConfigsAfter)).toHaveLength(0); + }); + }); +}); diff --git a/packages/api/src/mcp/registry/__tests__/mcpConnectionsMock.helper.ts b/packages/api/src/mcp/registry/__tests__/mcpConnectionsMock.helper.ts new file mode 100644 index 0000000000..74bc83425d --- /dev/null +++ b/packages/api/src/mcp/registry/__tests__/mcpConnectionsMock.helper.ts @@ -0,0 +1,55 @@ +import type { MCPConnection } from '~/mcp/connection'; + +/** + * Creates a single mock MCP connection for testing. + * The connection has a client with mocked methods that return server-specific data. + * @param serverName - Name of the server to create mock connection for + * @returns Mocked MCPConnection instance + */ +export function createMockConnection(serverName: string): jest.Mocked { + const mockClient = { + getInstructions: jest.fn().mockReturnValue(`instructions for ${serverName}`), + getServerCapabilities: jest.fn().mockReturnValue({ + tools: { listChanged: true }, + resources: { listChanged: true }, + prompts: { get: `getPrompts for ${serverName}` }, + }), + listTools: jest.fn().mockResolvedValue({ + tools: [ + { + name: 'listFiles', + description: `Description for ${serverName}'s listFiles tool`, + inputSchema: { + type: 'object', + properties: { + input: { type: 'string' }, + }, + }, + }, + ], + }), + }; + + return { + client: mockClient, + disconnect: jest.fn().mockResolvedValue(undefined), + } as unknown as jest.Mocked; +} + +/** + * Creates mock MCP connections for testing. + * Each connection has a client with mocked methods that return server-specific data. + * @param serverNames - Array of server names to create mock connections for + * @returns Map of server names to mocked MCPConnection instances + */ +export function createMockConnectionsMap( + serverNames: string[], +): Map> { + const mockConnections = new Map>(); + + serverNames.forEach((serverName) => { + mockConnections.set(serverName, createMockConnection(serverName)); + }); + + return mockConnections; +} diff --git a/packages/api/src/mcp/registry/cache/BaseRegistryCache.ts b/packages/api/src/mcp/registry/cache/BaseRegistryCache.ts new file mode 100644 index 0000000000..1d2266fc6d --- /dev/null +++ b/packages/api/src/mcp/registry/cache/BaseRegistryCache.ts @@ -0,0 +1,26 @@ +import type Keyv from 'keyv'; +import { isLeader } from '~/cluster'; + +/** + * Base class for MCP registry caches that require distributed leader coordination. + * Provides helper methods for leader-only operations and success validation. + * All concrete implementations must provide their own Keyv cache instance. + */ +export abstract class BaseRegistryCache { + protected readonly PREFIX = 'MCP::ServersRegistry'; + protected abstract readonly cache: Keyv; + + protected async leaderCheck(action: string): Promise { + if (!(await isLeader())) throw new Error(`Only leader can ${action}.`); + } + + protected successCheck(action: string, success: boolean): true { + if (!success) throw new Error(`Failed to ${action} in cache.`); + return true; + } + + public async reset(): Promise { + await this.leaderCheck(`reset ${this.cache.namespace} cache`); + await this.cache.clear(); + } +} diff --git a/packages/api/src/mcp/registry/cache/RegistryStatusCache.ts b/packages/api/src/mcp/registry/cache/RegistryStatusCache.ts new file mode 100644 index 0000000000..2a8fc72213 --- /dev/null +++ b/packages/api/src/mcp/registry/cache/RegistryStatusCache.ts @@ -0,0 +1,37 @@ +import { standardCache } from '~/cache'; +import { BaseRegistryCache } from './BaseRegistryCache'; + +// Status keys +const INITIALIZED = 'INITIALIZED'; + +/** + * Cache for tracking MCP Servers Registry metadata and status across distributed instances. + * Uses Redis-backed storage to coordinate state between leader and follower nodes. + * Currently, tracks initialization status to ensure only the leader performs initialization + * while followers wait for completion. Designed to be extended with additional registry + * metadata as needed (e.g., last update timestamps, version info, health status). + * This cache is only meant to be used internally by registry management components. + */ +class RegistryStatusCache extends BaseRegistryCache { + protected readonly cache = standardCache(`${this.PREFIX}::Status`); + + public async isInitialized(): Promise { + return (await this.get(INITIALIZED)) === true; + } + + public async setInitialized(value: boolean): Promise { + await this.set(INITIALIZED, value); + } + + private async get(key: string): Promise { + return this.cache.get(key); + } + + private async set(key: string, value: string | number | boolean, ttl?: number): Promise { + await this.leaderCheck('set MCP Servers Registry status'); + const success = await this.cache.set(key, value, ttl); + this.successCheck(`set status key "${key}"`, success); + } +} + +export const registryStatusCache = new RegistryStatusCache(); diff --git a/packages/api/src/mcp/registry/cache/ServerConfigsCacheFactory.ts b/packages/api/src/mcp/registry/cache/ServerConfigsCacheFactory.ts new file mode 100644 index 0000000000..72c664d844 --- /dev/null +++ b/packages/api/src/mcp/registry/cache/ServerConfigsCacheFactory.ts @@ -0,0 +1,31 @@ +import { cacheConfig } from '~/cache'; +import { ServerConfigsCacheInMemory } from './ServerConfigsCacheInMemory'; +import { ServerConfigsCacheRedis } from './ServerConfigsCacheRedis'; + +export type ServerConfigsCache = ServerConfigsCacheInMemory | ServerConfigsCacheRedis; + +/** + * Factory for creating the appropriate ServerConfigsCache implementation based on deployment mode. + * Automatically selects between in-memory and Redis-backed storage depending on USE_REDIS config. + * In single-instance mode (USE_REDIS=false), returns lightweight in-memory cache. + * In cluster mode (USE_REDIS=true), returns Redis-backed cache with distributed coordination. + * Provides a unified interface regardless of the underlying storage mechanism. + */ +export class ServerConfigsCacheFactory { + /** + * Create a ServerConfigsCache instance. + * Returns Redis implementation if Redis is configured, otherwise in-memory implementation. + * + * @param owner - The owner of the cache (e.g., 'user', 'global') - only used for Redis namespacing + * @param leaderOnly - Whether operations should only be performed by the leader (only applies to Redis) + * @returns ServerConfigsCache instance + */ + static create(owner: string, leaderOnly: boolean): ServerConfigsCache { + if (cacheConfig.USE_REDIS) { + return new ServerConfigsCacheRedis(owner, leaderOnly); + } + + // In-memory mode uses a simple Map - doesn't need owner/namespace + return new ServerConfigsCacheInMemory(); + } +} diff --git a/packages/api/src/mcp/registry/cache/ServerConfigsCacheInMemory.ts b/packages/api/src/mcp/registry/cache/ServerConfigsCacheInMemory.ts new file mode 100644 index 0000000000..1dd2385053 --- /dev/null +++ b/packages/api/src/mcp/registry/cache/ServerConfigsCacheInMemory.ts @@ -0,0 +1,46 @@ +import { ParsedServerConfig } from '~/mcp/types'; + +/** + * In-memory implementation of MCP server configurations cache for single-instance deployments. + * Uses a native JavaScript Map for fast, local storage without Redis dependencies. + * Suitable for development environments or single-server production deployments. + * Does not require leader checks or distributed coordination since data is instance-local. + * Data is lost on server restart and not shared across multiple server instances. + */ +export class ServerConfigsCacheInMemory { + private readonly cache: Map = new Map(); + + public async add(serverName: string, config: ParsedServerConfig): Promise { + if (this.cache.has(serverName)) + throw new Error( + `Server "${serverName}" already exists in cache. Use update() to modify existing configs.`, + ); + this.cache.set(serverName, config); + } + + public async update(serverName: string, config: ParsedServerConfig): Promise { + if (!this.cache.has(serverName)) + throw new Error( + `Server "${serverName}" does not exist in cache. Use add() to create new configs.`, + ); + this.cache.set(serverName, config); + } + + public async remove(serverName: string): Promise { + if (!this.cache.delete(serverName)) { + throw new Error(`Failed to remove server "${serverName}" in cache.`); + } + } + + public async get(serverName: string): Promise { + return this.cache.get(serverName); + } + + public async getAll(): Promise> { + return Object.fromEntries(this.cache); + } + + public async reset(): Promise { + this.cache.clear(); + } +} diff --git a/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedis.ts b/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedis.ts new file mode 100644 index 0000000000..a2e025736c --- /dev/null +++ b/packages/api/src/mcp/registry/cache/ServerConfigsCacheRedis.ts @@ -0,0 +1,80 @@ +import type Keyv from 'keyv'; +import { fromPairs } from 'lodash'; +import { standardCache, keyvRedisClient } from '~/cache'; +import { ParsedServerConfig } from '~/mcp/types'; +import { BaseRegistryCache } from './BaseRegistryCache'; + +/** + * Redis-backed implementation of MCP server configurations cache for distributed deployments. + * Stores server configs in Redis with namespace isolation by owner (App, User, or specific user ID). + * Enables data sharing across multiple server instances in a cluster environment. + * Supports optional leader-only write operations to prevent race conditions during initialization. + * Data persists across server restarts and is accessible from any instance in the cluster. + */ +export class ServerConfigsCacheRedis extends BaseRegistryCache { + protected readonly cache: Keyv; + private readonly owner: string; + private readonly leaderOnly: boolean; + + constructor(owner: string, leaderOnly: boolean) { + super(); + this.owner = owner; + this.leaderOnly = leaderOnly; + this.cache = standardCache(`${this.PREFIX}::Servers::${owner}`); + } + + public async add(serverName: string, config: ParsedServerConfig): Promise { + if (this.leaderOnly) await this.leaderCheck(`add ${this.owner} MCP servers`); + const exists = await this.cache.has(serverName); + if (exists) + throw new Error( + `Server "${serverName}" already exists in cache. Use update() to modify existing configs.`, + ); + const success = await this.cache.set(serverName, config); + this.successCheck(`add ${this.owner} server "${serverName}"`, success); + } + + public async update(serverName: string, config: ParsedServerConfig): Promise { + if (this.leaderOnly) await this.leaderCheck(`update ${this.owner} MCP servers`); + const exists = await this.cache.has(serverName); + if (!exists) + throw new Error( + `Server "${serverName}" does not exist in cache. Use add() to create new configs.`, + ); + const success = await this.cache.set(serverName, config); + this.successCheck(`update ${this.owner} server "${serverName}"`, success); + } + + public async remove(serverName: string): Promise { + if (this.leaderOnly) await this.leaderCheck(`remove ${this.owner} MCP servers`); + const success = await this.cache.delete(serverName); + this.successCheck(`remove ${this.owner} server "${serverName}"`, success); + } + + public async get(serverName: string): Promise { + return this.cache.get(serverName); + } + + public async getAll(): Promise> { + // Use Redis SCAN iterator directly (non-blocking, production-ready) + // Note: Keyv uses a single colon ':' between namespace and key, even if GLOBAL_PREFIX_SEPARATOR is '::' + const pattern = `*${this.cache.namespace}:*`; + const entries: Array<[string, ParsedServerConfig]> = []; + + // Use scanIterator from Redis client + if (keyvRedisClient && 'scanIterator' in keyvRedisClient) { + for await (const key of keyvRedisClient.scanIterator({ MATCH: pattern })) { + // Extract the actual key name (last part after final colon) + // Full key format: "prefix::namespace:keyName" + const lastColonIndex = key.lastIndexOf(':'); + const keyName = key.substring(lastColonIndex + 1); + const value = await this.cache.get(keyName); + if (value) { + entries.push([keyName, value as ParsedServerConfig]); + } + } + } + + return fromPairs(entries); + } +} diff --git a/packages/api/src/mcp/registry/cache/__tests__/RegistryStatusCache.cache_integration.spec.ts b/packages/api/src/mcp/registry/cache/__tests__/RegistryStatusCache.cache_integration.spec.ts new file mode 100644 index 0000000000..643e7c27df --- /dev/null +++ b/packages/api/src/mcp/registry/cache/__tests__/RegistryStatusCache.cache_integration.spec.ts @@ -0,0 +1,73 @@ +import { expect } from '@playwright/test'; + +describe('RegistryStatusCache Integration Tests', () => { + let registryStatusCache: typeof import('../RegistryStatusCache').registryStatusCache; + let keyvRedisClient: Awaited['keyvRedisClient']; + let LeaderElection: typeof import('~/cluster/LeaderElection').LeaderElection; + let leaderInstance: InstanceType; + + beforeAll(async () => { + // Set up environment variables for Redis (only if not already set) + process.env.USE_REDIS = process.env.USE_REDIS ?? 'true'; + process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379'; + process.env.REDIS_KEY_PREFIX = + process.env.REDIS_KEY_PREFIX ?? 'RegistryStatusCache-IntegrationTest'; + + // Import modules after setting env vars + const statusCacheModule = await import('../RegistryStatusCache'); + const redisClients = await import('~/cache/redisClients'); + const leaderElectionModule = await import('~/cluster/LeaderElection'); + + registryStatusCache = statusCacheModule.registryStatusCache; + keyvRedisClient = redisClients.keyvRedisClient; + LeaderElection = leaderElectionModule.LeaderElection; + + // Ensure Redis is connected + if (!keyvRedisClient) throw new Error('Redis client is not initialized'); + + // Wait for Redis to be ready + if (!keyvRedisClient.isOpen) await keyvRedisClient.connect(); + + // Become leader so we can perform write operations + leaderInstance = new LeaderElection(); + const isLeader = await leaderInstance.isLeader(); + expect(isLeader).toBe(true); + }); + + afterEach(async () => { + // Clean up: clear all test keys from Redis + if (keyvRedisClient) { + const pattern = '*RegistryStatusCache-IntegrationTest*'; + if ('scanIterator' in keyvRedisClient) { + for await (const key of keyvRedisClient.scanIterator({ MATCH: pattern })) { + await keyvRedisClient.del(key); + } + } + } + }); + + afterAll(async () => { + // Resign as leader + if (leaderInstance) await leaderInstance.resign(); + + // Close Redis connection + if (keyvRedisClient?.isOpen) await keyvRedisClient.disconnect(); + }); + + describe('Initialization status tracking', () => { + it('should return false for isInitialized when not set', async () => { + const initialized = await registryStatusCache.isInitialized(); + expect(initialized).toBe(false); + }); + + it('should set and get initialized status', async () => { + await registryStatusCache.setInitialized(true); + const initialized = await registryStatusCache.isInitialized(); + expect(initialized).toBe(true); + + await registryStatusCache.setInitialized(false); + const uninitialized = await registryStatusCache.isInitialized(); + expect(uninitialized).toBe(false); + }); + }); +}); diff --git a/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheFactory.test.ts b/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheFactory.test.ts new file mode 100644 index 0000000000..d1e0a0d486 --- /dev/null +++ b/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheFactory.test.ts @@ -0,0 +1,70 @@ +import { ServerConfigsCacheFactory } from '../ServerConfigsCacheFactory'; +import { ServerConfigsCacheInMemory } from '../ServerConfigsCacheInMemory'; +import { ServerConfigsCacheRedis } from '../ServerConfigsCacheRedis'; +import { cacheConfig } from '~/cache'; + +// Mock the cache implementations +jest.mock('../ServerConfigsCacheInMemory'); +jest.mock('../ServerConfigsCacheRedis'); + +// Mock the cache config module +jest.mock('~/cache', () => ({ + cacheConfig: { + USE_REDIS: false, + }, +})); + +describe('ServerConfigsCacheFactory', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + describe('create()', () => { + it('should return ServerConfigsCacheRedis when USE_REDIS is true', () => { + // Arrange + cacheConfig.USE_REDIS = true; + + // Act + const cache = ServerConfigsCacheFactory.create('TestOwner', true); + + // Assert + expect(cache).toBeInstanceOf(ServerConfigsCacheRedis); + expect(ServerConfigsCacheRedis).toHaveBeenCalledWith('TestOwner', true); + }); + + it('should return ServerConfigsCacheInMemory when USE_REDIS is false', () => { + // Arrange + cacheConfig.USE_REDIS = false; + + // Act + const cache = ServerConfigsCacheFactory.create('TestOwner', false); + + // Assert + expect(cache).toBeInstanceOf(ServerConfigsCacheInMemory); + expect(ServerConfigsCacheInMemory).toHaveBeenCalled(); + }); + + it('should pass correct parameters to ServerConfigsCacheRedis', () => { + // Arrange + cacheConfig.USE_REDIS = true; + + // Act + ServerConfigsCacheFactory.create('App', true); + + // Assert + expect(ServerConfigsCacheRedis).toHaveBeenCalledWith('App', true); + }); + + it('should create ServerConfigsCacheInMemory without parameters when USE_REDIS is false', () => { + // Arrange + cacheConfig.USE_REDIS = false; + + // Act + ServerConfigsCacheFactory.create('User', false); + + // Assert + // In-memory cache doesn't use owner/leaderOnly parameters + expect(ServerConfigsCacheInMemory).toHaveBeenCalledWith(); + }); + }); +}); diff --git a/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheInMemory.test.ts b/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheInMemory.test.ts new file mode 100644 index 0000000000..e2033d0ba8 --- /dev/null +++ b/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheInMemory.test.ts @@ -0,0 +1,173 @@ +import { expect } from '@playwright/test'; +import { ParsedServerConfig } from '~/mcp/types'; + +describe('ServerConfigsCacheInMemory Integration Tests', () => { + let ServerConfigsCacheInMemory: typeof import('../ServerConfigsCacheInMemory').ServerConfigsCacheInMemory; + let cache: InstanceType< + typeof import('../ServerConfigsCacheInMemory').ServerConfigsCacheInMemory + >; + + // Test data + const mockConfig1: ParsedServerConfig = { + command: 'node', + args: ['server1.js'], + env: { TEST: 'value1' }, + }; + + const mockConfig2: ParsedServerConfig = { + command: 'python', + args: ['server2.py'], + env: { TEST: 'value2' }, + }; + + const mockConfig3: ParsedServerConfig = { + command: 'node', + args: ['server3.js'], + url: 'http://localhost:3000', + requiresOAuth: true, + }; + + beforeAll(async () => { + // Import modules + const cacheModule = await import('../ServerConfigsCacheInMemory'); + ServerConfigsCacheInMemory = cacheModule.ServerConfigsCacheInMemory; + }); + + beforeEach(() => { + // Create a fresh instance for each test + cache = new ServerConfigsCacheInMemory(); + }); + + describe('add and get operations', () => { + it('should add and retrieve a server config', async () => { + await cache.add('server1', mockConfig1); + const result = await cache.get('server1'); + expect(result).toEqual(mockConfig1); + }); + + it('should return undefined for non-existent server', async () => { + const result = await cache.get('non-existent'); + expect(result).toBeUndefined(); + }); + + it('should throw error when adding duplicate server', async () => { + await cache.add('server1', mockConfig1); + await expect(cache.add('server1', mockConfig2)).rejects.toThrow( + 'Server "server1" already exists in cache. Use update() to modify existing configs.', + ); + }); + + it('should handle multiple server configs', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + await cache.add('server3', mockConfig3); + + const result1 = await cache.get('server1'); + const result2 = await cache.get('server2'); + const result3 = await cache.get('server3'); + + expect(result1).toEqual(mockConfig1); + expect(result2).toEqual(mockConfig2); + expect(result3).toEqual(mockConfig3); + }); + }); + + describe('getAll operation', () => { + it('should return empty object when no servers exist', async () => { + const result = await cache.getAll(); + expect(result).toEqual({}); + }); + + it('should return all server configs', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + await cache.add('server3', mockConfig3); + + const result = await cache.getAll(); + expect(result).toEqual({ + server1: mockConfig1, + server2: mockConfig2, + server3: mockConfig3, + }); + }); + + it('should reflect updates in getAll', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + + let result = await cache.getAll(); + expect(Object.keys(result).length).toBe(2); + + await cache.add('server3', mockConfig3); + result = await cache.getAll(); + expect(Object.keys(result).length).toBe(3); + expect(result.server3).toEqual(mockConfig3); + }); + }); + + describe('update operation', () => { + it('should update an existing server config', async () => { + await cache.add('server1', mockConfig1); + expect(await cache.get('server1')).toEqual(mockConfig1); + + await cache.update('server1', mockConfig2); + const result = await cache.get('server1'); + expect(result).toEqual(mockConfig2); + }); + + it('should throw error when updating non-existent server', async () => { + await expect(cache.update('non-existent', mockConfig1)).rejects.toThrow( + 'Server "non-existent" does not exist in cache. Use add() to create new configs.', + ); + }); + + it('should reflect updates in getAll', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + + await cache.update('server1', mockConfig3); + const result = await cache.getAll(); + expect(result.server1).toEqual(mockConfig3); + expect(result.server2).toEqual(mockConfig2); + }); + }); + + describe('remove operation', () => { + it('should remove an existing server config', async () => { + await cache.add('server1', mockConfig1); + expect(await cache.get('server1')).toEqual(mockConfig1); + + await cache.remove('server1'); + expect(await cache.get('server1')).toBeUndefined(); + }); + + it('should throw error when removing non-existent server', async () => { + await expect(cache.remove('non-existent')).rejects.toThrow( + 'Failed to remove server "non-existent" in cache.', + ); + }); + + it('should remove server from getAll results', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + + let result = await cache.getAll(); + expect(Object.keys(result).length).toBe(2); + + await cache.remove('server1'); + result = await cache.getAll(); + expect(Object.keys(result).length).toBe(1); + expect(result.server1).toBeUndefined(); + expect(result.server2).toEqual(mockConfig2); + }); + + it('should allow re-adding a removed server', async () => { + await cache.add('server1', mockConfig1); + await cache.remove('server1'); + await cache.add('server1', mockConfig3); + + const result = await cache.get('server1'); + expect(result).toEqual(mockConfig3); + }); + }); +}); diff --git a/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedis.cache_integration.spec.ts b/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedis.cache_integration.spec.ts new file mode 100644 index 0000000000..7e139dc5be --- /dev/null +++ b/packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedis.cache_integration.spec.ts @@ -0,0 +1,278 @@ +import { expect } from '@playwright/test'; +import { ParsedServerConfig } from '~/mcp/types'; + +describe('ServerConfigsCacheRedis Integration Tests', () => { + let ServerConfigsCacheRedis: typeof import('../ServerConfigsCacheRedis').ServerConfigsCacheRedis; + let keyvRedisClient: Awaited['keyvRedisClient']; + let LeaderElection: typeof import('~/cluster/LeaderElection').LeaderElection; + let checkIsLeader: () => Promise; + let cache: InstanceType; + + // Test data + const mockConfig1: ParsedServerConfig = { + command: 'node', + args: ['server1.js'], + env: { TEST: 'value1' }, + }; + + const mockConfig2: ParsedServerConfig = { + command: 'python', + args: ['server2.py'], + env: { TEST: 'value2' }, + }; + + const mockConfig3: ParsedServerConfig = { + command: 'node', + args: ['server3.js'], + url: 'http://localhost:3000', + requiresOAuth: true, + }; + + beforeAll(async () => { + // Set up environment variables for Redis (only if not already set) + process.env.USE_REDIS = process.env.USE_REDIS ?? 'true'; + process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379'; + process.env.REDIS_KEY_PREFIX = + process.env.REDIS_KEY_PREFIX ?? 'ServerConfigsCacheRedis-IntegrationTest'; + + // Import modules after setting env vars + const cacheModule = await import('../ServerConfigsCacheRedis'); + const redisClients = await import('~/cache/redisClients'); + const leaderElectionModule = await import('~/cluster/LeaderElection'); + const clusterModule = await import('~/cluster'); + + ServerConfigsCacheRedis = cacheModule.ServerConfigsCacheRedis; + keyvRedisClient = redisClients.keyvRedisClient; + LeaderElection = leaderElectionModule.LeaderElection; + checkIsLeader = clusterModule.isLeader; + + // Ensure Redis is connected + if (!keyvRedisClient) throw new Error('Redis client is not initialized'); + + // Wait for Redis to be ready + if (!keyvRedisClient.isOpen) await keyvRedisClient.connect(); + + // Clear any existing leader key to ensure clean state + await keyvRedisClient.del(LeaderElection.LEADER_KEY); + + // Become leader so we can perform write operations (using default election instance) + const isLeader = await checkIsLeader(); + expect(isLeader).toBe(true); + }); + + beforeEach(() => { + // Create a fresh instance for each test with leaderOnly=true + cache = new ServerConfigsCacheRedis('test-user', true); + }); + + afterEach(async () => { + // Clean up: clear all test keys from Redis + if (keyvRedisClient) { + const pattern = '*ServerConfigsCacheRedis-IntegrationTest*'; + if ('scanIterator' in keyvRedisClient) { + for await (const key of keyvRedisClient.scanIterator({ MATCH: pattern })) { + await keyvRedisClient.del(key); + } + } + } + }); + + afterAll(async () => { + // Clear leader key to allow other tests to become leader + if (keyvRedisClient) await keyvRedisClient.del(LeaderElection.LEADER_KEY); + + // Close Redis connection + if (keyvRedisClient?.isOpen) await keyvRedisClient.disconnect(); + }); + + describe('add and get operations', () => { + it('should add and retrieve a server config', async () => { + await cache.add('server1', mockConfig1); + const result = await cache.get('server1'); + expect(result).toEqual(mockConfig1); + }); + + it('should return undefined for non-existent server', async () => { + const result = await cache.get('non-existent'); + expect(result).toBeUndefined(); + }); + + it('should throw error when adding duplicate server', async () => { + await cache.add('server1', mockConfig1); + await expect(cache.add('server1', mockConfig2)).rejects.toThrow( + 'Server "server1" already exists in cache. Use update() to modify existing configs.', + ); + }); + + it('should handle multiple server configs', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + await cache.add('server3', mockConfig3); + + const result1 = await cache.get('server1'); + const result2 = await cache.get('server2'); + const result3 = await cache.get('server3'); + + expect(result1).toEqual(mockConfig1); + expect(result2).toEqual(mockConfig2); + expect(result3).toEqual(mockConfig3); + }); + + it('should isolate caches by owner namespace', async () => { + const userCache = new ServerConfigsCacheRedis('user1', true); + const globalCache = new ServerConfigsCacheRedis('global', true); + + await userCache.add('server1', mockConfig1); + await globalCache.add('server1', mockConfig2); + + const userResult = await userCache.get('server1'); + const globalResult = await globalCache.get('server1'); + + expect(userResult).toEqual(mockConfig1); + expect(globalResult).toEqual(mockConfig2); + }); + }); + + describe('getAll operation', () => { + it('should return empty object when no servers exist', async () => { + const result = await cache.getAll(); + expect(result).toEqual({}); + }); + + it('should return all server configs', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + await cache.add('server3', mockConfig3); + + const result = await cache.getAll(); + expect(result).toEqual({ + server1: mockConfig1, + server2: mockConfig2, + server3: mockConfig3, + }); + }); + + it('should reflect updates in getAll', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + + let result = await cache.getAll(); + expect(Object.keys(result).length).toBe(2); + + await cache.add('server3', mockConfig3); + result = await cache.getAll(); + expect(Object.keys(result).length).toBe(3); + expect(result.server3).toEqual(mockConfig3); + }); + + it('should only return configs for the specific owner', async () => { + const userCache = new ServerConfigsCacheRedis('user1', true); + const globalCache = new ServerConfigsCacheRedis('global', true); + + await userCache.add('server1', mockConfig1); + await userCache.add('server2', mockConfig2); + await globalCache.add('server3', mockConfig3); + + const userResult = await userCache.getAll(); + const globalResult = await globalCache.getAll(); + + expect(Object.keys(userResult).length).toBe(2); + expect(Object.keys(globalResult).length).toBe(1); + expect(userResult.server1).toEqual(mockConfig1); + expect(userResult.server3).toBeUndefined(); + expect(globalResult.server3).toEqual(mockConfig3); + }); + }); + + describe('update operation', () => { + it('should update an existing server config', async () => { + await cache.add('server1', mockConfig1); + expect(await cache.get('server1')).toEqual(mockConfig1); + + await cache.update('server1', mockConfig2); + const result = await cache.get('server1'); + expect(result).toEqual(mockConfig2); + }); + + it('should throw error when updating non-existent server', async () => { + await expect(cache.update('non-existent', mockConfig1)).rejects.toThrow( + 'Server "non-existent" does not exist in cache. Use add() to create new configs.', + ); + }); + + it('should reflect updates in getAll', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + + await cache.update('server1', mockConfig3); + const result = await cache.getAll(); + expect(result.server1).toEqual(mockConfig3); + expect(result.server2).toEqual(mockConfig2); + }); + + it('should only update in the specific owner namespace', async () => { + const userCache = new ServerConfigsCacheRedis('user1', true); + const globalCache = new ServerConfigsCacheRedis('global', true); + + await userCache.add('server1', mockConfig1); + await globalCache.add('server1', mockConfig2); + + await userCache.update('server1', mockConfig3); + + expect(await userCache.get('server1')).toEqual(mockConfig3); + expect(await globalCache.get('server1')).toEqual(mockConfig2); + }); + }); + + describe('remove operation', () => { + it('should remove an existing server config', async () => { + await cache.add('server1', mockConfig1); + expect(await cache.get('server1')).toEqual(mockConfig1); + + await cache.remove('server1'); + expect(await cache.get('server1')).toBeUndefined(); + }); + + it('should throw error when removing non-existent server', async () => { + await expect(cache.remove('non-existent')).rejects.toThrow( + 'Failed to remove test-user server "non-existent"', + ); + }); + + it('should remove server from getAll results', async () => { + await cache.add('server1', mockConfig1); + await cache.add('server2', mockConfig2); + + let result = await cache.getAll(); + expect(Object.keys(result).length).toBe(2); + + await cache.remove('server1'); + result = await cache.getAll(); + expect(Object.keys(result).length).toBe(1); + expect(result.server1).toBeUndefined(); + expect(result.server2).toEqual(mockConfig2); + }); + + it('should allow re-adding a removed server', async () => { + await cache.add('server1', mockConfig1); + await cache.remove('server1'); + await cache.add('server1', mockConfig3); + + const result = await cache.get('server1'); + expect(result).toEqual(mockConfig3); + }); + + it('should only remove from the specific owner namespace', async () => { + const userCache = new ServerConfigsCacheRedis('user1', true); + const globalCache = new ServerConfigsCacheRedis('global', true); + + await userCache.add('server1', mockConfig1); + await globalCache.add('server1', mockConfig2); + + await userCache.remove('server1'); + + expect(await userCache.get('server1')).toBeUndefined(); + expect(await globalCache.get('server1')).toEqual(mockConfig2); + }); + }); +}); diff --git a/packages/api/src/mcp/types/index.ts b/packages/api/src/mcp/types/index.ts index 5cf003b9f5..6e445e26ad 100644 --- a/packages/api/src/mcp/types/index.ts +++ b/packages/api/src/mcp/types/index.ts @@ -151,6 +151,8 @@ export type ParsedServerConfig = MCPOptions & { oauthMetadata?: Record | null; capabilities?: string; tools?: string; + toolFunctions?: LCAvailableTools; + initDuration?: number; }; export interface BasicConnectionOptions { diff --git a/packages/api/src/utils/index.ts b/packages/api/src/utils/index.ts index 9fd3b01885..85c99d108f 100644 --- a/packages/api/src/utils/index.ts +++ b/packages/api/src/utils/index.ts @@ -10,6 +10,7 @@ export * from './key'; export * from './llm'; export * from './math'; export * from './openid'; +export * from './promise'; export * from './sanitizeTitle'; export * from './tempChatRetention'; export * from './text'; diff --git a/packages/api/src/utils/promise.spec.ts b/packages/api/src/utils/promise.spec.ts new file mode 100644 index 0000000000..c43c8bf739 --- /dev/null +++ b/packages/api/src/utils/promise.spec.ts @@ -0,0 +1,115 @@ +import { withTimeout } from './promise'; + +describe('withTimeout', () => { + beforeEach(() => { + jest.clearAllTimers(); + }); + + it('should resolve when promise completes before timeout', async () => { + const promise = Promise.resolve('success'); + const result = await withTimeout(promise, 1000); + expect(result).toBe('success'); + }); + + it('should reject when promise rejects before timeout', async () => { + const promise = Promise.reject(new Error('test error')); + await expect(withTimeout(promise, 1000)).rejects.toThrow('test error'); + }); + + it('should timeout when promise takes too long', async () => { + const promise = new Promise((resolve) => setTimeout(() => resolve('late'), 2000)); + await expect(withTimeout(promise, 100, 'Custom timeout message')).rejects.toThrow( + 'Custom timeout message', + ); + }); + + it('should use default error message when none provided', async () => { + const promise = new Promise((resolve) => setTimeout(() => resolve('late'), 2000)); + await expect(withTimeout(promise, 100)).rejects.toThrow('Operation timed out after 100ms'); + }); + + it('should clear timeout when promise resolves', async () => { + const clearTimeoutSpy = jest.spyOn(global, 'clearTimeout'); + const promise = Promise.resolve('fast'); + + await withTimeout(promise, 1000); + + expect(clearTimeoutSpy).toHaveBeenCalled(); + clearTimeoutSpy.mockRestore(); + }); + + it('should clear timeout when promise rejects', async () => { + const clearTimeoutSpy = jest.spyOn(global, 'clearTimeout'); + const promise = Promise.reject(new Error('fail')); + + await expect(withTimeout(promise, 1000)).rejects.toThrow('fail'); + + expect(clearTimeoutSpy).toHaveBeenCalled(); + clearTimeoutSpy.mockRestore(); + }); + + it('should handle multiple concurrent timeouts', async () => { + const promise1 = Promise.resolve('first'); + const promise2 = new Promise((resolve) => setTimeout(() => resolve('second'), 50)); + const promise3 = new Promise((resolve) => setTimeout(() => resolve('third'), 2000)); + + const [result1, result2] = await Promise.all([ + withTimeout(promise1, 1000), + withTimeout(promise2, 1000), + ]); + + expect(result1).toBe('first'); + expect(result2).toBe('second'); + + await expect(withTimeout(promise3, 100)).rejects.toThrow('Operation timed out after 100ms'); + }); + + it('should work with async functions', async () => { + const asyncFunction = async () => { + await new Promise((resolve) => setTimeout(resolve, 10)); + return 'async result'; + }; + + const result = await withTimeout(asyncFunction(), 1000); + expect(result).toBe('async result'); + }); + + it('should work with any return type', async () => { + const numberPromise = Promise.resolve(42); + const objectPromise = Promise.resolve({ key: 'value' }); + const arrayPromise = Promise.resolve([1, 2, 3]); + + expect(await withTimeout(numberPromise, 1000)).toBe(42); + expect(await withTimeout(objectPromise, 1000)).toEqual({ key: 'value' }); + expect(await withTimeout(arrayPromise, 1000)).toEqual([1, 2, 3]); + }); + + it('should call logger when timeout occurs', async () => { + const loggerMock = jest.fn(); + const promise = new Promise((resolve) => setTimeout(() => resolve('late'), 2000)); + const errorMessage = 'Custom timeout with logger'; + + await expect(withTimeout(promise, 100, errorMessage, loggerMock)).rejects.toThrow(errorMessage); + + expect(loggerMock).toHaveBeenCalledTimes(1); + expect(loggerMock).toHaveBeenCalledWith(errorMessage, expect.any(Error)); + }); + + it('should not call logger when promise resolves', async () => { + const loggerMock = jest.fn(); + const promise = Promise.resolve('success'); + + const result = await withTimeout(promise, 1000, 'Should not timeout', loggerMock); + + expect(result).toBe('success'); + expect(loggerMock).not.toHaveBeenCalled(); + }); + + it('should work without logger parameter', async () => { + const promise = new Promise((resolve) => setTimeout(() => resolve('late'), 2000)); + + await expect(withTimeout(promise, 100, 'No logger provided')).rejects.toThrow( + 'No logger provided', + ); + }); +}); diff --git a/packages/api/src/utils/promise.ts b/packages/api/src/utils/promise.ts new file mode 100644 index 0000000000..72719a3ff0 --- /dev/null +++ b/packages/api/src/utils/promise.ts @@ -0,0 +1,42 @@ +/** + * Wraps a promise with a timeout. If the promise doesn't resolve/reject within + * the specified time, it will be rejected with a timeout error. + * + * @param promise - The promise to wrap with a timeout + * @param timeoutMs - Timeout duration in milliseconds + * @param errorMessage - Custom error message for timeout (optional) + * @param logger - Optional logger function to log timeout errors (e.g., console.warn, logger.warn) + * @returns Promise that resolves/rejects with the original promise or times out + * + * @example + * ```typescript + * const result = await withTimeout( + * fetchData(), + * 5000, + * 'Failed to fetch data within 5 seconds', + * console.warn + * ); + * ``` + */ +export async function withTimeout( + promise: Promise, + timeoutMs: number, + errorMessage?: string, + logger?: (message: string, error: Error) => void, +): Promise { + let timeoutId: NodeJS.Timeout; + + const timeoutPromise = new Promise((_, reject) => { + timeoutId = setTimeout(() => { + const error = new Error(errorMessage ?? `Operation timed out after ${timeoutMs}ms`); + if (logger) logger(error.message, error); + reject(error); + }, timeoutMs); + }); + + try { + return await Promise.race([promise, timeoutPromise]); + } finally { + clearTimeout(timeoutId!); + } +} From 14e494136737f44755ca9942a253da767c3f28c0 Mon Sep 17 00:00:00 2001 From: Max Sanna Date: Tue, 4 Nov 2025 19:40:24 +0100 Subject: [PATCH 007/207] =?UTF-8?q?=F0=9F=93=8E=20fix:=20Document=20Upload?= =?UTF-8?q?s=20for=20Custom=20Endpoints=20(#10336)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fixed upload to provider for custom endpoints + unit tests * fix: add support back for agents to be able to use Upload to Provider with supported providers * ci: add test for agents endpoint still recognizing document supported providers * chore: address ESLint suggestions * Improved unit tests * Linting error on unit tests fixed --------- Co-authored-by: Dustin Healy --- .../Chat/Input/Files/AttachFileMenu.tsx | 6 +- .../Chat/Input/Files/DragDropModal.tsx | 2 +- .../Files/__tests__/AttachFileMenu.spec.tsx | 602 ++++++++++++++++++ .../Files/__tests__/DragDropModal.spec.tsx | 121 ++++ 4 files changed, 728 insertions(+), 3 deletions(-) create mode 100644 client/src/components/Chat/Input/Files/__tests__/AttachFileMenu.spec.tsx create mode 100644 client/src/components/Chat/Input/Files/__tests__/DragDropModal.spec.tsx diff --git a/client/src/components/Chat/Input/Files/AttachFileMenu.tsx b/client/src/components/Chat/Input/Files/AttachFileMenu.tsx index a3e5a8d304..821678cfc8 100644 --- a/client/src/components/Chat/Input/Files/AttachFileMenu.tsx +++ b/client/src/components/Chat/Input/Files/AttachFileMenu.tsx @@ -117,8 +117,10 @@ const AttachFileMenu = ({ const items: MenuItemProps[] = []; const currentProvider = provider || endpoint; - - if (isDocumentSupportedProvider(currentProvider || endpointType)) { + if ( + isDocumentSupportedProvider(endpointType) || + isDocumentSupportedProvider(currentProvider) + ) { items.push({ label: localize('com_ui_upload_provider'), onClick: () => { diff --git a/client/src/components/Chat/Input/Files/DragDropModal.tsx b/client/src/components/Chat/Input/Files/DragDropModal.tsx index d9003de3dc..015a590d55 100644 --- a/client/src/components/Chat/Input/Files/DragDropModal.tsx +++ b/client/src/components/Chat/Input/Files/DragDropModal.tsx @@ -57,7 +57,7 @@ const DragDropModal = ({ onOptionSelect, setShowModal, files, isVisible }: DragD const currentProvider = provider || endpoint; // Check if provider supports document upload - if (isDocumentSupportedProvider(currentProvider || endpointType)) { + if (isDocumentSupportedProvider(endpointType) || isDocumentSupportedProvider(currentProvider)) { const isGoogleProvider = currentProvider === EModelEndpoint.google; const validFileTypes = isGoogleProvider ? files.every( diff --git a/client/src/components/Chat/Input/Files/__tests__/AttachFileMenu.spec.tsx b/client/src/components/Chat/Input/Files/__tests__/AttachFileMenu.spec.tsx new file mode 100644 index 0000000000..36c4ee40e7 --- /dev/null +++ b/client/src/components/Chat/Input/Files/__tests__/AttachFileMenu.spec.tsx @@ -0,0 +1,602 @@ +import React from 'react'; +import { render, screen, fireEvent } from '@testing-library/react'; +import '@testing-library/jest-dom'; +import { RecoilRoot } from 'recoil'; +import { QueryClient, QueryClientProvider } from '@tanstack/react-query'; +import { EModelEndpoint } from 'librechat-data-provider'; +import AttachFileMenu from '../AttachFileMenu'; + +// Mock all the hooks +jest.mock('~/hooks', () => ({ + useAgentToolPermissions: jest.fn(), + useAgentCapabilities: jest.fn(), + useGetAgentsConfig: jest.fn(), + useFileHandling: jest.fn(), + useLocalize: jest.fn(), +})); + +jest.mock('~/hooks/Files/useSharePointFileHandling', () => ({ + __esModule: true, + default: jest.fn(), +})); + +jest.mock('~/data-provider', () => ({ + useGetStartupConfig: jest.fn(), +})); + +jest.mock('~/components/SharePoint', () => ({ + SharePointPickerDialog: jest.fn(() => null), +})); + +jest.mock('@librechat/client', () => { + const React = jest.requireActual('react'); + return { + FileUpload: React.forwardRef(({ children, handleFileChange }: any, ref: any) => ( +
+ + {children} +
+ )), + TooltipAnchor: ({ render }: any) => render, + DropdownPopup: ({ trigger, items, isOpen, setIsOpen }: any) => { + const handleTriggerClick = () => { + if (setIsOpen) { + setIsOpen(!isOpen); + } + }; + + return ( +
+
{trigger}
+ {isOpen && ( +
+ {items.map((item: any, idx: number) => ( + + ))} +
+ )} +
+ ); + }, + AttachmentIcon: () => 📎, + SharePointIcon: () => SP, + }; +}); + +jest.mock('@ariakit/react', () => ({ + MenuButton: ({ children, onClick, disabled, ...props }: any) => ( + + ), +})); + +const mockUseAgentToolPermissions = jest.requireMock('~/hooks').useAgentToolPermissions; +const mockUseAgentCapabilities = jest.requireMock('~/hooks').useAgentCapabilities; +const mockUseGetAgentsConfig = jest.requireMock('~/hooks').useGetAgentsConfig; +const mockUseFileHandling = jest.requireMock('~/hooks').useFileHandling; +const mockUseLocalize = jest.requireMock('~/hooks').useLocalize; +const mockUseSharePointFileHandling = jest.requireMock( + '~/hooks/Files/useSharePointFileHandling', +).default; +const mockUseGetStartupConfig = jest.requireMock('~/data-provider').useGetStartupConfig; + +describe('AttachFileMenu', () => { + const queryClient = new QueryClient({ + defaultOptions: { + queries: { retry: false }, + }, + }); + + const mockHandleFileChange = jest.fn(); + + beforeEach(() => { + jest.clearAllMocks(); + + // Default mock implementations + mockUseLocalize.mockReturnValue((key: string) => { + const translations: Record = { + com_ui_upload_provider: 'Upload to Provider', + com_ui_upload_image_input: 'Upload Image', + com_ui_upload_ocr_text: 'Upload OCR Text', + com_ui_upload_file_search: 'Upload for File Search', + com_ui_upload_code_files: 'Upload Code Files', + com_sidepanel_attach_files: 'Attach Files', + com_files_upload_sharepoint: 'Upload from SharePoint', + }; + return translations[key] || key; + }); + + mockUseAgentCapabilities.mockReturnValue({ + contextEnabled: false, + fileSearchEnabled: false, + codeEnabled: false, + }); + + mockUseGetAgentsConfig.mockReturnValue({ + agentsConfig: { + capabilities: { + contextEnabled: false, + fileSearchEnabled: false, + codeEnabled: false, + }, + }, + }); + + mockUseFileHandling.mockReturnValue({ + handleFileChange: mockHandleFileChange, + }); + + mockUseSharePointFileHandling.mockReturnValue({ + handleSharePointFiles: jest.fn(), + isProcessing: false, + downloadProgress: 0, + }); + + mockUseGetStartupConfig.mockReturnValue({ + data: { + sharePointFilePickerEnabled: false, + }, + }); + + mockUseAgentToolPermissions.mockReturnValue({ + fileSearchAllowedByAgent: false, + codeAllowedByAgent: false, + provider: undefined, + }); + }); + + const renderAttachFileMenu = (props: any = {}) => { + return render( + + + + + , + ); + }; + + describe('Basic Rendering', () => { + it('should render the attachment button', () => { + renderAttachFileMenu(); + const button = screen.getByRole('button', { name: /attach file options/i }); + expect(button).toBeInTheDocument(); + }); + + it('should be disabled when disabled prop is true', () => { + renderAttachFileMenu({ disabled: true }); + const button = screen.getByRole('button', { name: /attach file options/i }); + expect(button).toBeDisabled(); + }); + + it('should not be disabled when disabled prop is false', () => { + renderAttachFileMenu({ disabled: false }); + const button = screen.getByRole('button', { name: /attach file options/i }); + expect(button).not.toBeDisabled(); + }); + }); + + describe('Provider Detection Fix - endpointType Priority', () => { + it('should prioritize endpointType over currentProvider for LiteLLM gateway', () => { + mockUseAgentToolPermissions.mockReturnValue({ + fileSearchAllowedByAgent: false, + codeAllowedByAgent: false, + provider: 'litellm', // Custom gateway name NOT in documentSupportedProviders + }); + + renderAttachFileMenu({ + endpoint: 'litellm', + endpointType: EModelEndpoint.openAI, // Backend override IS in documentSupportedProviders + }); + + const button = screen.getByRole('button', { name: /attach file options/i }); + fireEvent.click(button); + + // With the fix, should show "Upload to Provider" because endpointType is checked first + expect(screen.getByText('Upload to Provider')).toBeInTheDocument(); + expect(screen.queryByText('Upload Image')).not.toBeInTheDocument(); + }); + + it('should show Upload to Provider for custom endpoints with OpenAI endpointType', () => { + mockUseAgentToolPermissions.mockReturnValue({ + fileSearchAllowedByAgent: false, + codeAllowedByAgent: false, + provider: 'my-custom-gateway', + }); + + renderAttachFileMenu({ + endpoint: 'my-custom-gateway', + endpointType: EModelEndpoint.openAI, + }); + + const button = screen.getByRole('button', { name: /attach file options/i }); + fireEvent.click(button); + + expect(screen.getByText('Upload to Provider')).toBeInTheDocument(); + }); + + it('should show Upload Image when neither endpointType nor provider support documents', () => { + mockUseAgentToolPermissions.mockReturnValue({ + fileSearchAllowedByAgent: false, + codeAllowedByAgent: false, + provider: 'unsupported-provider', + }); + + renderAttachFileMenu({ + endpoint: 'unsupported-provider', + endpointType: 'unsupported-endpoint' as any, + }); + + const button = screen.getByRole('button', { name: /attach file options/i }); + fireEvent.click(button); + + expect(screen.getByText('Upload Image')).toBeInTheDocument(); + expect(screen.queryByText('Upload to Provider')).not.toBeInTheDocument(); + }); + + it('should fallback to currentProvider when endpointType is undefined', () => { + mockUseAgentToolPermissions.mockReturnValue({ + fileSearchAllowedByAgent: false, + codeAllowedByAgent: false, + provider: EModelEndpoint.openAI, + }); + + renderAttachFileMenu({ + endpoint: EModelEndpoint.openAI, + endpointType: undefined, + }); + + const button = screen.getByRole('button', { name: /attach file options/i }); + fireEvent.click(button); + + expect(screen.getByText('Upload to Provider')).toBeInTheDocument(); + }); + + it('should fallback to currentProvider when endpointType is null', () => { + mockUseAgentToolPermissions.mockReturnValue({ + fileSearchAllowedByAgent: false, + codeAllowedByAgent: false, + provider: EModelEndpoint.anthropic, + }); + + renderAttachFileMenu({ + endpoint: EModelEndpoint.anthropic, + endpointType: null, + }); + + const button = screen.getByRole('button', { name: /attach file options/i }); + fireEvent.click(button); + + expect(screen.getByText('Upload to Provider')).toBeInTheDocument(); + }); + }); + + describe('Supported Providers', () => { + const supportedProviders = [ + { name: 'OpenAI', endpoint: EModelEndpoint.openAI }, + { name: 'Anthropic', endpoint: EModelEndpoint.anthropic }, + { name: 'Google', endpoint: EModelEndpoint.google }, + { name: 'Azure OpenAI', endpoint: EModelEndpoint.azureOpenAI }, + { name: 'Custom', endpoint: EModelEndpoint.custom }, + ]; + + supportedProviders.forEach(({ name, endpoint }) => { + it(`should show Upload to Provider for ${name}`, () => { + mockUseAgentToolPermissions.mockReturnValue({ + fileSearchAllowedByAgent: false, + codeAllowedByAgent: false, + provider: endpoint, + }); + + renderAttachFileMenu({ + endpoint, + endpointType: endpoint, + }); + + const button = screen.getByRole('button', { name: /attach file options/i }); + fireEvent.click(button); + + expect(screen.getByText('Upload to Provider')).toBeInTheDocument(); + }); + }); + }); + + describe('Agent Capabilities', () => { + it('should show OCR Text option when context is enabled', () => { + mockUseAgentCapabilities.mockReturnValue({ + contextEnabled: true, + fileSearchEnabled: false, + codeEnabled: false, + }); + + renderAttachFileMenu({ + endpointType: EModelEndpoint.openAI, + }); + + const button = screen.getByRole('button', { name: /attach file options/i }); + fireEvent.click(button); + + expect(screen.getByText('Upload OCR Text')).toBeInTheDocument(); + }); + + it('should show File Search option when enabled and allowed by agent', () => { + mockUseAgentCapabilities.mockReturnValue({ + contextEnabled: false, + fileSearchEnabled: true, + codeEnabled: false, + }); + + mockUseAgentToolPermissions.mockReturnValue({ + fileSearchAllowedByAgent: true, + codeAllowedByAgent: false, + provider: undefined, + }); + + renderAttachFileMenu({ + endpointType: EModelEndpoint.openAI, + }); + + const button = screen.getByRole('button', { name: /attach file options/i }); + fireEvent.click(button); + + expect(screen.getByText('Upload for File Search')).toBeInTheDocument(); + }); + + it('should NOT show File Search when enabled but not allowed by agent', () => { + mockUseAgentCapabilities.mockReturnValue({ + contextEnabled: false, + fileSearchEnabled: true, + codeEnabled: false, + }); + + mockUseAgentToolPermissions.mockReturnValue({ + fileSearchAllowedByAgent: false, + codeAllowedByAgent: false, + provider: undefined, + }); + + renderAttachFileMenu({ + endpointType: EModelEndpoint.openAI, + }); + + const button = screen.getByRole('button', { name: /attach file options/i }); + fireEvent.click(button); + + expect(screen.queryByText('Upload for File Search')).not.toBeInTheDocument(); + }); + + it('should show Code Files option when enabled and allowed by agent', () => { + mockUseAgentCapabilities.mockReturnValue({ + contextEnabled: false, + fileSearchEnabled: false, + codeEnabled: true, + }); + + mockUseAgentToolPermissions.mockReturnValue({ + fileSearchAllowedByAgent: false, + codeAllowedByAgent: true, + provider: undefined, + }); + + renderAttachFileMenu({ + endpointType: EModelEndpoint.openAI, + }); + + const button = screen.getByRole('button', { name: /attach file options/i }); + fireEvent.click(button); + + expect(screen.getByText('Upload Code Files')).toBeInTheDocument(); + }); + + it('should show all options when all capabilities are enabled', () => { + mockUseAgentCapabilities.mockReturnValue({ + contextEnabled: true, + fileSearchEnabled: true, + codeEnabled: true, + }); + + mockUseAgentToolPermissions.mockReturnValue({ + fileSearchAllowedByAgent: true, + codeAllowedByAgent: true, + provider: undefined, + }); + + renderAttachFileMenu({ + endpointType: EModelEndpoint.openAI, + }); + + const button = screen.getByRole('button', { name: /attach file options/i }); + fireEvent.click(button); + + expect(screen.getByText('Upload to Provider')).toBeInTheDocument(); + expect(screen.getByText('Upload OCR Text')).toBeInTheDocument(); + expect(screen.getByText('Upload for File Search')).toBeInTheDocument(); + expect(screen.getByText('Upload Code Files')).toBeInTheDocument(); + }); + }); + + describe('SharePoint Integration', () => { + it('should show SharePoint option when enabled', () => { + mockUseGetStartupConfig.mockReturnValue({ + data: { + sharePointFilePickerEnabled: true, + }, + }); + + renderAttachFileMenu({ + endpointType: EModelEndpoint.openAI, + }); + + const button = screen.getByRole('button', { name: /attach file options/i }); + fireEvent.click(button); + + expect(screen.getByText('Upload from SharePoint')).toBeInTheDocument(); + }); + + it('should NOT show SharePoint option when disabled', () => { + mockUseGetStartupConfig.mockReturnValue({ + data: { + sharePointFilePickerEnabled: false, + }, + }); + + renderAttachFileMenu({ + endpointType: EModelEndpoint.openAI, + }); + + const button = screen.getByRole('button', { name: /attach file options/i }); + fireEvent.click(button); + + expect(screen.queryByText('Upload from SharePoint')).not.toBeInTheDocument(); + }); + }); + + describe('Edge Cases', () => { + it('should handle undefined endpoint and provider gracefully', () => { + mockUseAgentToolPermissions.mockReturnValue({ + fileSearchAllowedByAgent: false, + codeAllowedByAgent: false, + provider: undefined, + }); + + renderAttachFileMenu({ + endpoint: undefined, + endpointType: undefined, + }); + + const button = screen.getByRole('button', { name: /attach file options/i }); + expect(button).toBeInTheDocument(); + fireEvent.click(button); + + // Should show Upload Image as fallback + expect(screen.getByText('Upload Image')).toBeInTheDocument(); + }); + + it('should handle null endpoint and provider gracefully', () => { + mockUseAgentToolPermissions.mockReturnValue({ + fileSearchAllowedByAgent: false, + codeAllowedByAgent: false, + provider: null, + }); + + renderAttachFileMenu({ + endpoint: null, + endpointType: null, + }); + + const button = screen.getByRole('button', { name: /attach file options/i }); + expect(button).toBeInTheDocument(); + }); + + it('should handle missing agentId gracefully', () => { + renderAttachFileMenu({ + agentId: undefined, + endpointType: EModelEndpoint.openAI, + }); + + const button = screen.getByRole('button', { name: /attach file options/i }); + expect(button).toBeInTheDocument(); + }); + + it('should handle empty string agentId', () => { + renderAttachFileMenu({ + agentId: '', + endpointType: EModelEndpoint.openAI, + }); + + const button = screen.getByRole('button', { name: /attach file options/i }); + expect(button).toBeInTheDocument(); + }); + }); + + describe('Google Provider Special Case', () => { + it('should use google_multimodal file type for Google provider', () => { + mockUseAgentToolPermissions.mockReturnValue({ + fileSearchAllowedByAgent: false, + codeAllowedByAgent: false, + provider: EModelEndpoint.google, + }); + + renderAttachFileMenu({ + endpoint: EModelEndpoint.google, + endpointType: EModelEndpoint.google, + }); + + const button = screen.getByRole('button', { name: /attach file options/i }); + fireEvent.click(button); + + const uploadProviderButton = screen.getByText('Upload to Provider'); + expect(uploadProviderButton).toBeInTheDocument(); + + // Click the upload to provider option + fireEvent.click(uploadProviderButton); + + // The file input should have been clicked (indirectly tested through the implementation) + }); + + it('should use multimodal file type for non-Google providers', () => { + mockUseAgentToolPermissions.mockReturnValue({ + fileSearchAllowedByAgent: false, + codeAllowedByAgent: false, + provider: EModelEndpoint.openAI, + }); + + renderAttachFileMenu({ + endpoint: EModelEndpoint.openAI, + endpointType: EModelEndpoint.openAI, + }); + + const button = screen.getByRole('button', { name: /attach file options/i }); + fireEvent.click(button); + + const uploadProviderButton = screen.getByText('Upload to Provider'); + expect(uploadProviderButton).toBeInTheDocument(); + fireEvent.click(uploadProviderButton); + + // Implementation detail - multimodal type is used + }); + }); + + describe('Regression Tests', () => { + it('should not break the previous behavior for direct provider attachments', () => { + // When using a direct supported provider (not through a gateway) + mockUseAgentToolPermissions.mockReturnValue({ + fileSearchAllowedByAgent: false, + codeAllowedByAgent: false, + provider: EModelEndpoint.anthropic, + }); + + renderAttachFileMenu({ + endpoint: EModelEndpoint.anthropic, + endpointType: EModelEndpoint.anthropic, + }); + + const button = screen.getByRole('button', { name: /attach file options/i }); + fireEvent.click(button); + + expect(screen.getByText('Upload to Provider')).toBeInTheDocument(); + }); + + it('should maintain correct priority when both are supported', () => { + // Both endpointType and provider are supported, endpointType should be checked first + mockUseAgentToolPermissions.mockReturnValue({ + fileSearchAllowedByAgent: false, + codeAllowedByAgent: false, + provider: EModelEndpoint.google, + }); + + renderAttachFileMenu({ + endpoint: EModelEndpoint.google, + endpointType: EModelEndpoint.openAI, // Different but both supported + }); + + const button = screen.getByRole('button', { name: /attach file options/i }); + fireEvent.click(button); + + // Should still work because endpointType (openAI) is supported + expect(screen.getByText('Upload to Provider')).toBeInTheDocument(); + }); + }); +}); diff --git a/client/src/components/Chat/Input/Files/__tests__/DragDropModal.spec.tsx b/client/src/components/Chat/Input/Files/__tests__/DragDropModal.spec.tsx new file mode 100644 index 0000000000..2adad63b9a --- /dev/null +++ b/client/src/components/Chat/Input/Files/__tests__/DragDropModal.spec.tsx @@ -0,0 +1,121 @@ +import { EModelEndpoint, isDocumentSupportedProvider } from 'librechat-data-provider'; + +describe('DragDropModal - Provider Detection', () => { + describe('endpointType priority over currentProvider', () => { + it('should show upload option for LiteLLM with OpenAI endpointType', () => { + const currentProvider = 'litellm'; // NOT in documentSupportedProviders + const endpointType = EModelEndpoint.openAI; // IS in documentSupportedProviders + + // With fix: endpointType checked + const withFix = + isDocumentSupportedProvider(endpointType) || isDocumentSupportedProvider(currentProvider); + expect(withFix).toBe(true); + + // Without fix: only currentProvider checked = false + const withoutFix = isDocumentSupportedProvider(currentProvider || endpointType); + expect(withoutFix).toBe(false); + }); + + it('should show upload option for any custom gateway with OpenAI endpointType', () => { + const currentProvider = 'my-custom-gateway'; + const endpointType = EModelEndpoint.openAI; + + const result = + isDocumentSupportedProvider(endpointType) || isDocumentSupportedProvider(currentProvider); + expect(result).toBe(true); + }); + + it('should fallback to currentProvider when endpointType is undefined', () => { + const currentProvider = EModelEndpoint.openAI; + const endpointType = undefined; + + const result = + isDocumentSupportedProvider(endpointType) || isDocumentSupportedProvider(currentProvider); + expect(result).toBe(true); + }); + + it('should fallback to currentProvider when endpointType is null', () => { + const currentProvider = EModelEndpoint.anthropic; + const endpointType = null; + + const result = + isDocumentSupportedProvider(endpointType as any) || + isDocumentSupportedProvider(currentProvider); + expect(result).toBe(true); + }); + + it('should return false when neither provider supports documents', () => { + const currentProvider = 'unsupported-provider'; + const endpointType = 'unsupported-endpoint' as any; + + const result = + isDocumentSupportedProvider(endpointType) || isDocumentSupportedProvider(currentProvider); + expect(result).toBe(false); + }); + }); + + describe('supported providers', () => { + const supportedProviders = [ + { name: 'OpenAI', value: EModelEndpoint.openAI }, + { name: 'Anthropic', value: EModelEndpoint.anthropic }, + { name: 'Google', value: EModelEndpoint.google }, + { name: 'Azure OpenAI', value: EModelEndpoint.azureOpenAI }, + { name: 'Custom', value: EModelEndpoint.custom }, + ]; + + supportedProviders.forEach(({ name, value }) => { + it(`should recognize ${name} as supported`, () => { + expect(isDocumentSupportedProvider(value)).toBe(true); + }); + }); + }); + + describe('real-world scenarios', () => { + it('should handle LiteLLM gateway pointing to OpenAI', () => { + const scenario = { + currentProvider: 'litellm', + endpointType: EModelEndpoint.openAI, + }; + + expect( + isDocumentSupportedProvider(scenario.endpointType) || + isDocumentSupportedProvider(scenario.currentProvider), + ).toBe(true); + }); + + it('should handle direct OpenAI connection', () => { + const scenario = { + currentProvider: EModelEndpoint.openAI, + endpointType: EModelEndpoint.openAI, + }; + + expect( + isDocumentSupportedProvider(scenario.endpointType) || + isDocumentSupportedProvider(scenario.currentProvider), + ).toBe(true); + }); + + it('should handle unsupported custom endpoint without override', () => { + const scenario = { + currentProvider: 'my-unsupported-endpoint', + endpointType: undefined, + }; + + expect( + isDocumentSupportedProvider(scenario.endpointType) || + isDocumentSupportedProvider(scenario.currentProvider), + ).toBe(false); + }); + it('should handle agents endpoints with document supported providers', () => { + const scenario = { + currentProvider: EModelEndpoint.google, + endpointType: EModelEndpoint.agents, + }; + + expect( + isDocumentSupportedProvider(scenario.endpointType) || + isDocumentSupportedProvider(scenario.currentProvider), + ).toBe(true); + }); + }); +}); From c9e1127b85c8a47d27e9e74b638a13cf293bb019 Mon Sep 17 00:00:00 2001 From: Eduardo Cruz Guedes Date: Tue, 4 Nov 2025 15:52:47 -0300 Subject: [PATCH 008/207] =?UTF-8?q?=F0=9F=8C=85=20docs:=20Add=20OpenAI=20I?= =?UTF-8?q?mage=20Gen=20Env=20Vars=20(#10335)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .env.example | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.env.example b/.env.example index 24c74487aa..10e299e72b 100644 --- a/.env.example +++ b/.env.example @@ -254,6 +254,10 @@ AZURE_AI_SEARCH_SEARCH_OPTION_SELECT= # OpenAI Image Tools Customization #---------------- +# IMAGE_GEN_OAI_API_KEY= # Create or reuse OpenAI API key for image generation tool +# IMAGE_GEN_OAI_BASEURL= # Custom OpenAI base URL for image generation tool +# IMAGE_GEN_OAI_AZURE_API_VERSION= # Custom Azure OpenAI deployments +# IMAGE_GEN_OAI_DESCRIPTION= # IMAGE_GEN_OAI_DESCRIPTION_WITH_FILES=Custom description for image generation tool when files are present # IMAGE_GEN_OAI_DESCRIPTION_NO_FILES=Custom description for image generation tool when no files are present # IMAGE_EDIT_OAI_DESCRIPTION=Custom description for image editing tool From 06fcf79d56cf3a7a5fa2dff9861969b9ac58f067 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Wed, 5 Nov 2025 09:20:35 -0500 Subject: [PATCH 009/207] =?UTF-8?q?=F0=9F=9B=82=20feat:=20Social=20Login?= =?UTF-8?q?=20by=20Provider=20ID=20First=20then=20Email=20(#10358)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/strategies/appleStrategy.test.js | 1 + api/strategies/process.js | 20 +- api/strategies/process.test.js | 72 +++++++ api/strategies/socialLogin.js | 18 +- api/strategies/socialLogin.test.js | 276 +++++++++++++++++++++++++++ 5 files changed, 381 insertions(+), 6 deletions(-) create mode 100644 api/strategies/socialLogin.test.js diff --git a/api/strategies/appleStrategy.test.js b/api/strategies/appleStrategy.test.js index d8ba4616f2..d142d27eac 100644 --- a/api/strategies/appleStrategy.test.js +++ b/api/strategies/appleStrategy.test.js @@ -304,6 +304,7 @@ describe('Apple Login Strategy', () => { fileStrategy: 'local', balance: { enabled: false }, }), + 'jane.doe@example.com', ); }); diff --git a/api/strategies/process.js b/api/strategies/process.js index 8f70cd86ce..c1e0ad0bbc 100644 --- a/api/strategies/process.js +++ b/api/strategies/process.js @@ -5,22 +5,25 @@ const { resizeAvatar } = require('~/server/services/Files/images/avatar'); const { updateUser, createUser, getUserById } = require('~/models'); /** - * Updates the avatar URL of an existing user. If the user's avatar URL does not include the query parameter + * Updates the avatar URL and email of an existing user. If the user's avatar URL does not include the query parameter * '?manual=true', it updates the user's avatar with the provided URL. For local file storage, it directly updates * the avatar URL, while for other storage types, it processes the avatar URL using the specified file strategy. + * Also updates the email if it has changed (e.g., when a Google Workspace email is updated). * * @param {IUser} oldUser - The existing user object that needs to be updated. * @param {string} avatarUrl - The new avatar URL to be set for the user. * @param {AppConfig} appConfig - The application configuration object. + * @param {string} [email] - Optional. The new email address to update if it has changed. * * @returns {Promise} - * The function updates the user's avatar and saves the user object. It does not return any value. + * The function updates the user's avatar and/or email and saves the user object. It does not return any value. * * @throws {Error} Throws an error if there's an issue saving the updated user object. */ -const handleExistingUser = async (oldUser, avatarUrl, appConfig) => { +const handleExistingUser = async (oldUser, avatarUrl, appConfig, email) => { const fileStrategy = appConfig?.fileStrategy ?? process.env.CDN_PROVIDER; const isLocal = fileStrategy === FileSources.local; + const updates = {}; let updatedAvatar = false; const hasManualFlag = @@ -39,7 +42,16 @@ const handleExistingUser = async (oldUser, avatarUrl, appConfig) => { } if (updatedAvatar) { - await updateUser(oldUser._id, { avatar: updatedAvatar }); + updates.avatar = updatedAvatar; + } + + /** Update email if it has changed */ + if (email && email.trim() !== oldUser.email) { + updates.email = email.trim(); + } + + if (Object.keys(updates).length > 0) { + await updateUser(oldUser._id, updates); } }; diff --git a/api/strategies/process.test.js b/api/strategies/process.test.js index ceb7d21a64..ab5fdb651f 100644 --- a/api/strategies/process.test.js +++ b/api/strategies/process.test.js @@ -167,4 +167,76 @@ describe('handleExistingUser', () => { // This should throw an error when trying to access oldUser._id await expect(handleExistingUser(null, avatarUrl)).rejects.toThrow(); }); + + it('should update email when it has changed', async () => { + const oldUser = { + _id: 'user123', + email: 'old@example.com', + avatar: 'https://example.com/avatar.png?manual=true', + }; + const avatarUrl = 'https://example.com/avatar.png'; + const newEmail = 'new@example.com'; + + await handleExistingUser(oldUser, avatarUrl, {}, newEmail); + + expect(updateUser).toHaveBeenCalledWith('user123', { email: 'new@example.com' }); + }); + + it('should update both avatar and email when both have changed', async () => { + const oldUser = { + _id: 'user123', + email: 'old@example.com', + avatar: null, + }; + const avatarUrl = 'https://example.com/new-avatar.png'; + const newEmail = 'new@example.com'; + + await handleExistingUser(oldUser, avatarUrl, {}, newEmail); + + expect(updateUser).toHaveBeenCalledWith('user123', { + avatar: avatarUrl, + email: 'new@example.com', + }); + }); + + it('should not update email when it has not changed', async () => { + const oldUser = { + _id: 'user123', + email: 'same@example.com', + avatar: 'https://example.com/avatar.png?manual=true', + }; + const avatarUrl = 'https://example.com/avatar.png'; + const sameEmail = 'same@example.com'; + + await handleExistingUser(oldUser, avatarUrl, {}, sameEmail); + + expect(updateUser).not.toHaveBeenCalled(); + }); + + it('should trim email before comparison and update', async () => { + const oldUser = { + _id: 'user123', + email: 'test@example.com', + avatar: 'https://example.com/avatar.png?manual=true', + }; + const avatarUrl = 'https://example.com/avatar.png'; + const newEmailWithSpaces = ' newemail@example.com '; + + await handleExistingUser(oldUser, avatarUrl, {}, newEmailWithSpaces); + + expect(updateUser).toHaveBeenCalledWith('user123', { email: 'newemail@example.com' }); + }); + + it('should not update when email parameter is not provided', async () => { + const oldUser = { + _id: 'user123', + email: 'test@example.com', + avatar: 'https://example.com/avatar.png?manual=true', + }; + const avatarUrl = 'https://example.com/avatar.png'; + + await handleExistingUser(oldUser, avatarUrl, {}); + + expect(updateUser).not.toHaveBeenCalled(); + }); }); diff --git a/api/strategies/socialLogin.js b/api/strategies/socialLogin.js index bad70cc040..88fb347042 100644 --- a/api/strategies/socialLogin.js +++ b/api/strategies/socialLogin.js @@ -25,10 +25,24 @@ const socialLogin = return cb(error); } - const existingUser = await findUser({ email: email.trim() }); + const providerKey = `${provider}Id`; + let existingUser = null; + + /** First try to find user by provider ID (e.g., googleId, facebookId) */ + if (id && typeof id === 'string') { + existingUser = await findUser({ [providerKey]: id }); + } + + /** If not found by provider ID, try finding by email */ + if (!existingUser) { + existingUser = await findUser({ email: email?.trim() }); + if (existingUser) { + logger.warn(`[${provider}Login] User found by email: ${email} but not by ${providerKey}`); + } + } if (existingUser?.provider === provider) { - await handleExistingUser(existingUser, avatarUrl, appConfig); + await handleExistingUser(existingUser, avatarUrl, appConfig, email); return cb(null, existingUser); } else if (existingUser) { logger.info( diff --git a/api/strategies/socialLogin.test.js b/api/strategies/socialLogin.test.js new file mode 100644 index 0000000000..11ada17975 --- /dev/null +++ b/api/strategies/socialLogin.test.js @@ -0,0 +1,276 @@ +const { logger } = require('@librechat/data-schemas'); +const { ErrorTypes } = require('librechat-data-provider'); +const { createSocialUser, handleExistingUser } = require('./process'); +const socialLogin = require('./socialLogin'); +const { findUser } = require('~/models'); + +jest.mock('@librechat/data-schemas', () => { + const actualModule = jest.requireActual('@librechat/data-schemas'); + return { + ...actualModule, + logger: { + error: jest.fn(), + info: jest.fn(), + warn: jest.fn(), + }, + }; +}); + +jest.mock('./process', () => ({ + createSocialUser: jest.fn(), + handleExistingUser: jest.fn(), +})); + +jest.mock('@librechat/api', () => ({ + ...jest.requireActual('@librechat/api'), + isEnabled: jest.fn().mockReturnValue(true), + isEmailDomainAllowed: jest.fn().mockReturnValue(true), +})); + +jest.mock('~/models', () => ({ + findUser: jest.fn(), +})); + +jest.mock('~/server/services/Config', () => ({ + getAppConfig: jest.fn().mockResolvedValue({ + fileStrategy: 'local', + balance: { enabled: false }, + }), +})); + +describe('socialLogin', () => { + const mockGetProfileDetails = ({ profile }) => ({ + email: profile.emails[0].value, + id: profile.id, + avatarUrl: profile.photos?.[0]?.value || null, + username: profile.name?.givenName || 'user', + name: `${profile.name?.givenName || ''} ${profile.name?.familyName || ''}`.trim(), + emailVerified: profile.emails[0].verified || false, + }); + + beforeEach(() => { + jest.clearAllMocks(); + }); + + describe('Finding users by provider ID', () => { + it('should find user by provider ID (googleId) when email has changed', async () => { + const provider = 'google'; + const googleId = 'google-user-123'; + const oldEmail = 'old@example.com'; + const newEmail = 'new@example.com'; + + const existingUser = { + _id: 'user123', + email: oldEmail, + provider: 'google', + googleId: googleId, + }; + + /** Mock findUser to return user on first call (by googleId), null on second call */ + findUser + .mockResolvedValueOnce(existingUser) // First call: finds by googleId + .mockResolvedValueOnce(null); // Second call would be by email, but won't be reached + + const mockProfile = { + id: googleId, + emails: [{ value: newEmail, verified: true }], + photos: [{ value: 'https://example.com/avatar.png' }], + name: { givenName: 'John', familyName: 'Doe' }, + }; + + const loginFn = socialLogin(provider, mockGetProfileDetails); + const callback = jest.fn(); + + await loginFn(null, null, null, mockProfile, callback); + + /** Verify it searched by googleId first */ + expect(findUser).toHaveBeenNthCalledWith(1, { googleId: googleId }); + + /** Verify it did NOT search by email (because it found user by googleId) */ + expect(findUser).toHaveBeenCalledTimes(1); + + /** Verify handleExistingUser was called with the new email */ + expect(handleExistingUser).toHaveBeenCalledWith( + existingUser, + 'https://example.com/avatar.png', + expect.any(Object), + newEmail, + ); + + /** Verify callback was called with success */ + expect(callback).toHaveBeenCalledWith(null, existingUser); + }); + + it('should find user by provider ID (facebookId) when using Facebook', async () => { + const provider = 'facebook'; + const facebookId = 'fb-user-456'; + const email = 'user@example.com'; + + const existingUser = { + _id: 'user456', + email: email, + provider: 'facebook', + facebookId: facebookId, + }; + + findUser.mockResolvedValue(existingUser); // Always returns user + + const mockProfile = { + id: facebookId, + emails: [{ value: email, verified: true }], + photos: [{ value: 'https://example.com/fb-avatar.png' }], + name: { givenName: 'Jane', familyName: 'Smith' }, + }; + + const loginFn = socialLogin(provider, mockGetProfileDetails); + const callback = jest.fn(); + + await loginFn(null, null, null, mockProfile, callback); + + /** Verify it searched by facebookId first */ + expect(findUser).toHaveBeenCalledWith({ facebookId: facebookId }); + expect(findUser.mock.calls[0]).toEqual([{ facebookId: facebookId }]); + + expect(handleExistingUser).toHaveBeenCalledWith( + existingUser, + 'https://example.com/fb-avatar.png', + expect.any(Object), + email, + ); + + expect(callback).toHaveBeenCalledWith(null, existingUser); + }); + + it('should fallback to finding user by email if not found by provider ID', async () => { + const provider = 'google'; + const googleId = 'google-user-789'; + const email = 'user@example.com'; + + const existingUser = { + _id: 'user789', + email: email, + provider: 'google', + googleId: 'old-google-id', // Different googleId (edge case) + }; + + /** First call (by googleId) returns null, second call (by email) returns user */ + findUser + .mockResolvedValueOnce(null) // By googleId + .mockResolvedValueOnce(existingUser); // By email + + const mockProfile = { + id: googleId, + emails: [{ value: email, verified: true }], + photos: [{ value: 'https://example.com/avatar.png' }], + name: { givenName: 'Bob', familyName: 'Johnson' }, + }; + + const loginFn = socialLogin(provider, mockGetProfileDetails); + const callback = jest.fn(); + + await loginFn(null, null, null, mockProfile, callback); + + /** Verify both searches happened */ + expect(findUser).toHaveBeenNthCalledWith(1, { googleId: googleId }); + expect(findUser).toHaveBeenNthCalledWith(2, { email: email }); + expect(findUser).toHaveBeenCalledTimes(2); + + /** Verify warning log */ + expect(logger.warn).toHaveBeenCalledWith( + `[${provider}Login] User found by email: ${email} but not by ${provider}Id`, + ); + + expect(handleExistingUser).toHaveBeenCalled(); + expect(callback).toHaveBeenCalledWith(null, existingUser); + }); + + it('should create new user if not found by provider ID or email', async () => { + const provider = 'google'; + const googleId = 'google-new-user'; + const email = 'newuser@example.com'; + + const newUser = { + _id: 'newuser123', + email: email, + provider: 'google', + googleId: googleId, + }; + + /** Both searches return null */ + findUser.mockResolvedValue(null); + createSocialUser.mockResolvedValue(newUser); + + const mockProfile = { + id: googleId, + emails: [{ value: email, verified: true }], + photos: [{ value: 'https://example.com/avatar.png' }], + name: { givenName: 'New', familyName: 'User' }, + }; + + const loginFn = socialLogin(provider, mockGetProfileDetails); + const callback = jest.fn(); + + await loginFn(null, null, null, mockProfile, callback); + + /** Verify both searches happened */ + expect(findUser).toHaveBeenCalledTimes(2); + + /** Verify createSocialUser was called */ + expect(createSocialUser).toHaveBeenCalledWith({ + email: email, + avatarUrl: 'https://example.com/avatar.png', + provider: provider, + providerKey: 'googleId', + providerId: googleId, + username: 'New', + name: 'New User', + emailVerified: true, + appConfig: expect.any(Object), + }); + + expect(callback).toHaveBeenCalledWith(null, newUser); + }); + }); + + describe('Error handling', () => { + it('should return error if user exists with different provider', async () => { + const provider = 'google'; + const googleId = 'google-user-123'; + const email = 'user@example.com'; + + const existingUser = { + _id: 'user123', + email: email, + provider: 'local', // Different provider + }; + + findUser + .mockResolvedValueOnce(null) // By googleId + .mockResolvedValueOnce(existingUser); // By email + + const mockProfile = { + id: googleId, + emails: [{ value: email, verified: true }], + photos: [{ value: 'https://example.com/avatar.png' }], + name: { givenName: 'John', familyName: 'Doe' }, + }; + + const loginFn = socialLogin(provider, mockGetProfileDetails); + const callback = jest.fn(); + + await loginFn(null, null, null, mockProfile, callback); + + /** Verify error callback */ + expect(callback).toHaveBeenCalledWith( + expect.objectContaining({ + code: ErrorTypes.AUTH_FAILED, + provider: 'local', + }), + ); + + expect(logger.info).toHaveBeenCalledWith( + `[${provider}Login] User ${email} already exists with provider local`, + ); + }); + }); +}); From 772b706e204e0e239cad999d96e8eabbfc166046 Mon Sep 17 00:00:00 2001 From: Rakshit Date: Wed, 5 Nov 2025 20:57:34 +0530 Subject: [PATCH 010/207] =?UTF-8?q?=F0=9F=8E=99=EF=B8=8F=20fix:=20Azure=20?= =?UTF-8?q?OpenAI=20Speech-to-Text=20400=20Bad=20Request=20Error=20(#10355?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/server/services/Files/Audio/STTService.js | 1 - packages/api/src/utils/azure.ts | 6 ++++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/api/server/services/Files/Audio/STTService.js b/api/server/services/Files/Audio/STTService.js index 60d6a48a14..16f806de4e 100644 --- a/api/server/services/Files/Audio/STTService.js +++ b/api/server/services/Files/Audio/STTService.js @@ -227,7 +227,6 @@ class STTService { } const headers = { - 'Content-Type': 'multipart/form-data', ...(apiKey && { 'api-key': apiKey }), }; diff --git a/packages/api/src/utils/azure.ts b/packages/api/src/utils/azure.ts index b4051d3d80..1bbd0e29b2 100644 --- a/packages/api/src/utils/azure.ts +++ b/packages/api/src/utils/azure.ts @@ -25,6 +25,12 @@ export const genAzureEndpoint = ({ azureOpenAIApiInstanceName: string; azureOpenAIApiDeploymentName: string; }): string => { + // Support both old (.openai.azure.com) and new (.cognitiveservices.azure.com) endpoint formats + // If instanceName already includes a full domain, use it as-is + if (azureOpenAIApiInstanceName.includes('.azure.com')) { + return `https://${azureOpenAIApiInstanceName}/openai/deployments/${azureOpenAIApiDeploymentName}`; + } + // Legacy format for backward compatibility return `https://${azureOpenAIApiInstanceName}.openai.azure.com/openai/deployments/${azureOpenAIApiDeploymentName}`; }; From 0f4222a908af5bfd00ed37879e00b412e37091e2 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Wed, 5 Nov 2025 10:28:06 -0500 Subject: [PATCH 011/207] =?UTF-8?q?=F0=9F=AA=9E=20fix:=20Prevent=20Revoked?= =?UTF-8?q?=20Blob=20URLs=20in=20Uploaded=20Images=20(FileRow)=20(#10361)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../components/Chat/Input/Files/FileRow.tsx | 2 +- .../Input/Files/__tests__/FileRow.spec.tsx | 347 ++++++++++++++++++ 2 files changed, 348 insertions(+), 1 deletion(-) create mode 100644 client/src/components/Chat/Input/Files/__tests__/FileRow.spec.tsx diff --git a/client/src/components/Chat/Input/Files/FileRow.tsx b/client/src/components/Chat/Input/Files/FileRow.tsx index babb0aef69..ea0b648015 100644 --- a/client/src/components/Chat/Input/Files/FileRow.tsx +++ b/client/src/components/Chat/Input/Files/FileRow.tsx @@ -133,7 +133,7 @@ export default function FileRow({ > {isImage ? ( ({ + useLocalize: jest.fn(), +})); + +jest.mock('~/data-provider', () => ({ + useDeleteFilesMutation: jest.fn(), +})); + +jest.mock('~/hooks/Files', () => ({ + useFileDeletion: jest.fn(), +})); + +jest.mock('~/utils', () => ({ + logger: { + log: jest.fn(), + }, +})); + +jest.mock('../Image', () => { + return function MockImage({ url, progress, source }: any) { + return ( +
+ {url} + {progress} + {source} +
+ ); + }; +}); + +jest.mock('../FileContainer', () => { + return function MockFileContainer({ file }: any) { + return ( +
+ {file.filename} +
+ ); + }; +}); + +const mockUseLocalize = jest.requireMock('~/hooks').useLocalize; +const mockUseDeleteFilesMutation = jest.requireMock('~/data-provider').useDeleteFilesMutation; +const mockUseFileDeletion = jest.requireMock('~/hooks/Files').useFileDeletion; + +describe('FileRow', () => { + const mockSetFiles = jest.fn(); + const mockSetFilesLoading = jest.fn(); + const mockDeleteFile = jest.fn(); + + beforeEach(() => { + jest.clearAllMocks(); + + mockUseLocalize.mockReturnValue((key: string) => { + const translations: Record = { + com_ui_deleting_file: 'Deleting file...', + }; + return translations[key] || key; + }); + + mockUseDeleteFilesMutation.mockReturnValue({ + mutateAsync: jest.fn(), + }); + + mockUseFileDeletion.mockReturnValue({ + deleteFile: mockDeleteFile, + }); + }); + + /** + * Creates a mock ExtendedFile with sensible defaults + */ + const createMockFile = (overrides: Partial = {}): ExtendedFile => ({ + file_id: 'test-file-id', + type: 'image/png', + preview: 'blob:http://localhost:3080/preview-blob-url', + filepath: '/images/user123/test-file-id__image.png', + filename: 'test-image.png', + progress: 1, + size: 1024, + source: FileSources.local, + ...overrides, + }); + + const renderFileRow = (files: Map) => { + return render( + , + ); + }; + + describe('Image URL Selection Logic', () => { + it('should use filepath instead of preview when progress is 1 (upload complete)', () => { + const file = createMockFile({ + file_id: 'uploaded-file', + preview: 'blob:http://localhost:3080/temp-preview', + filepath: '/images/user123/uploaded-file__image.png', + progress: 1, + }); + + const filesMap = new Map(); + filesMap.set(file.file_id, file); + + renderFileRow(filesMap); + + const imageUrl = screen.getByTestId('image-url').textContent; + expect(imageUrl).toBe('/images/user123/uploaded-file__image.png'); + expect(imageUrl).not.toContain('blob:'); + }); + + it('should use preview when progress is less than 1 (uploading)', () => { + const file = createMockFile({ + file_id: 'uploading-file', + preview: 'blob:http://localhost:3080/temp-preview', + filepath: undefined, + progress: 0.5, + }); + + const filesMap = new Map(); + filesMap.set(file.file_id, file); + + renderFileRow(filesMap); + + const imageUrl = screen.getByTestId('image-url').textContent; + expect(imageUrl).toBe('blob:http://localhost:3080/temp-preview'); + }); + + it('should fallback to filepath when preview is undefined and progress is less than 1', () => { + const file = createMockFile({ + file_id: 'file-without-preview', + preview: undefined, + filepath: '/images/user123/file-without-preview__image.png', + progress: 0.7, + }); + + const filesMap = new Map(); + filesMap.set(file.file_id, file); + + renderFileRow(filesMap); + + const imageUrl = screen.getByTestId('image-url').textContent; + expect(imageUrl).toBe('/images/user123/file-without-preview__image.png'); + }); + + it('should use filepath when both preview and filepath exist and progress is exactly 1', () => { + const file = createMockFile({ + file_id: 'complete-file', + preview: 'blob:http://localhost:3080/old-blob', + filepath: '/images/user123/complete-file__image.png', + progress: 1.0, + }); + + const filesMap = new Map(); + filesMap.set(file.file_id, file); + + renderFileRow(filesMap); + + const imageUrl = screen.getByTestId('image-url').textContent; + expect(imageUrl).toBe('/images/user123/complete-file__image.png'); + }); + }); + + describe('Progress States', () => { + it('should pass correct progress value during upload', () => { + const file = createMockFile({ + progress: 0.65, + }); + + const filesMap = new Map(); + filesMap.set(file.file_id, file); + + renderFileRow(filesMap); + + const progress = screen.getByTestId('image-progress').textContent; + expect(progress).toBe('0.65'); + }); + + it('should pass progress value of 1 when upload is complete', () => { + const file = createMockFile({ + progress: 1, + }); + + const filesMap = new Map(); + filesMap.set(file.file_id, file); + + renderFileRow(filesMap); + + const progress = screen.getByTestId('image-progress').textContent; + expect(progress).toBe('1'); + }); + }); + + describe('File Source', () => { + it('should pass local source to Image component', () => { + const file = createMockFile({ + source: FileSources.local, + }); + + const filesMap = new Map(); + filesMap.set(file.file_id, file); + + renderFileRow(filesMap); + + const source = screen.getByTestId('image-source').textContent; + expect(source).toBe(FileSources.local); + }); + + it('should pass openai source to Image component', () => { + const file = createMockFile({ + source: FileSources.openai, + }); + + const filesMap = new Map(); + filesMap.set(file.file_id, file); + + renderFileRow(filesMap); + + const source = screen.getByTestId('image-source').textContent; + expect(source).toBe(FileSources.openai); + }); + }); + + describe('File Type Detection', () => { + it('should render Image component for image files', () => { + const file = createMockFile({ + type: 'image/jpeg', + }); + + const filesMap = new Map(); + filesMap.set(file.file_id, file); + + renderFileRow(filesMap); + + expect(screen.getByTestId('mock-image')).toBeInTheDocument(); + expect(screen.queryByTestId('mock-file-container')).not.toBeInTheDocument(); + }); + + it('should render FileContainer for non-image files', () => { + const file = createMockFile({ + type: 'application/pdf', + filename: 'document.pdf', + }); + + const filesMap = new Map(); + filesMap.set(file.file_id, file); + + renderFileRow(filesMap); + + expect(screen.getByTestId('mock-file-container')).toBeInTheDocument(); + expect(screen.queryByTestId('mock-image')).not.toBeInTheDocument(); + }); + }); + + describe('Multiple Files', () => { + it('should render multiple image files with correct URLs based on their progress', () => { + const filesMap = new Map(); + + const uploadingFile = createMockFile({ + file_id: 'file-1', + preview: 'blob:http://localhost:3080/preview-1', + filepath: undefined, + progress: 0.3, + }); + + const completedFile = createMockFile({ + file_id: 'file-2', + preview: 'blob:http://localhost:3080/preview-2', + filepath: '/images/user123/file-2__image.png', + progress: 1, + }); + + filesMap.set(uploadingFile.file_id, uploadingFile); + filesMap.set(completedFile.file_id, completedFile); + + renderFileRow(filesMap); + + const images = screen.getAllByTestId('mock-image'); + expect(images).toHaveLength(2); + + const urls = screen.getAllByTestId('image-url').map((el) => el.textContent); + expect(urls).toContain('blob:http://localhost:3080/preview-1'); + expect(urls).toContain('/images/user123/file-2__image.png'); + }); + + it('should deduplicate files with the same file_id', () => { + const filesMap = new Map(); + + const file1 = createMockFile({ file_id: 'duplicate-id' }); + const file2 = createMockFile({ file_id: 'duplicate-id' }); + + filesMap.set('key-1', file1); + filesMap.set('key-2', file2); + + renderFileRow(filesMap); + + const images = screen.getAllByTestId('mock-image'); + expect(images).toHaveLength(1); + }); + }); + + describe('Empty State', () => { + it('should render nothing when files map is empty', () => { + const filesMap = new Map(); + + const { container } = renderFileRow(filesMap); + + expect(container.firstChild).toBeNull(); + }); + + it('should render nothing when files is undefined', () => { + const { container } = render( + , + ); + + expect(container.firstChild).toBeNull(); + }); + }); + + describe('Regression: Blob URL Bug Fix', () => { + it('should NOT use revoked blob URL after upload completes', () => { + const file = createMockFile({ + file_id: 'regression-test', + preview: 'blob:http://localhost:3080/d25f730c-152d-41f7-8d79-c9fa448f606b', + filepath: + '/images/68c98b26901ebe2d87c193a2/c0fe1b93-ba3d-456c-80be-9a492bfd9ed0__image.png', + progress: 1, + }); + + const filesMap = new Map(); + filesMap.set(file.file_id, file); + + renderFileRow(filesMap); + + const imageUrl = screen.getByTestId('image-url').textContent; + + expect(imageUrl).not.toContain('blob:'); + expect(imageUrl).toBe( + '/images/68c98b26901ebe2d87c193a2/c0fe1b93-ba3d-456c-80be-9a492bfd9ed0__image.png', + ); + }); + }); +}); From 958a6c787276f7c717c716b69cf8b2fa1206818e Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 5 Nov 2025 11:09:52 -0500 Subject: [PATCH 012/207] =?UTF-8?q?=F0=9F=8C=8D=20i18n:=20Update=20transla?= =?UTF-8?q?tion.json=20with=20latest=20translations=20(#10370)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- client/src/locales/lv/translation.json | 2 ++ 1 file changed, 2 insertions(+) diff --git a/client/src/locales/lv/translation.json b/client/src/locales/lv/translation.json index 508ea0037d..38f6809e02 100644 --- a/client/src/locales/lv/translation.json +++ b/client/src/locales/lv/translation.json @@ -787,6 +787,7 @@ "com_ui_copy_code": "Kopēt kodu", "com_ui_copy_link": "Kopēt saiti", "com_ui_copy_stack_trace": "Kopēt kļūdas informāciju", + "com_ui_copy_thoughts_to_clipboard": "Kopēt domas starpliktuvē", "com_ui_copy_to_clipboard": "Kopēt starpliktuvē", "com_ui_copy_url_to_clipboard": "URL kopēšana uz starpliktuvi", "com_ui_create": "Izveidot", @@ -1121,6 +1122,7 @@ "com_ui_reset_var": "Atiestatīt {{0}}", "com_ui_reset_zoom": "Atiestatīt tālummaiņu", "com_ui_resource": "resurss", + "com_ui_response": "Atbilde", "com_ui_result": "Rezultāts", "com_ui_revoke": "Atcelt", "com_ui_revoke_info": "Atcelt visus lietotāja sniegtos lietotāja datus", From 8a4a5a47903d354daab75bb26ea72b229b5a0571 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Wed, 5 Nov 2025 17:15:17 -0500 Subject: [PATCH 013/207] =?UTF-8?q?=F0=9F=A4=96=20feat:=20Agent=20Handoffs?= =?UTF-8?q?=20(Routing)=20(#10176)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: Add support for agent handoffs with edges in agent forms and schemas chore: Mark `agent_ids` field as deprecated in favor of edges across various schemas and types chore: Update dependencies for @langchain/core and @librechat/agents to latest versions chore: Update peer dependency for @librechat/agents to version 3.0.0-rc2 in package.json chore: Update @librechat/agents dependency to version 3.0.0-rc3 in package.json and package-lock.json feat: first pass, multi-agent handoffs fix: update output type to ToolMessage in memory handling functions fix: improve type checking for graphConfig in createRun function refactor: remove unused content filtering logic in AgentClient chore: update @librechat/agents dependency to version 3.0.0-rc4 in package.json and package-lock.json fix: update @langchain/core peer dependency version to ^0.3.72 in package.json and package-lock.json fix: update @librechat/agents dependency to version 3.0.0-rc6 in package.json and package-lock.json; refactor stream rate handling in various endpoints feat: Agent handoff UI chore: update @librechat/agents dependency to version 3.0.0-rc8 in package.json and package-lock.json fix: improve hasInfo condition and adjust UI element classes in AgentHandoff component refactor: remove current fixed agent display from AgentHandoffs component due to redundancy feat: enhance AgentHandoffs UI with localized beta label and improved layout chore: update @librechat/agents dependency to version 3.0.0-rc10 in package.json and package-lock.json feat: add `createSequentialChainEdges` function to add back agent chaining via multi-agents feat: update `createSequentialChainEdges` call to only provide conversation context between agents feat: deprecate Agent Chain functionality and update related methods for improved clarity * chore: update @librechat/agents dependency to version 3.0.0-rc11 in package.json and package-lock.json * refactor: remove unused addCacheControl function and related imports and import from @librechat/agents * chore: remove unused i18n keys * refactor: remove unused format export from index.ts * chore: update @librechat/agents to v3.0.0-rc13 * chore: remove BEDROCK_LEGACY provider from Providers enum * chore: update @librechat/agents to version 3.0.2 in package.json --- api/app/clients/AnthropicClient.js | 3 +- api/app/clients/prompts/addCacheControl.js | 45 - .../clients/prompts/addCacheControl.spec.js | 227 -- api/app/clients/prompts/index.js | 2 - api/package.json | 4 +- api/server/controllers/agents/callbacks.js | 23 +- api/server/controllers/agents/client.js | 274 +- .../services/Endpoints/agents/initialize.js | 116 +- .../Endpoints/anthropic/initialize.js | 4 +- .../services/Endpoints/bedrock/options.js | 16 +- .../services/Endpoints/custom/initialize.js | 7 +- .../Endpoints/custom/initialize.spec.js | 1 - .../services/Endpoints/openAI/initialize.js | 7 +- client/src/common/agents-types.ts | 7 +- .../Chat/Messages/Content/AgentHandoff.tsx | 92 + .../components/Chat/Messages/Content/Part.tsx | 10 + .../Messages/Content/Parts/AgentUpdate.tsx | 4 +- .../Agents/Advanced/AdvancedPanel.tsx | 7 + .../Agents/Advanced/AgentHandoffs.tsx | 296 ++ .../SidePanel/Agents/AgentPanel.tsx | 3 + .../SidePanel/Agents/AgentSelect.tsx | 5 + client/src/locales/en/translation.json | 14 + package-lock.json | 3057 ++--------------- packages/api/package.json | 4 +- packages/api/src/agents/chain.ts | 47 + packages/api/src/agents/index.ts | 1 + packages/api/src/agents/memory.ts | 6 +- packages/api/src/agents/run.ts | 136 +- packages/api/src/agents/validation.ts | 13 + .../api/src/endpoints/openai/initialize.ts | 18 +- packages/api/src/format/content.spec.ts | 340 -- packages/api/src/format/content.ts | 57 - packages/api/src/format/index.ts | 1 - packages/api/src/index.ts | 1 - packages/api/src/types/openai.ts | 8 +- packages/data-provider/src/config.ts | 4 + packages/data-provider/src/schemas.ts | 2 +- packages/data-provider/src/types/agents.ts | 42 + .../data-provider/src/types/assistants.ts | 6 +- packages/data-schemas/src/schema/agent.ts | 5 + packages/data-schemas/src/types/agent.ts | 3 + 41 files changed, 1108 insertions(+), 3810 deletions(-) delete mode 100644 api/app/clients/prompts/addCacheControl.js delete mode 100644 api/app/clients/prompts/addCacheControl.spec.js create mode 100644 client/src/components/Chat/Messages/Content/AgentHandoff.tsx create mode 100644 client/src/components/SidePanel/Agents/Advanced/AgentHandoffs.tsx create mode 100644 packages/api/src/agents/chain.ts delete mode 100644 packages/api/src/format/content.spec.ts delete mode 100644 packages/api/src/format/content.ts delete mode 100644 packages/api/src/format/index.ts diff --git a/api/app/clients/AnthropicClient.js b/api/app/clients/AnthropicClient.js index 43e546a0a3..cb884f2d54 100644 --- a/api/app/clients/AnthropicClient.js +++ b/api/app/clients/AnthropicClient.js @@ -10,7 +10,7 @@ const { getResponseSender, validateVisionModel, } = require('librechat-data-provider'); -const { sleep, SplitStreamHandler: _Handler } = require('@librechat/agents'); +const { sleep, SplitStreamHandler: _Handler, addCacheControl } = require('@librechat/agents'); const { Tokenizer, createFetch, @@ -25,7 +25,6 @@ const { const { truncateText, formatMessage, - addCacheControl, titleFunctionPrompt, parseParamFromPrompt, createContextHandlers, diff --git a/api/app/clients/prompts/addCacheControl.js b/api/app/clients/prompts/addCacheControl.js deleted file mode 100644 index 6bfd901a65..0000000000 --- a/api/app/clients/prompts/addCacheControl.js +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Anthropic API: Adds cache control to the appropriate user messages in the payload. - * @param {Array} messages - The array of message objects. - * @returns {Array} - The updated array of message objects with cache control added. - */ -function addCacheControl(messages) { - if (!Array.isArray(messages) || messages.length < 2) { - return messages; - } - - const updatedMessages = [...messages]; - let userMessagesModified = 0; - - for (let i = updatedMessages.length - 1; i >= 0 && userMessagesModified < 2; i--) { - const message = updatedMessages[i]; - if (message.getType != null && message.getType() !== 'human') { - continue; - } else if (message.getType == null && message.role !== 'user') { - continue; - } - - if (typeof message.content === 'string') { - message.content = [ - { - type: 'text', - text: message.content, - cache_control: { type: 'ephemeral' }, - }, - ]; - userMessagesModified++; - } else if (Array.isArray(message.content)) { - for (let j = message.content.length - 1; j >= 0; j--) { - if (message.content[j].type === 'text') { - message.content[j].cache_control = { type: 'ephemeral' }; - userMessagesModified++; - break; - } - } - } - } - - return updatedMessages; -} - -module.exports = addCacheControl; diff --git a/api/app/clients/prompts/addCacheControl.spec.js b/api/app/clients/prompts/addCacheControl.spec.js deleted file mode 100644 index c46ffd95e3..0000000000 --- a/api/app/clients/prompts/addCacheControl.spec.js +++ /dev/null @@ -1,227 +0,0 @@ -const addCacheControl = require('./addCacheControl'); - -describe('addCacheControl', () => { - test('should add cache control to the last two user messages with array content', () => { - const messages = [ - { role: 'user', content: [{ type: 'text', text: 'Hello' }] }, - { role: 'assistant', content: [{ type: 'text', text: 'Hi there' }] }, - { role: 'user', content: [{ type: 'text', text: 'How are you?' }] }, - { role: 'assistant', content: [{ type: 'text', text: 'I\'m doing well, thanks!' }] }, - { role: 'user', content: [{ type: 'text', text: 'Great!' }] }, - ]; - - const result = addCacheControl(messages); - - expect(result[0].content[0]).not.toHaveProperty('cache_control'); - expect(result[2].content[0].cache_control).toEqual({ type: 'ephemeral' }); - expect(result[4].content[0].cache_control).toEqual({ type: 'ephemeral' }); - }); - - test('should add cache control to the last two user messages with string content', () => { - const messages = [ - { role: 'user', content: 'Hello' }, - { role: 'assistant', content: 'Hi there' }, - { role: 'user', content: 'How are you?' }, - { role: 'assistant', content: 'I\'m doing well, thanks!' }, - { role: 'user', content: 'Great!' }, - ]; - - const result = addCacheControl(messages); - - expect(result[0].content).toBe('Hello'); - expect(result[2].content[0]).toEqual({ - type: 'text', - text: 'How are you?', - cache_control: { type: 'ephemeral' }, - }); - expect(result[4].content[0]).toEqual({ - type: 'text', - text: 'Great!', - cache_control: { type: 'ephemeral' }, - }); - }); - - test('should handle mixed string and array content', () => { - const messages = [ - { role: 'user', content: 'Hello' }, - { role: 'assistant', content: 'Hi there' }, - { role: 'user', content: [{ type: 'text', text: 'How are you?' }] }, - ]; - - const result = addCacheControl(messages); - - expect(result[0].content[0]).toEqual({ - type: 'text', - text: 'Hello', - cache_control: { type: 'ephemeral' }, - }); - expect(result[2].content[0].cache_control).toEqual({ type: 'ephemeral' }); - }); - - test('should handle less than two user messages', () => { - const messages = [ - { role: 'user', content: 'Hello' }, - { role: 'assistant', content: 'Hi there' }, - ]; - - const result = addCacheControl(messages); - - expect(result[0].content[0]).toEqual({ - type: 'text', - text: 'Hello', - cache_control: { type: 'ephemeral' }, - }); - expect(result[1].content).toBe('Hi there'); - }); - - test('should return original array if no user messages', () => { - const messages = [ - { role: 'assistant', content: 'Hi there' }, - { role: 'assistant', content: 'How can I help?' }, - ]; - - const result = addCacheControl(messages); - - expect(result).toEqual(messages); - }); - - test('should handle empty array', () => { - const messages = []; - const result = addCacheControl(messages); - expect(result).toEqual([]); - }); - - test('should handle non-array input', () => { - const messages = 'not an array'; - const result = addCacheControl(messages); - expect(result).toBe('not an array'); - }); - - test('should not modify assistant messages', () => { - const messages = [ - { role: 'user', content: 'Hello' }, - { role: 'assistant', content: 'Hi there' }, - { role: 'user', content: 'How are you?' }, - ]; - - const result = addCacheControl(messages); - - expect(result[1].content).toBe('Hi there'); - }); - - test('should handle multiple content items in user messages', () => { - const messages = [ - { - role: 'user', - content: [ - { type: 'text', text: 'Hello' }, - { type: 'image', url: 'http://example.com/image.jpg' }, - { type: 'text', text: 'This is an image' }, - ], - }, - { role: 'assistant', content: 'Hi there' }, - { role: 'user', content: 'How are you?' }, - ]; - - const result = addCacheControl(messages); - - expect(result[0].content[0]).not.toHaveProperty('cache_control'); - expect(result[0].content[1]).not.toHaveProperty('cache_control'); - expect(result[0].content[2].cache_control).toEqual({ type: 'ephemeral' }); - expect(result[2].content[0]).toEqual({ - type: 'text', - text: 'How are you?', - cache_control: { type: 'ephemeral' }, - }); - }); - - test('should handle an array with mixed content types', () => { - const messages = [ - { role: 'user', content: 'Hello' }, - { role: 'assistant', content: 'Hi there' }, - { role: 'user', content: [{ type: 'text', text: 'How are you?' }] }, - { role: 'assistant', content: 'I\'m doing well, thanks!' }, - { role: 'user', content: 'Great!' }, - ]; - - const result = addCacheControl(messages); - - expect(result[0].content).toEqual('Hello'); - expect(result[2].content[0]).toEqual({ - type: 'text', - text: 'How are you?', - cache_control: { type: 'ephemeral' }, - }); - expect(result[4].content).toEqual([ - { - type: 'text', - text: 'Great!', - cache_control: { type: 'ephemeral' }, - }, - ]); - expect(result[1].content).toBe('Hi there'); - expect(result[3].content).toBe('I\'m doing well, thanks!'); - }); - - test('should handle edge case with multiple content types', () => { - const messages = [ - { - role: 'user', - content: [ - { - type: 'image', - source: { type: 'base64', media_type: 'image/png', data: 'some_base64_string' }, - }, - { - type: 'image', - source: { type: 'base64', media_type: 'image/png', data: 'another_base64_string' }, - }, - { type: 'text', text: 'what do all these images have in common' }, - ], - }, - { role: 'assistant', content: 'I see multiple images.' }, - { role: 'user', content: 'Correct!' }, - ]; - - const result = addCacheControl(messages); - - expect(result[0].content[0]).not.toHaveProperty('cache_control'); - expect(result[0].content[1]).not.toHaveProperty('cache_control'); - expect(result[0].content[2].cache_control).toEqual({ type: 'ephemeral' }); - expect(result[2].content[0]).toEqual({ - type: 'text', - text: 'Correct!', - cache_control: { type: 'ephemeral' }, - }); - }); - - test('should handle user message with no text block', () => { - const messages = [ - { - role: 'user', - content: [ - { - type: 'image', - source: { type: 'base64', media_type: 'image/png', data: 'some_base64_string' }, - }, - { - type: 'image', - source: { type: 'base64', media_type: 'image/png', data: 'another_base64_string' }, - }, - ], - }, - { role: 'assistant', content: 'I see two images.' }, - { role: 'user', content: 'Correct!' }, - ]; - - const result = addCacheControl(messages); - - expect(result[0].content[0]).not.toHaveProperty('cache_control'); - expect(result[0].content[1]).not.toHaveProperty('cache_control'); - expect(result[2].content[0]).toEqual({ - type: 'text', - text: 'Correct!', - cache_control: { type: 'ephemeral' }, - }); - }); -}); diff --git a/api/app/clients/prompts/index.js b/api/app/clients/prompts/index.js index 2549ccda5c..4749cf0b48 100644 --- a/api/app/clients/prompts/index.js +++ b/api/app/clients/prompts/index.js @@ -1,4 +1,3 @@ -const addCacheControl = require('./addCacheControl'); const formatMessages = require('./formatMessages'); const summaryPrompts = require('./summaryPrompts'); const handleInputs = require('./handleInputs'); @@ -9,7 +8,6 @@ const createVisionPrompt = require('./createVisionPrompt'); const createContextHandlers = require('./createContextHandlers'); module.exports = { - addCacheControl, ...formatMessages, ...summaryPrompts, ...handleInputs, diff --git a/api/package.json b/api/package.json index 977cd13668..f3325ebf12 100644 --- a/api/package.json +++ b/api/package.json @@ -44,11 +44,11 @@ "@googleapis/youtube": "^20.0.0", "@keyv/redis": "^4.3.3", "@langchain/community": "^0.3.47", - "@langchain/core": "^0.3.62", + "@langchain/core": "^0.3.72", "@langchain/google-genai": "^0.2.13", "@langchain/google-vertexai": "^0.2.13", "@langchain/textsplitters": "^0.1.0", - "@librechat/agents": "^2.4.90", + "@librechat/agents": "^3.0.2", "@librechat/api": "*", "@librechat/data-schemas": "*", "@microsoft/microsoft-graph-client": "^3.0.7", diff --git a/api/server/controllers/agents/callbacks.js b/api/server/controllers/agents/callbacks.js index d700f0a9cb..a66fe9a053 100644 --- a/api/server/controllers/agents/callbacks.js +++ b/api/server/controllers/agents/callbacks.js @@ -95,6 +95,19 @@ class ModelEndHandler { } } +/** + * @deprecated Agent Chain helper + * @param {string | undefined} [last_agent_id] + * @param {string | undefined} [langgraph_node] + * @returns {boolean} + */ +function checkIfLastAgent(last_agent_id, langgraph_node) { + if (!last_agent_id || !langgraph_node) { + return false; + } + return langgraph_node?.endsWith(last_agent_id); +} + /** * Get default handlers for stream events. * @param {Object} options - The options object. @@ -125,7 +138,7 @@ function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedU handle: (event, data, metadata) => { if (data?.stepDetails.type === StepTypes.TOOL_CALLS) { sendEvent(res, { event, data }); - } else if (metadata?.last_agent_index === metadata?.agent_index) { + } else if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) { sendEvent(res, { event, data }); } else if (!metadata?.hide_sequential_outputs) { sendEvent(res, { event, data }); @@ -154,7 +167,7 @@ function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedU handle: (event, data, metadata) => { if (data?.delta.type === StepTypes.TOOL_CALLS) { sendEvent(res, { event, data }); - } else if (metadata?.last_agent_index === metadata?.agent_index) { + } else if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) { sendEvent(res, { event, data }); } else if (!metadata?.hide_sequential_outputs) { sendEvent(res, { event, data }); @@ -172,7 +185,7 @@ function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedU handle: (event, data, metadata) => { if (data?.result != null) { sendEvent(res, { event, data }); - } else if (metadata?.last_agent_index === metadata?.agent_index) { + } else if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) { sendEvent(res, { event, data }); } else if (!metadata?.hide_sequential_outputs) { sendEvent(res, { event, data }); @@ -188,7 +201,7 @@ function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedU * @param {GraphRunnableConfig['configurable']} [metadata] The runnable metadata. */ handle: (event, data, metadata) => { - if (metadata?.last_agent_index === metadata?.agent_index) { + if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) { sendEvent(res, { event, data }); } else if (!metadata?.hide_sequential_outputs) { sendEvent(res, { event, data }); @@ -204,7 +217,7 @@ function getDefaultHandlers({ res, aggregateContent, toolEndCallback, collectedU * @param {GraphRunnableConfig['configurable']} [metadata] The runnable metadata. */ handle: (event, data, metadata) => { - if (metadata?.last_agent_index === metadata?.agent_index) { + if (checkIfLastAgent(metadata?.last_agent_id, metadata?.langgraph_node)) { sendEvent(res, { event, data }); } else if (!metadata?.hide_sequential_outputs) { sendEvent(res, { event, data }); diff --git a/api/server/controllers/agents/client.js b/api/server/controllers/agents/client.js index 27da7d5cc1..13d779e95a 100644 --- a/api/server/controllers/agents/client.js +++ b/api/server/controllers/agents/client.js @@ -3,7 +3,6 @@ const { logger } = require('@librechat/data-schemas'); const { DynamicStructuredTool } = require('@langchain/core/tools'); const { getBufferString, HumanMessage } = require('@langchain/core/messages'); const { - sendEvent, createRun, Tokenizer, checkAccess, @@ -12,14 +11,12 @@ const { resolveHeaders, getBalanceConfig, memoryInstructions, - formatContentStrings, getTransactionsConfig, createMemoryProcessor, } = require('@librechat/api'); const { Callback, Providers, - GraphEvents, TitleMethod, formatMessage, formatAgentMessages, @@ -38,12 +35,12 @@ const { bedrockInputSchema, removeNullishValues, } = require('librechat-data-provider'); -const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts'); const { initializeAgent } = require('~/server/services/Endpoints/agents/agent'); const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens'); const { getFormattedMemories, deleteMemory, setMemory } = require('~/models'); const { encodeAndFormat } = require('~/server/services/Files/images/encode'); const { getProviderConfig } = require('~/server/services/Endpoints'); +const { createContextHandlers } = require('~/app/clients/prompts'); const { checkCapability } = require('~/server/services/Config'); const BaseClient = require('~/app/clients/BaseClient'); const { getRoleByName } = require('~/models/Role'); @@ -80,8 +77,6 @@ const payloadParser = ({ req, agent, endpoint }) => { return req.body.endpointOption.model_parameters; }; -const noSystemModelRegex = [/\b(o1-preview|o1-mini|amazon\.titan-text)\b/gi]; - function createTokenCounter(encoding) { return function (message) { const countTokens = (text) => Tokenizer.getTokenCount(text, encoding); @@ -803,137 +798,81 @@ class AgentClient extends BaseClient { ); /** - * - * @param {Agent} agent * @param {BaseMessage[]} messages - * @param {number} [i] - * @param {TMessageContentParts[]} [contentData] - * @param {Record} [currentIndexCountMap] */ - const runAgent = async (agent, _messages, i = 0, contentData = [], _currentIndexCountMap) => { - config.configurable.model = agent.model_parameters.model; - const currentIndexCountMap = _currentIndexCountMap ?? indexTokenCountMap; - if (i > 0) { - this.model = agent.model_parameters.model; + const runAgents = async (messages) => { + const agents = [this.options.agent]; + if ( + this.agentConfigs && + this.agentConfigs.size > 0 && + ((this.options.agent.edges?.length ?? 0) > 0 || + (await checkCapability(this.options.req, AgentCapabilities.chain))) + ) { + agents.push(...this.agentConfigs.values()); } - if (i > 0 && config.signal == null) { - config.signal = abortController.signal; - } - if (agent.recursion_limit && typeof agent.recursion_limit === 'number') { - config.recursionLimit = agent.recursion_limit; + + if (agents[0].recursion_limit && typeof agents[0].recursion_limit === 'number') { + config.recursionLimit = agents[0].recursion_limit; } + if ( agentsEConfig?.maxRecursionLimit && config.recursionLimit > agentsEConfig?.maxRecursionLimit ) { config.recursionLimit = agentsEConfig?.maxRecursionLimit; } - config.configurable.agent_id = agent.id; - config.configurable.name = agent.name; - config.configurable.agent_index = i; - const noSystemMessages = noSystemModelRegex.some((regex) => - agent.model_parameters.model.match(regex), - ); - const systemMessage = Object.values(agent.toolContextMap ?? {}) - .join('\n') - .trim(); + // TODO: needs to be added as part of AgentContext initialization + // const noSystemModelRegex = [/\b(o1-preview|o1-mini|amazon\.titan-text)\b/gi]; + // const noSystemMessages = noSystemModelRegex.some((regex) => + // agent.model_parameters.model.match(regex), + // ); + // if (noSystemMessages === true && systemContent?.length) { + // const latestMessageContent = _messages.pop().content; + // if (typeof latestMessageContent !== 'string') { + // latestMessageContent[0].text = [systemContent, latestMessageContent[0].text].join('\n'); + // _messages.push(new HumanMessage({ content: latestMessageContent })); + // } else { + // const text = [systemContent, latestMessageContent].join('\n'); + // _messages.push(new HumanMessage(text)); + // } + // } + // let messages = _messages; + // if (agent.useLegacyContent === true) { + // messages = formatContentStrings(messages); + // } + // if ( + // agent.model_parameters?.clientOptions?.defaultHeaders?.['anthropic-beta']?.includes( + // 'prompt-caching', + // ) + // ) { + // messages = addCacheControl(messages); + // } - let systemContent = [ - systemMessage, - agent.instructions ?? '', - i !== 0 ? (agent.additional_instructions ?? '') : '', - ] - .join('\n') - .trim(); - - if (noSystemMessages === true) { - agent.instructions = undefined; - agent.additional_instructions = undefined; - } else { - agent.instructions = systemContent; - agent.additional_instructions = undefined; - } - - if (noSystemMessages === true && systemContent?.length) { - const latestMessageContent = _messages.pop().content; - if (typeof latestMessageContent !== 'string') { - latestMessageContent[0].text = [systemContent, latestMessageContent[0].text].join('\n'); - _messages.push(new HumanMessage({ content: latestMessageContent })); - } else { - const text = [systemContent, latestMessageContent].join('\n'); - _messages.push(new HumanMessage(text)); - } - } - - let messages = _messages; - if (agent.useLegacyContent === true) { - messages = formatContentStrings(messages); - } - const defaultHeaders = - agent.model_parameters?.clientOptions?.defaultHeaders ?? - agent.model_parameters?.configuration?.defaultHeaders; - if (defaultHeaders?.['anthropic-beta']?.includes('prompt-caching')) { - messages = addCacheControl(messages); - } - - if (i === 0) { - memoryPromise = this.runMemory(messages); - } - - /** Resolve request-based headers for Custom Endpoints. Note: if this is added to - * non-custom endpoints, needs consideration of varying provider header configs. - */ - if (agent.model_parameters?.configuration?.defaultHeaders != null) { - agent.model_parameters.configuration.defaultHeaders = resolveHeaders({ - headers: agent.model_parameters.configuration.defaultHeaders, - body: config.configurable.requestBody, - }); - } + memoryPromise = this.runMemory(messages); run = await createRun({ - agent, - req: this.options.req, + agents, + indexTokenCountMap, runId: this.responseMessageId, signal: abortController.signal, customHandlers: this.options.eventHandlers, + requestBody: config.configurable.requestBody, + tokenCounter: createTokenCounter(this.getEncoding()), }); if (!run) { throw new Error('Failed to create run'); } - if (i === 0) { - this.run = run; - } - - if (contentData.length) { - const agentUpdate = { - type: ContentTypes.AGENT_UPDATE, - [ContentTypes.AGENT_UPDATE]: { - index: contentData.length, - runId: this.responseMessageId, - agentId: agent.id, - }, - }; - const streamData = { - event: GraphEvents.ON_AGENT_UPDATE, - data: agentUpdate, - }; - this.options.aggregateContent(streamData); - sendEvent(this.options.res, streamData); - contentData.push(agentUpdate); - run.Graph.contentData = contentData; - } - + this.run = run; if (userMCPAuthMap != null) { config.configurable.userMCPAuthMap = userMCPAuthMap; } + + /** @deprecated Agent Chain */ + config.configurable.last_agent_id = agents[agents.length - 1].id; await run.processStream({ messages }, config, { - keepContent: i !== 0, - tokenCounter: createTokenCounter(this.getEncoding()), - indexTokenCountMap: currentIndexCountMap, - maxContextTokens: agent.maxContextTokens, callbacks: { [Callback.TOOL_ERROR]: logToolError, }, @@ -942,109 +881,22 @@ class AgentClient extends BaseClient { config.signal = null; }; - await runAgent(this.options.agent, initialMessages); - let finalContentStart = 0; - if ( - this.agentConfigs && - this.agentConfigs.size > 0 && - (await checkCapability(this.options.req, AgentCapabilities.chain)) - ) { - const windowSize = 5; - let latestMessage = initialMessages.pop().content; - if (typeof latestMessage !== 'string') { - latestMessage = latestMessage[0].text; - } - let i = 1; - let runMessages = []; - - const windowIndexCountMap = {}; - const windowMessages = initialMessages.slice(-windowSize); - let currentIndex = 4; - for (let i = initialMessages.length - 1; i >= 0; i--) { - windowIndexCountMap[currentIndex] = indexTokenCountMap[i]; - currentIndex--; - if (currentIndex < 0) { - break; - } - } - const encoding = this.getEncoding(); - const tokenCounter = createTokenCounter(encoding); - for (const [agentId, agent] of this.agentConfigs) { - if (abortController.signal.aborted === true) { - break; - } - const currentRun = await run; - - if ( - i === this.agentConfigs.size && - config.configurable.hide_sequential_outputs === true - ) { - const content = this.contentParts.filter( - (part) => part.type === ContentTypes.TOOL_CALL, - ); - - this.options.res.write( - `event: message\ndata: ${JSON.stringify({ - event: 'on_content_update', - data: { - runId: this.responseMessageId, - content, - }, - })}\n\n`, - ); - } - const _runMessages = currentRun.Graph.getRunMessages(); - finalContentStart = this.contentParts.length; - runMessages = runMessages.concat(_runMessages); - const contentData = currentRun.Graph.contentData.slice(); - const bufferString = getBufferString([new HumanMessage(latestMessage), ...runMessages]); - if (i === this.agentConfigs.size) { - logger.debug(`SEQUENTIAL AGENTS: Last buffer string:\n${bufferString}`); - } - try { - const contextMessages = []; - const runIndexCountMap = {}; - for (let i = 0; i < windowMessages.length; i++) { - const message = windowMessages[i]; - const messageType = message._getType(); - if ( - (!agent.tools || agent.tools.length === 0) && - (messageType === 'tool' || (message.tool_calls?.length ?? 0) > 0) - ) { - continue; - } - runIndexCountMap[contextMessages.length] = windowIndexCountMap[i]; - contextMessages.push(message); - } - const bufferMessage = new HumanMessage(bufferString); - runIndexCountMap[contextMessages.length] = tokenCounter(bufferMessage); - const currentMessages = [...contextMessages, bufferMessage]; - await runAgent(agent, currentMessages, i, contentData, runIndexCountMap); - } catch (err) { - logger.error( - `[api/server/controllers/agents/client.js #chatCompletion] Error running agent ${agentId} (${i})`, - err, - ); - } - i++; - } + await runAgents(initialMessages); + /** @deprecated Agent Chain */ + if (config.configurable.hide_sequential_outputs) { + this.contentParts = this.contentParts.filter((part, index) => { + // Include parts that are either: + // 1. At or after the finalContentStart index + // 2. Of type tool_call + // 3. Have tool_call_ids property + return ( + index >= this.contentParts.length - 1 || + part.type === ContentTypes.TOOL_CALL || + part.tool_call_ids + ); + }); } - /** Note: not implemented */ - if (config.configurable.hide_sequential_outputs !== true) { - finalContentStart = 0; - } - - this.contentParts = this.contentParts.filter((part, index) => { - // Include parts that are either: - // 1. At or after the finalContentStart index - // 2. Of type tool_call - // 3. Have tool_call_ids property - return ( - index >= finalContentStart || part.type === ContentTypes.TOOL_CALL || part.tool_call_ids - ); - }); - try { const attachments = await this.awaitMemoryWithTimeout(memoryPromise); if (attachments && attachments.length > 0) { diff --git a/api/server/services/Endpoints/agents/initialize.js b/api/server/services/Endpoints/agents/initialize.js index 7cc0a39fba..3064a03662 100644 --- a/api/server/services/Endpoints/agents/initialize.js +++ b/api/server/services/Endpoints/agents/initialize.js @@ -1,6 +1,10 @@ const { logger } = require('@librechat/data-schemas'); const { createContentAggregator } = require('@librechat/agents'); -const { validateAgentModel, getCustomEndpointConfig } = require('@librechat/api'); +const { + validateAgentModel, + getCustomEndpointConfig, + createSequentialChainEdges, +} = require('@librechat/api'); const { Constants, EModelEndpoint, @@ -119,44 +123,90 @@ const initializeClient = async ({ req, res, signal, endpointOption }) => { const agent_ids = primaryConfig.agent_ids; let userMCPAuthMap = primaryConfig.userMCPAuthMap; - if (agent_ids?.length) { - for (const agentId of agent_ids) { - const agent = await getAgent({ id: agentId }); - if (!agent) { - throw new Error(`Agent ${agentId} not found`); + + async function processAgent(agentId) { + const agent = await getAgent({ id: agentId }); + if (!agent) { + throw new Error(`Agent ${agentId} not found`); + } + + const validationResult = await validateAgentModel({ + req, + res, + agent, + modelsConfig, + logViolation, + }); + + if (!validationResult.isValid) { + throw new Error(validationResult.error?.message); + } + + const config = await initializeAgent({ + req, + res, + agent, + loadTools, + requestFiles, + conversationId, + endpointOption, + allowedProviders, + }); + if (userMCPAuthMap != null) { + Object.assign(userMCPAuthMap, config.userMCPAuthMap ?? {}); + } else { + userMCPAuthMap = config.userMCPAuthMap; + } + agentConfigs.set(agentId, config); + } + + let edges = primaryConfig.edges; + const checkAgentInit = (agentId) => agentId === primaryConfig.id || agentConfigs.has(agentId); + if ((edges?.length ?? 0) > 0) { + for (const edge of edges) { + if (Array.isArray(edge.to)) { + for (const to of edge.to) { + if (checkAgentInit(to)) { + continue; + } + await processAgent(to); + } + } else if (typeof edge.to === 'string' && checkAgentInit(edge.to)) { + continue; + } else if (typeof edge.to === 'string') { + await processAgent(edge.to); } - const validationResult = await validateAgentModel({ - req, - res, - agent, - modelsConfig, - logViolation, - }); - - if (!validationResult.isValid) { - throw new Error(validationResult.error?.message); + if (Array.isArray(edge.from)) { + for (const from of edge.from) { + if (checkAgentInit(from)) { + continue; + } + await processAgent(from); + } + } else if (typeof edge.from === 'string' && checkAgentInit(edge.from)) { + continue; + } else if (typeof edge.from === 'string') { + await processAgent(edge.from); } - - const config = await initializeAgent({ - req, - res, - agent, - loadTools, - requestFiles, - conversationId, - endpointOption, - allowedProviders, - }); - if (userMCPAuthMap != null) { - Object.assign(userMCPAuthMap, config.userMCPAuthMap ?? {}); - } else { - userMCPAuthMap = config.userMCPAuthMap; - } - agentConfigs.set(agentId, config); } } + /** @deprecated Agent Chain */ + if (agent_ids?.length) { + for (const agentId of agent_ids) { + if (checkAgentInit(agentId)) { + continue; + } + await processAgent(agentId); + } + + const chain = await createSequentialChainEdges([primaryConfig.id].concat(agent_ids), '{convo}'); + edges = edges ? edges.concat(chain) : chain; + } + + primaryConfig.edges = edges; + let endpointConfig = appConfig.endpoints?.[primaryConfig.endpoint]; if (!isAgentsEndpoint(primaryConfig.endpoint) && !endpointConfig) { try { diff --git a/api/server/services/Endpoints/anthropic/initialize.js b/api/server/services/Endpoints/anthropic/initialize.js index 6e661da671..88639b3d7c 100644 --- a/api/server/services/Endpoints/anthropic/initialize.js +++ b/api/server/services/Endpoints/anthropic/initialize.js @@ -27,13 +27,13 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio const anthropicConfig = appConfig.endpoints?.[EModelEndpoint.anthropic]; if (anthropicConfig) { - clientOptions.streamRate = anthropicConfig.streamRate; + clientOptions._lc_stream_delay = anthropicConfig.streamRate; clientOptions.titleModel = anthropicConfig.titleModel; } const allConfig = appConfig.endpoints?.all; if (allConfig) { - clientOptions.streamRate = allConfig.streamRate; + clientOptions._lc_stream_delay = allConfig.streamRate; } if (optionsOnly) { diff --git a/api/server/services/Endpoints/bedrock/options.js b/api/server/services/Endpoints/bedrock/options.js index 2bc18f9a76..0d02d09b07 100644 --- a/api/server/services/Endpoints/bedrock/options.js +++ b/api/server/services/Endpoints/bedrock/options.js @@ -1,8 +1,6 @@ const { HttpsProxyAgent } = require('https-proxy-agent'); -const { createHandleLLMNewToken } = require('@librechat/api'); const { AuthType, - Constants, EModelEndpoint, bedrockInputParser, bedrockOutputParser, @@ -11,7 +9,6 @@ const { const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService'); const getOptions = async ({ req, overrideModel, endpointOption }) => { - const appConfig = req.config; const { BEDROCK_AWS_SECRET_ACCESS_KEY, BEDROCK_AWS_ACCESS_KEY_ID, @@ -47,10 +44,12 @@ const getOptions = async ({ req, overrideModel, endpointOption }) => { checkUserKeyExpiry(expiresAt, EModelEndpoint.bedrock); } - /** @type {number} */ + /* + Callback for stream rate no longer awaits and may end the stream prematurely + /** @type {number} let streamRate = Constants.DEFAULT_STREAM_RATE; - /** @type {undefined | TBaseEndpoint} */ + /** @type {undefined | TBaseEndpoint} const bedrockConfig = appConfig.endpoints?.[EModelEndpoint.bedrock]; if (bedrockConfig && bedrockConfig.streamRate) { @@ -61,6 +60,7 @@ const getOptions = async ({ req, overrideModel, endpointOption }) => { if (allConfig && allConfig.streamRate) { streamRate = allConfig.streamRate; } + */ /** @type {BedrockClientOptions} */ const requestOptions = { @@ -88,12 +88,6 @@ const getOptions = async ({ req, overrideModel, endpointOption }) => { llmConfig.endpointHost = BEDROCK_REVERSE_PROXY; } - llmConfig.callbacks = [ - { - handleLLMNewToken: createHandleLLMNewToken(streamRate), - }, - ]; - return { /** @type {BedrockClientOptions} */ llmConfig, diff --git a/api/server/services/Endpoints/custom/initialize.js b/api/server/services/Endpoints/custom/initialize.js index 066c9430ce..e6fbf65e77 100644 --- a/api/server/services/Endpoints/custom/initialize.js +++ b/api/server/services/Endpoints/custom/initialize.js @@ -3,7 +3,6 @@ const { isUserProvided, getOpenAIConfig, getCustomEndpointConfig, - createHandleLLMNewToken, } = require('@librechat/api'); const { CacheKeys, @@ -157,11 +156,7 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid if (!clientOptions.streamRate) { return options; } - options.llmConfig.callbacks = [ - { - handleLLMNewToken: createHandleLLMNewToken(clientOptions.streamRate), - }, - ]; + options.llmConfig._lc_stream_delay = clientOptions.streamRate; return options; } diff --git a/api/server/services/Endpoints/custom/initialize.spec.js b/api/server/services/Endpoints/custom/initialize.spec.js index 8b4a1303ee..a69ff9ef58 100644 --- a/api/server/services/Endpoints/custom/initialize.spec.js +++ b/api/server/services/Endpoints/custom/initialize.spec.js @@ -4,7 +4,6 @@ jest.mock('@librechat/api', () => ({ ...jest.requireActual('@librechat/api'), resolveHeaders: jest.fn(), getOpenAIConfig: jest.fn(), - createHandleLLMNewToken: jest.fn(), getCustomEndpointConfig: jest.fn().mockReturnValue({ apiKey: 'test-key', baseURL: 'https://test.com', diff --git a/api/server/services/Endpoints/openAI/initialize.js b/api/server/services/Endpoints/openAI/initialize.js index ab2e80640a..cd691c6240 100644 --- a/api/server/services/Endpoints/openAI/initialize.js +++ b/api/server/services/Endpoints/openAI/initialize.js @@ -5,7 +5,6 @@ const { isUserProvided, getOpenAIConfig, getAzureCredentials, - createHandleLLMNewToken, } = require('@librechat/api'); const { getUserKeyValues, checkUserKeyExpiry } = require('~/server/services/UserService'); const OpenAIClient = require('~/app/clients/OpenAIClient'); @@ -151,11 +150,7 @@ const initializeClient = async ({ if (!streamRate) { return options; } - options.llmConfig.callbacks = [ - { - handleLLMNewToken: createHandleLLMNewToken(streamRate), - }, - ]; + options.llmConfig._lc_stream_delay = streamRate; return options; } diff --git a/client/src/common/agents-types.ts b/client/src/common/agents-types.ts index a49586b8a0..43448a478f 100644 --- a/client/src/common/agents-types.ts +++ b/client/src/common/agents-types.ts @@ -1,9 +1,10 @@ import { AgentCapabilities, ArtifactModes } from 'librechat-data-provider'; import type { - Agent, - AgentProvider, AgentModelParameters, SupportContact, + AgentProvider, + GraphEdge, + Agent, } from 'librechat-data-provider'; import type { OptionWithIcon, ExtendedFile } from './types'; @@ -33,7 +34,9 @@ export type AgentForm = { model_parameters: AgentModelParameters; tools?: string[]; provider?: AgentProvider | OptionWithIcon; + /** @deprecated Use edges instead */ agent_ids?: string[]; + edges?: GraphEdge[]; [AgentCapabilities.artifacts]?: ArtifactModes | string; recursion_limit?: number; support_contact?: SupportContact; diff --git a/client/src/components/Chat/Messages/Content/AgentHandoff.tsx b/client/src/components/Chat/Messages/Content/AgentHandoff.tsx new file mode 100644 index 0000000000..989cf4d3c4 --- /dev/null +++ b/client/src/components/Chat/Messages/Content/AgentHandoff.tsx @@ -0,0 +1,92 @@ +import React, { useMemo, useState } from 'react'; +import { EModelEndpoint, Constants } from 'librechat-data-provider'; +import { ChevronDown } from 'lucide-react'; +import type { TMessage } from 'librechat-data-provider'; +import MessageIcon from '~/components/Share/MessageIcon'; +import { useAgentsMapContext } from '~/Providers'; +import { useLocalize } from '~/hooks'; +import { cn } from '~/utils'; + +interface AgentHandoffProps { + name: string; + args: string | Record; + output?: string | null; +} + +const AgentHandoff: React.FC = ({ name, args: _args = '' }) => { + const localize = useLocalize(); + const agentsMap = useAgentsMapContext(); + const [showInfo, setShowInfo] = useState(false); + + /** Extracted agent ID from tool name (e.g., "lc_transfer_to_agent_gUV0wMb7zHt3y3Xjz-8_4" -> "agent_gUV0wMb7zHt3y3Xjz-8_4") */ + const targetAgentId = useMemo(() => { + if (typeof name !== 'string' || !name.startsWith(Constants.LC_TRANSFER_TO_)) { + return null; + } + return name.replace(Constants.LC_TRANSFER_TO_, ''); + }, [name]); + + const targetAgent = useMemo(() => { + if (!targetAgentId || !agentsMap) { + return null; + } + return agentsMap[targetAgentId]; + }, [agentsMap, targetAgentId]); + + const args = useMemo(() => { + if (typeof _args === 'string') { + return _args; + } + try { + return JSON.stringify(_args, null, 2); + } catch { + return ''; + } + }, [_args]) as string; + + /** Requires more than 2 characters as can be an empty object: `{}` */ + const hasInfo = useMemo(() => (args?.trim()?.length ?? 0) > 2, [args]); + + return ( +
+
hasInfo && setShowInfo(!showInfo)} + > +
+ +
+ {localize('com_ui_transferred_to')} + + {targetAgent?.name || localize('com_ui_agent')} + + {hasInfo && ( + + )} +
+ {hasInfo && showInfo && ( +
+
+ {localize('com_ui_handoff_instructions')}: +
+
{args}
+
+ )} +
+ ); +}; + +export default AgentHandoff; diff --git a/client/src/components/Chat/Messages/Content/Part.tsx b/client/src/components/Chat/Messages/Content/Part.tsx index b8d70f33e4..b37010447d 100644 --- a/client/src/components/Chat/Messages/Content/Part.tsx +++ b/client/src/components/Chat/Messages/Content/Part.tsx @@ -1,5 +1,6 @@ import { Tools, + Constants, ContentTypes, ToolCallTypes, imageGenTools, @@ -10,6 +11,7 @@ import type { TMessageContentParts, TAttachment } from 'librechat-data-provider' import { OpenAIImageGen, EmptyText, Reasoning, ExecuteCode, AgentUpdate, Text } from './Parts'; import { ErrorMessage } from './MessageContent'; import RetrievalCall from './RetrievalCall'; +import AgentHandoff from './AgentHandoff'; import CodeAnalyze from './CodeAnalyze'; import Container from './Container'; import WebSearch from './WebSearch'; @@ -123,6 +125,14 @@ const Part = memo( isLast={isLast} /> ); + } else if (isToolCall && toolCall.name?.startsWith(Constants.LC_TRANSFER_TO_)) { + return ( + + ); } else if (isToolCall) { return ( = ({ currentAgentId }) => { const localize = useLocalize(); - const agentsMap = useAgentsMapContext() || {}; - const currentAgent = useMemo(() => agentsMap[currentAgentId], [agentsMap, currentAgentId]); + const agentsMap = useAgentsMapContext(); + const currentAgent = useMemo(() => agentsMap?.[currentAgentId], [agentsMap, currentAgentId]); if (!currentAgentId) { return null; } diff --git a/client/src/components/SidePanel/Agents/Advanced/AdvancedPanel.tsx b/client/src/components/SidePanel/Agents/Advanced/AdvancedPanel.tsx index f99bce6f3b..6bc4cf5a0d 100644 --- a/client/src/components/SidePanel/Agents/Advanced/AdvancedPanel.tsx +++ b/client/src/components/SidePanel/Agents/Advanced/AdvancedPanel.tsx @@ -5,6 +5,7 @@ import { useFormContext, Controller } from 'react-hook-form'; import type { AgentForm } from '~/common'; import { useAgentPanelContext } from '~/Providers'; import MaxAgentSteps from './MaxAgentSteps'; +import AgentHandoffs from './AgentHandoffs'; import { useLocalize } from '~/hooks'; import AgentChain from './AgentChain'; import { Panel } from '~/common'; @@ -42,6 +43,12 @@ export default function AdvancedPanel() {
+ } + /> {chainEnabled && ( ; + currentAgentId: string; +} + +/** TODO: make configurable */ +const MAX_HANDOFFS = 10; + +const AgentHandoffs: React.FC = ({ field, currentAgentId }) => { + const localize = useLocalize(); + const [newAgentId, setNewAgentId] = useState(''); + const [expandedIndices, setExpandedIndices] = useState>(new Set()); + const agentsMap = useAgentsMapContext(); + const edgesValue = field.value; + const edges = useMemo(() => edgesValue || [], [edgesValue]); + + const agents = useMemo(() => (agentsMap ? Object.values(agentsMap) : []), [agentsMap]); + + const selectableAgents = useMemo( + () => + agents + .filter((agent) => agent?.id !== currentAgentId) + .map( + (agent) => + ({ + label: agent?.name || '', + value: agent?.id || '', + icon: ( + + ), + }) as OptionWithIcon, + ), + [agents, currentAgentId], + ); + + const getAgentDetails = useCallback((id: string) => agentsMap?.[id], [agentsMap]); + + useEffect(() => { + if (newAgentId && edges.length < MAX_HANDOFFS) { + const newEdge: GraphEdge = { + from: currentAgentId, + to: newAgentId, + edgeType: 'handoff', + }; + field.onChange([...edges, newEdge]); + setNewAgentId(''); + } + }, [newAgentId, edges, field, currentAgentId]); + + const removeHandoffAt = (index: number) => { + field.onChange(edges.filter((_, i) => i !== index)); + // Also remove from expanded set + setExpandedIndices((prev) => { + const newSet = new Set(prev); + newSet.delete(index); + return newSet; + }); + }; + + const updateHandoffAt = (index: number, agentId: string) => { + const updated = [...edges]; + updated[index] = { ...updated[index], to: agentId }; + field.onChange(updated); + }; + + const updateHandoffDetailsAt = (index: number, updates: Partial) => { + const updated = [...edges]; + updated[index] = { ...updated[index], ...updates }; + field.onChange(updated); + }; + + const toggleExpanded = (index: number) => { + setExpandedIndices((prev) => { + const newSet = new Set(prev); + if (newSet.has(index)) { + newSet.delete(index); + } else { + newSet.add(index); + } + return newSet; + }); + }; + + const getTargetAgentId = (to: string | string[]): string => { + return Array.isArray(to) ? to[0] : to; + }; + + return ( + +
+
+ + + + +
+
+
+ {localize('com_ui_beta')} +
+
+ {edges.length} / {MAX_HANDOFFS} +
+
+
+
+ {edges.map((edge, idx) => { + const targetAgentId = getTargetAgentId(edge.to); + const isExpanded = expandedIndices.has(idx); + + return ( + +
+
+ updateHandoffAt(idx, id)} + selectPlaceholder={localize('com_ui_agent_var', { + 0: localize('com_ui_select'), + })} + searchPlaceholder={localize('com_ui_agent_var', { + 0: localize('com_ui_search'), + })} + items={selectableAgents} + displayValue={getAgentDetails(targetAgentId)?.name ?? ''} + SelectIcon={ + + } + className="flex-1 border-border-heavy" + containerClassName="px-0" + /> + + +
+ + {isExpanded && ( +
+
+ + + updateHandoffDetailsAt(idx, { description: e.target.value }) + } + className="mt-1 h-8 text-sm" + /> +
+ +
+ +