Merge branch 'dev' into feat/context-window-ui

This commit is contained in:
Marco Beretta 2025-12-29 02:07:54 +01:00
commit cb8322ca85
No known key found for this signature in database
GPG key ID: D918033D8E74CC11
407 changed files with 25479 additions and 19894 deletions

View file

@ -1,6 +1,6 @@
{
"name": "@librechat/api",
"version": "1.7.0",
"version": "1.7.10",
"type": "commonjs",
"description": "MCP services for LibreChat",
"main": "dist/index.js",
@ -23,7 +23,8 @@
"test:cache-integration:core": "jest --testPathPatterns=\"src/cache/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false",
"test:cache-integration:cluster": "jest --testPathPatterns=\"src/cluster/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false --runInBand",
"test:cache-integration:mcp": "jest --testPathPatterns=\"src/mcp/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false",
"test:cache-integration": "npm run test:cache-integration:core && npm run test:cache-integration:cluster && npm run test:cache-integration:mcp",
"test:cache-integration:stream": "jest --testPathPatterns=\"src/stream/.*\\.stream_integration\\.spec\\.ts$\" --coverage=false --runInBand --forceExit",
"test:cache-integration": "npm run test:cache-integration:core && npm run test:cache-integration:cluster && npm run test:cache-integration:mcp && npm run test:cache-integration:stream",
"verify": "npm run test:ci",
"b:clean": "bun run rimraf dist",
"b:build": "bun run b:clean && bun run rollup -c --silent --bundleConfigAsCjs",
@ -78,15 +79,17 @@
"registry": "https://registry.npmjs.org/"
},
"peerDependencies": {
"@aws-sdk/client-bedrock-runtime": "^3.941.0",
"@aws-sdk/client-s3": "^3.758.0",
"@azure/identity": "^4.7.0",
"@azure/search-documents": "^12.0.0",
"@azure/storage-blob": "^12.27.0",
"@keyv/redis": "^4.3.3",
"@langchain/core": "^0.3.79",
"@librechat/agents": "^3.0.50",
"@langchain/core": "^0.3.80",
"@librechat/agents": "^3.0.61",
"@librechat/data-schemas": "*",
"@modelcontextprotocol/sdk": "^1.24.3",
"@modelcontextprotocol/sdk": "^1.25.1",
"@smithy/node-http-handler": "^4.4.5",
"axios": "^1.12.1",
"connect-redis": "^8.1.0",
"diff": "^7.0.0",
@ -95,12 +98,14 @@
"express-session": "^1.18.2",
"firebase": "^11.0.2",
"form-data": "^4.0.4",
"https-proxy-agent": "^7.0.6",
"ioredis": "^5.3.2",
"js-yaml": "^4.1.1",
"jsonwebtoken": "^9.0.0",
"keyv": "^5.3.2",
"keyv-file": "^5.1.2",
"librechat-data-provider": "*",
"mathjs": "^15.1.0",
"memorystore": "^1.6.7",
"mongoose": "^8.12.1",
"node-fetch": "2.7.0",

View file

@ -17,6 +17,7 @@ import type { TAttachment, MemoryArtifact } from 'librechat-data-provider';
import type { ObjectId, MemoryMethods } from '@librechat/data-schemas';
import type { BaseMessage, ToolMessage } from '@langchain/core/messages';
import type { Response as ServerResponse } from 'express';
import { GenerationJobManager } from '~/stream/GenerationJobManager';
import { Tokenizer } from '~/utils';
type RequiredMemoryMethods = Pick<
@ -250,6 +251,7 @@ export class BasicToolEndHandler implements EventHandler {
constructor(callback?: ToolEndCallback) {
this.callback = callback;
}
handle(
event: string,
data: StreamEventData | undefined,
@ -282,6 +284,7 @@ export async function processMemory({
llmConfig,
tokenLimit,
totalTokens = 0,
streamId = null,
}: {
res: ServerResponse;
setMemory: MemoryMethods['setMemory'];
@ -296,6 +299,7 @@ export async function processMemory({
tokenLimit?: number;
totalTokens?: number;
llmConfig?: Partial<LLMConfig>;
streamId?: string | null;
}): Promise<(TAttachment | null)[] | undefined> {
try {
const memoryTool = createMemoryTool({
@ -363,7 +367,7 @@ ${memory ?? 'No existing memories'}`;
}
const artifactPromises: Promise<TAttachment | null>[] = [];
const memoryCallback = createMemoryCallback({ res, artifactPromises });
const memoryCallback = createMemoryCallback({ res, artifactPromises, streamId });
const customHandlers = {
[GraphEvents.TOOL_END]: new BasicToolEndHandler(memoryCallback),
};
@ -416,6 +420,7 @@ export async function createMemoryProcessor({
memoryMethods,
conversationId,
config = {},
streamId = null,
}: {
res: ServerResponse;
messageId: string;
@ -423,6 +428,7 @@ export async function createMemoryProcessor({
userId: string | ObjectId;
memoryMethods: RequiredMemoryMethods;
config?: MemoryConfig;
streamId?: string | null;
}): Promise<[string, (messages: BaseMessage[]) => Promise<(TAttachment | null)[] | undefined>]> {
const { validKeys, instructions, llmConfig, tokenLimit } = config;
const finalInstructions = instructions || getDefaultInstructions(validKeys, tokenLimit);
@ -443,6 +449,7 @@ export async function createMemoryProcessor({
llmConfig,
messageId,
tokenLimit,
streamId,
conversationId,
memory: withKeys,
totalTokens: totalTokens || 0,
@ -461,10 +468,12 @@ async function handleMemoryArtifact({
res,
data,
metadata,
streamId = null,
}: {
res: ServerResponse;
data: ToolEndData;
metadata?: ToolEndMetadata;
streamId?: string | null;
}) {
const output = data?.output as ToolMessage | undefined;
if (!output) {
@ -490,7 +499,11 @@ async function handleMemoryArtifact({
if (!res.headersSent) {
return attachment;
}
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
if (streamId) {
GenerationJobManager.emitChunk(streamId, { event: 'attachment', data: attachment });
} else {
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
}
return attachment;
}
@ -499,14 +512,17 @@ async function handleMemoryArtifact({
* @param params - The parameters object
* @param params.res - The server response object
* @param params.artifactPromises - Array to collect artifact promises
* @param params.streamId - The stream ID for resumable mode, or null for standard mode
* @returns The memory callback function
*/
export function createMemoryCallback({
res,
artifactPromises,
streamId = null,
}: {
res: ServerResponse;
artifactPromises: Promise<Partial<TAttachment> | null>[];
streamId?: string | null;
}): ToolEndCallback {
return async (data: ToolEndData, metadata?: Record<string, unknown>) => {
const output = data?.output as ToolMessage | undefined;
@ -515,7 +531,7 @@ export function createMemoryCallback({
return;
}
artifactPromises.push(
handleMemoryArtifact({ res, data, metadata }).catch((error) => {
handleMemoryArtifact({ res, data, metadata, streamId }).catch((error) => {
logger.error('Error processing memory artifact content:', error);
return null;
}),

View file

@ -1,5 +1,10 @@
/* eslint-disable @typescript-eslint/ban-ts-comment */
import { isEmailDomainAllowed, isActionDomainAllowed } from './domain';
import {
isEmailDomainAllowed,
isActionDomainAllowed,
extractMCPServerDomain,
isMCPDomainAllowed,
} from './domain';
describe('isEmailDomainAllowed', () => {
afterEach(() => {
@ -213,3 +218,209 @@ describe('isActionDomainAllowed', () => {
});
});
});
describe('extractMCPServerDomain', () => {
afterEach(() => {
jest.clearAllMocks();
});
describe('URL extraction', () => {
it('should extract domain from HTTPS URL', () => {
const config = { url: 'https://api.example.com/sse' };
expect(extractMCPServerDomain(config)).toBe('api.example.com');
});
it('should extract domain from HTTP URL', () => {
const config = { url: 'http://api.example.com/sse' };
expect(extractMCPServerDomain(config)).toBe('api.example.com');
});
it('should extract domain from WebSocket URL', () => {
const config = { url: 'wss://ws.example.com' };
expect(extractMCPServerDomain(config)).toBe('ws.example.com');
});
it('should handle URL with port', () => {
const config = { url: 'https://localhost:3001/sse' };
expect(extractMCPServerDomain(config)).toBe('localhost');
});
it('should strip www prefix', () => {
const config = { url: 'https://www.example.com/api' };
expect(extractMCPServerDomain(config)).toBe('example.com');
});
it('should handle URL with path and query parameters', () => {
const config = { url: 'https://api.example.com/v1/sse?token=abc' };
expect(extractMCPServerDomain(config)).toBe('api.example.com');
});
});
describe('stdio transports (no URL)', () => {
it('should return null for stdio transport with command only', () => {
const config = { command: 'npx', args: ['-y', '@modelcontextprotocol/server-puppeteer'] };
expect(extractMCPServerDomain(config)).toBeNull();
});
it('should return null when url is undefined', () => {
const config = { command: 'node', args: ['server.js'] };
expect(extractMCPServerDomain(config)).toBeNull();
});
it('should return null for empty object', () => {
const config = {};
expect(extractMCPServerDomain(config)).toBeNull();
});
});
describe('invalid URLs', () => {
it('should return null for invalid URL format', () => {
const config = { url: 'not-a-valid-url' };
expect(extractMCPServerDomain(config)).toBeNull();
});
it('should return null for empty URL string', () => {
const config = { url: '' };
expect(extractMCPServerDomain(config)).toBeNull();
});
it('should return null for non-string url', () => {
const config = { url: 12345 };
expect(extractMCPServerDomain(config)).toBeNull();
});
it('should return null for null url', () => {
const config = { url: null };
expect(extractMCPServerDomain(config)).toBeNull();
});
});
});
describe('isMCPDomainAllowed', () => {
afterEach(() => {
jest.clearAllMocks();
});
describe('stdio transports (always allowed)', () => {
it('should allow stdio transport regardless of allowlist', async () => {
const config = { command: 'npx', args: ['-y', '@modelcontextprotocol/server-puppeteer'] };
expect(await isMCPDomainAllowed(config, ['example.com'])).toBe(true);
});
it('should allow stdio transport even with empty allowlist', async () => {
const config = { command: 'node', args: ['server.js'] };
expect(await isMCPDomainAllowed(config, [])).toBe(true);
});
it('should allow stdio transport when no URL present', async () => {
const config = {};
expect(await isMCPDomainAllowed(config, ['restricted.com'])).toBe(true);
});
});
describe('permissive defaults (no restrictions)', () => {
it('should allow all domains when allowedDomains is null', async () => {
const config = { url: 'https://any-domain.com/sse' };
expect(await isMCPDomainAllowed(config, null)).toBe(true);
});
it('should allow all domains when allowedDomains is undefined', async () => {
const config = { url: 'https://any-domain.com/sse' };
expect(await isMCPDomainAllowed(config, undefined)).toBe(true);
});
it('should allow all domains when allowedDomains is empty array', async () => {
const config = { url: 'https://any-domain.com/sse' };
expect(await isMCPDomainAllowed(config, [])).toBe(true);
});
});
describe('exact domain matching', () => {
const allowedDomains = ['example.com', 'localhost', 'trusted-mcp.com'];
it('should allow exact domain match', async () => {
const config = { url: 'https://example.com/api' };
expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(true);
});
it('should allow localhost', async () => {
const config = { url: 'http://localhost:3001/sse' };
expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(true);
});
it('should reject non-allowed domain', async () => {
const config = { url: 'https://malicious.com/sse' };
expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(false);
});
it('should reject subdomain when only parent is allowed', async () => {
const config = { url: 'https://api.example.com/sse' };
expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(false);
});
});
describe('wildcard domain matching', () => {
const allowedDomains = ['*.example.com', 'localhost'];
it('should allow subdomain with wildcard', async () => {
const config = { url: 'https://api.example.com/sse' };
expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(true);
});
it('should allow any subdomain with wildcard', async () => {
const config = { url: 'https://staging.example.com/sse' };
expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(true);
});
it('should allow base domain with wildcard', async () => {
const config = { url: 'https://example.com/sse' };
expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(true);
});
it('should allow nested subdomain with wildcard', async () => {
const config = { url: 'https://deep.nested.example.com/sse' };
expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(true);
});
it('should reject different domain even with wildcard', async () => {
const config = { url: 'https://api.other.com/sse' };
expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(false);
});
});
describe('case insensitivity', () => {
it('should match domains case-insensitively', async () => {
const config = { url: 'https://EXAMPLE.COM/sse' };
expect(await isMCPDomainAllowed(config, ['example.com'])).toBe(true);
});
it('should match with uppercase in allowlist', async () => {
const config = { url: 'https://example.com/sse' };
expect(await isMCPDomainAllowed(config, ['EXAMPLE.COM'])).toBe(true);
});
it('should match with mixed case', async () => {
const config = { url: 'https://Api.Example.Com/sse' };
expect(await isMCPDomainAllowed(config, ['*.example.com'])).toBe(true);
});
});
describe('www prefix handling', () => {
it('should strip www prefix from URL before matching', async () => {
const config = { url: 'https://www.example.com/sse' };
expect(await isMCPDomainAllowed(config, ['example.com'])).toBe(true);
});
it('should match www in allowlist to non-www URL', async () => {
const config = { url: 'https://example.com/sse' };
expect(await isMCPDomainAllowed(config, ['www.example.com'])).toBe(true);
});
});
describe('invalid URL handling', () => {
it('should allow config with invalid URL (treated as stdio)', async () => {
const config = { url: 'not-a-valid-url' };
expect(await isMCPDomainAllowed(config, ['example.com'])).toBe(true);
});
});
});

View file

@ -96,3 +96,45 @@ export async function isActionDomainAllowed(
return false;
}
/**
* Extracts domain from MCP server config URL.
* Returns null for stdio transports (no URL) or invalid URLs.
* @param config - MCP server configuration (accepts any config with optional url field)
*/
export function extractMCPServerDomain(config: Record<string, unknown>): string | null {
const url = config.url;
// Stdio transports don't have URLs - always allowed
if (!url || typeof url !== 'string') {
return null;
}
try {
const parsedUrl = new URL(url);
return parsedUrl.hostname.replace(/^www\./i, '');
} catch {
return null;
}
}
/**
* Validates MCP server domain against allowedDomains.
* Reuses isActionDomainAllowed for consistent validation logic.
* Stdio transports (no URL) are always allowed.
* @param config - MCP server configuration with optional url field
* @param allowedDomains - List of allowed domains (with wildcard support)
*/
export async function isMCPDomainAllowed(
config: Record<string, unknown>,
allowedDomains?: string[] | null,
): Promise<boolean> {
const domain = extractMCPServerDomain(config);
// Stdio transports don't have domains - always allowed
if (!domain) {
return true;
}
// Reuse existing validation logic (includes wildcard support)
return isActionDomainAllowed(domain, allowedDomains);
}

View file

@ -10,6 +10,7 @@ describe('cacheConfig', () => {
delete process.env.REDIS_KEY_PREFIX_VAR;
delete process.env.REDIS_KEY_PREFIX;
delete process.env.USE_REDIS;
delete process.env.USE_REDIS_STREAMS;
delete process.env.USE_REDIS_CLUSTER;
delete process.env.REDIS_PING_INTERVAL;
delete process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES;
@ -130,6 +131,53 @@ describe('cacheConfig', () => {
});
});
describe('USE_REDIS_STREAMS configuration', () => {
test('should default to USE_REDIS value when USE_REDIS_STREAMS is not set', async () => {
process.env.USE_REDIS = 'true';
process.env.REDIS_URI = 'redis://localhost:6379';
const { cacheConfig } = await import('../cacheConfig');
expect(cacheConfig.USE_REDIS).toBe(true);
expect(cacheConfig.USE_REDIS_STREAMS).toBe(true);
});
test('should default to false when both USE_REDIS and USE_REDIS_STREAMS are not set', async () => {
const { cacheConfig } = await import('../cacheConfig');
expect(cacheConfig.USE_REDIS).toBe(false);
expect(cacheConfig.USE_REDIS_STREAMS).toBe(false);
});
test('should be false when explicitly set to false even if USE_REDIS is true', async () => {
process.env.USE_REDIS = 'true';
process.env.USE_REDIS_STREAMS = 'false';
process.env.REDIS_URI = 'redis://localhost:6379';
const { cacheConfig } = await import('../cacheConfig');
expect(cacheConfig.USE_REDIS).toBe(true);
expect(cacheConfig.USE_REDIS_STREAMS).toBe(false);
});
test('should be true when explicitly set to true', async () => {
process.env.USE_REDIS = 'true';
process.env.USE_REDIS_STREAMS = 'true';
process.env.REDIS_URI = 'redis://localhost:6379';
const { cacheConfig } = await import('../cacheConfig');
expect(cacheConfig.USE_REDIS_STREAMS).toBe(true);
});
test('should allow streams without general Redis (explicit override)', async () => {
// Edge case: someone might want streams with Redis but not general caching
// This would require REDIS_URI but not USE_REDIS
process.env.USE_REDIS_STREAMS = 'true';
process.env.REDIS_URI = 'redis://localhost:6379';
const { cacheConfig } = await import('../cacheConfig');
expect(cacheConfig.USE_REDIS).toBe(false);
expect(cacheConfig.USE_REDIS_STREAMS).toBe(true);
});
});
describe('REDIS_CA file reading', () => {
test('should be null when REDIS_CA is not set', async () => {
const { cacheConfig } = await import('../cacheConfig');

View file

@ -17,6 +17,14 @@ if (USE_REDIS && !process.env.REDIS_URI) {
throw new Error('USE_REDIS is enabled but REDIS_URI is not set.');
}
// USE_REDIS_STREAMS controls whether Redis is used for resumable stream job storage.
// Defaults to true if USE_REDIS is enabled but USE_REDIS_STREAMS is not explicitly set.
// Set to 'false' to use in-memory storage for streams while keeping Redis for other caches.
const USE_REDIS_STREAMS =
process.env.USE_REDIS_STREAMS !== undefined
? isEnabled(process.env.USE_REDIS_STREAMS)
: USE_REDIS;
// Comma-separated list of cache namespaces that should be forced to use in-memory storage
// even when Redis is enabled. This allows selective performance optimization for specific caches.
const FORCED_IN_MEMORY_CACHE_NAMESPACES = process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES
@ -60,6 +68,7 @@ const getRedisCA = (): string | null => {
const cacheConfig = {
FORCED_IN_MEMORY_CACHE_NAMESPACES,
USE_REDIS,
USE_REDIS_STREAMS,
REDIS_URI: process.env.REDIS_URI,
REDIS_USERNAME: process.env.REDIS_USERNAME,
REDIS_PASSWORD: process.env.REDIS_PASSWORD,
@ -112,6 +121,13 @@ const cacheConfig = {
* @default 1000
*/
REDIS_SCAN_COUNT: math(process.env.REDIS_SCAN_COUNT, 1000),
/**
* TTL in milliseconds for MCP registry read-through cache.
* This cache reduces redundant lookups within a single request flow.
* @default 5000 (5 seconds)
*/
MCP_REGISTRY_CACHE_TTL: math(process.env.MCP_REGISTRY_CACHE_TTL, 5000),
};
export { cacheConfig };

View file

@ -130,7 +130,7 @@ export async function fetchModels({
const options: {
headers: Record<string, string>;
timeout: number;
httpsAgent?: HttpsProxyAgent;
httpsAgent?: HttpsProxyAgent<string>;
} = {
headers: {
...(headers ?? {}),

View file

@ -79,6 +79,21 @@ export async function encodeAndFormatAudios(
mimeType: file.type,
data: content,
});
} else if (provider === Providers.OPENROUTER) {
// Extract format from filename extension (e.g., 'audio.mp3' -> 'mp3')
// OpenRouter expects format values like: wav, mp3, aiff, aac, ogg, flac, m4a, pcm16, pcm24
// Note: MIME types don't always match (e.g., 'audio/mpeg' is mp3, not mpeg), so that is why we are using the file extension instead
const format = file.filename.split('.').pop()?.toLowerCase();
if (!format) {
throw new Error(`Could not extract audio format from filename: ${file.filename}`);
}
result.audios.push({
type: 'input_audio',
input_audio: {
data: content,
format,
},
});
}
result.files.push(metadata);

View file

@ -79,6 +79,13 @@ export async function encodeAndFormatVideos(
mimeType: file.type,
data: content,
});
} else if (provider === Providers.OPENROUTER) {
result.videos.push({
type: 'video_url',
video_url: {
url: `data:${file.type};base64,${content}`,
},
});
}
result.files.push(metadata);

View file

@ -129,22 +129,61 @@ export class FlowStateManager<T = unknown> {
return new Promise<T>((resolve, reject) => {
const checkInterval = 2000;
let elapsedTime = 0;
let isCleanedUp = false;
let intervalId: NodeJS.Timeout | null = null;
// Cleanup function to avoid duplicate cleanup
const cleanup = () => {
if (isCleanedUp) return;
isCleanedUp = true;
if (intervalId) {
clearInterval(intervalId);
this.intervals.delete(intervalId);
}
if (signal && abortHandler) {
signal.removeEventListener('abort', abortHandler);
}
};
// Immediate abort handler - responds instantly to abort signal
const abortHandler = async () => {
cleanup();
logger.warn(`[${flowKey}] Flow aborted (immediate)`);
const message = `${type} flow aborted`;
try {
await this.keyv.delete(flowKey);
} catch {
// Ignore delete errors during abort
}
reject(new Error(message));
};
// Register abort handler immediately if signal provided
if (signal) {
if (signal.aborted) {
// Already aborted, reject immediately
cleanup();
reject(new Error(`${type} flow aborted`));
return;
}
signal.addEventListener('abort', abortHandler, { once: true });
}
intervalId = setInterval(async () => {
if (isCleanedUp) return;
const intervalId = setInterval(async () => {
try {
const flowState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined;
if (!flowState) {
clearInterval(intervalId);
this.intervals.delete(intervalId);
cleanup();
logger.error(`[${flowKey}] Flow state not found`);
reject(new Error(`${type} Flow state not found`));
return;
}
if (signal?.aborted) {
clearInterval(intervalId);
this.intervals.delete(intervalId);
cleanup();
logger.warn(`[${flowKey}] Flow aborted`);
const message = `${type} flow aborted`;
await this.keyv.delete(flowKey);
@ -153,8 +192,7 @@ export class FlowStateManager<T = unknown> {
}
if (flowState.status !== 'PENDING') {
clearInterval(intervalId);
this.intervals.delete(intervalId);
cleanup();
logger.debug(`[${flowKey}] Flow completed`);
if (flowState.status === 'COMPLETED' && flowState.result !== undefined) {
@ -168,8 +206,7 @@ export class FlowStateManager<T = unknown> {
elapsedTime += checkInterval;
if (elapsedTime >= this.ttl) {
clearInterval(intervalId);
this.intervals.delete(intervalId);
cleanup();
logger.error(
`[${flowKey}] Flow timed out | Elapsed time: ${elapsedTime} | TTL: ${this.ttl}`,
);
@ -179,8 +216,7 @@ export class FlowStateManager<T = unknown> {
logger.debug(`[${flowKey}] Flow state elapsed time: ${elapsedTime}, checking again...`);
} catch (error) {
logger.error(`[${flowKey}] Error checking flow state:`, error);
clearInterval(intervalId);
this.intervals.delete(intervalId);
cleanup();
reject(error);
}
}, checkInterval);

View file

@ -9,6 +9,7 @@ export * from './mcp/connection';
export * from './mcp/oauth';
export * from './mcp/auth';
export * from './mcp/zod';
export * from './mcp/errors';
/* Utilities */
export * from './mcp/utils';
export * from './utils';
@ -38,6 +39,8 @@ export * from './tools';
export * from './web';
/* Cache */
export * from './cache';
/* Stream */
export * from './stream';
/* types */
export type * from './mcp/types';
export type * from './flow/types';

View file

@ -1,7 +1,7 @@
import { logger } from '@librechat/data-schemas';
import type { OAuthClientInformation } from '@modelcontextprotocol/sdk/shared/auth.js';
import type { TokenMethods } from '@librechat/data-schemas';
import type { MCPOAuthTokens, MCPOAuthFlowMetadata, OAuthMetadata } from '~/mcp/oauth';
import type { MCPOAuthTokens, OAuthMetadata } from '~/mcp/oauth';
import type { FlowStateManager } from '~/flow/manager';
import type { FlowMetadata } from '~/flow/types';
import type * as t from './types';
@ -173,9 +173,10 @@ export class MCPConnectionFactory {
// Create the flow state so the OAuth callback can find it
// We spawn this in the background without waiting for it
this.flowManager!.createFlow(flowId, 'mcp_oauth', flowMetadata).catch(() => {
// Pass signal so the flow can be aborted if the request is cancelled
this.flowManager!.createFlow(flowId, 'mcp_oauth', flowMetadata, this.signal).catch(() => {
// The OAuth callback will resolve this flow, so we expect it to timeout here
// which is fine - we just need the flow state to exist
// or it will be aborted if the request is cancelled - both are fine
});
if (this.oauthStart) {
@ -354,56 +355,26 @@ export class MCPConnectionFactory {
/** Check if there's already an ongoing OAuth flow for this flowId */
const existingFlow = await this.flowManager.getFlowState(flowId, 'mcp_oauth');
if (existingFlow && existingFlow.status === 'PENDING') {
// If any flow exists (PENDING, COMPLETED, FAILED), cancel it and start fresh
// This ensures the user always gets a new OAuth URL instead of waiting for stale flows
if (existingFlow) {
logger.debug(
`${this.logPrefix} OAuth flow already exists for ${flowId}, waiting for completion`,
`${this.logPrefix} Found existing OAuth flow (status: ${existingFlow.status}), cancelling to start fresh`,
);
/** Tokens from existing flow to complete */
const tokens = await this.flowManager.createFlow(flowId, 'mcp_oauth');
if (typeof this.oauthEnd === 'function') {
await this.oauthEnd();
}
logger.info(
`${this.logPrefix} OAuth flow completed, tokens received for ${this.serverName}`,
);
/** Client information from the existing flow metadata */
const existingMetadata = existingFlow.metadata as unknown as MCPOAuthFlowMetadata;
const clientInfo = existingMetadata?.clientInfo;
return { tokens, clientInfo };
}
// Clean up old completed/failed flows, but only if they're actually stale
// This prevents race conditions where we delete a flow that's still being processed
if (existingFlow && existingFlow.status !== 'PENDING') {
const STALE_FLOW_THRESHOLD = 2 * 60 * 1000; // 2 minutes
const { isStale, age, status } = await this.flowManager.isFlowStale(
flowId,
'mcp_oauth',
STALE_FLOW_THRESHOLD,
);
if (isStale) {
try {
try {
if (existingFlow.status === 'PENDING') {
await this.flowManager.failFlow(
flowId,
'mcp_oauth',
new Error('Cancelled for new OAuth request'),
);
} else {
await this.flowManager.deleteFlow(flowId, 'mcp_oauth');
logger.debug(
`${this.logPrefix} Cleared stale ${status} OAuth flow (age: ${Math.round(age / 1000)}s)`,
);
} catch (error) {
logger.warn(`${this.logPrefix} Failed to clear stale OAuth flow`, error);
}
} else {
logger.debug(
`${this.logPrefix} Skipping cleanup of recent ${status} flow (age: ${Math.round(age / 1000)}s, threshold: ${STALE_FLOW_THRESHOLD / 1000}s)`,
);
// If flow is recent but not pending, something might be wrong
if (status === 'FAILED') {
logger.warn(
`${this.logPrefix} Recent OAuth flow failed, will retry after ${Math.round((STALE_FLOW_THRESHOLD - age) / 1000)}s`,
);
}
} catch (error) {
logger.warn(`${this.logPrefix} Failed to cancel existing OAuth flow`, error);
}
// Continue to start a new flow below
}
logger.debug(`${this.logPrefix} Initiating new OAuth flow for ${this.serverName}...`);

View file

@ -188,3 +188,344 @@ describe('MCPConnection Error Detection', () => {
});
});
});
/**
* Tests for extractSSEErrorMessage function.
* This function extracts meaningful error messages from SSE transport errors,
* particularly handling the "SSE error: undefined" case from the MCP SDK.
*/
describe('extractSSEErrorMessage', () => {
/**
* Standalone implementation of extractSSEErrorMessage for testing.
* This mirrors the function in connection.ts.
* Keep in sync with the actual implementation.
*/
function extractSSEErrorMessage(error: unknown): {
message: string;
code?: number;
isProxyHint: boolean;
isTransient: boolean;
} {
if (!error || typeof error !== 'object') {
return {
message: 'Unknown SSE transport error',
isProxyHint: true,
isTransient: true,
};
}
const errorObj = error as { message?: string; code?: number; event?: unknown };
const rawMessage = errorObj.message ?? '';
const code = errorObj.code;
// Handle the common "SSE error: undefined" case
if (rawMessage === 'SSE error: undefined' || rawMessage === 'undefined' || !rawMessage) {
return {
message:
'SSE connection closed. This can occur due to: (1) idle connection timeout (normal), ' +
'(2) reverse proxy buffering (check proxy_buffering config), or (3) network interruption.',
code,
isProxyHint: true,
isTransient: true,
};
}
// Check for timeout patterns with case-insensitive matching
const lowerMessage = rawMessage.toLowerCase();
if (
rawMessage.includes('ETIMEDOUT') ||
rawMessage.includes('ESOCKETTIMEDOUT') ||
lowerMessage.includes('timed out') ||
lowerMessage.includes('timeout after') ||
lowerMessage.includes('request timeout')
) {
return {
message: `SSE connection timed out: ${rawMessage}. If behind a reverse proxy, increase proxy_read_timeout.`,
code,
isProxyHint: true,
isTransient: true,
};
}
// Connection reset is often transient
if (rawMessage.includes('ECONNRESET')) {
return {
message: `SSE connection reset: ${rawMessage}. The server or proxy may have restarted.`,
code,
isProxyHint: false,
isTransient: true,
};
}
// Connection refused is more serious
if (rawMessage.includes('ECONNREFUSED')) {
return {
message: `SSE connection refused: ${rawMessage}. Verify the MCP server is running and accessible.`,
code,
isProxyHint: false,
isTransient: false,
};
}
// DNS failure
if (rawMessage.includes('ENOTFOUND') || rawMessage.includes('getaddrinfo')) {
return {
message: `SSE DNS resolution failed: ${rawMessage}. Check the server URL is correct.`,
code,
isProxyHint: false,
isTransient: false,
};
}
// Check for HTTP status codes
const statusMatch = rawMessage.match(/\b(4\d{2}|5\d{2})\b/);
if (statusMatch) {
const statusCode = parseInt(statusMatch[1], 10);
const isServerError = statusCode >= 500 && statusCode < 600;
return {
message: rawMessage,
code: statusCode,
isProxyHint: statusCode === 502 || statusCode === 503 || statusCode === 504,
isTransient: isServerError,
};
}
return {
message: rawMessage,
code,
isProxyHint: false,
isTransient: false,
};
}
describe('undefined/empty error handling', () => {
it('should handle "SSE error: undefined" from MCP SDK', () => {
const error = { message: 'SSE error: undefined', code: undefined };
const result = extractSSEErrorMessage(error);
expect(result.message).toContain('SSE connection closed');
expect(result.isProxyHint).toBe(true);
expect(result.isTransient).toBe(true);
});
it('should handle empty message', () => {
const error = { message: '' };
const result = extractSSEErrorMessage(error);
expect(result.message).toContain('SSE connection closed');
expect(result.isTransient).toBe(true);
});
it('should handle message "undefined"', () => {
const error = { message: 'undefined' };
const result = extractSSEErrorMessage(error);
expect(result.message).toContain('SSE connection closed');
expect(result.isTransient).toBe(true);
});
it('should handle null error', () => {
const result = extractSSEErrorMessage(null);
expect(result.message).toBe('Unknown SSE transport error');
expect(result.isTransient).toBe(true);
});
it('should handle undefined error', () => {
const result = extractSSEErrorMessage(undefined);
expect(result.message).toBe('Unknown SSE transport error');
expect(result.isTransient).toBe(true);
});
it('should handle non-object error', () => {
const result = extractSSEErrorMessage('string error');
expect(result.message).toBe('Unknown SSE transport error');
expect(result.isTransient).toBe(true);
});
});
describe('timeout errors', () => {
it('should detect ETIMEDOUT', () => {
const error = { message: 'connect ETIMEDOUT 1.2.3.4:443' };
const result = extractSSEErrorMessage(error);
expect(result.message).toContain('SSE connection timed out');
expect(result.message).toContain('proxy_read_timeout');
expect(result.isProxyHint).toBe(true);
expect(result.isTransient).toBe(true);
});
it('should detect ESOCKETTIMEDOUT', () => {
const error = { message: 'ESOCKETTIMEDOUT' };
const result = extractSSEErrorMessage(error);
expect(result.message).toContain('SSE connection timed out');
expect(result.isTransient).toBe(true);
});
it('should detect "timed out" (case insensitive)', () => {
const error = { message: 'Connection Timed Out' };
const result = extractSSEErrorMessage(error);
expect(result.message).toContain('SSE connection timed out');
expect(result.isTransient).toBe(true);
});
it('should detect "timeout after"', () => {
const error = { message: 'Request timeout after 60000ms' };
const result = extractSSEErrorMessage(error);
expect(result.message).toContain('SSE connection timed out');
expect(result.isTransient).toBe(true);
});
it('should detect "request timeout"', () => {
const error = { message: 'Request Timeout' };
const result = extractSSEErrorMessage(error);
expect(result.message).toContain('SSE connection timed out');
expect(result.isTransient).toBe(true);
});
it('should NOT match "timeout" in unrelated context', () => {
// URL containing "timeout" should not trigger timeout detection
const error = { message: 'Failed to connect to https://api.example.com/timeout-settings' };
const result = extractSSEErrorMessage(error);
expect(result.message).not.toContain('SSE connection timed out');
expect(result.message).toBe('Failed to connect to https://api.example.com/timeout-settings');
});
});
describe('connection errors', () => {
it('should detect ECONNRESET as transient', () => {
const error = { message: 'read ECONNRESET' };
const result = extractSSEErrorMessage(error);
expect(result.message).toContain('SSE connection reset');
expect(result.isProxyHint).toBe(false);
expect(result.isTransient).toBe(true);
});
it('should detect ECONNREFUSED as non-transient', () => {
const error = { message: 'connect ECONNREFUSED 127.0.0.1:8080' };
const result = extractSSEErrorMessage(error);
expect(result.message).toContain('SSE connection refused');
expect(result.message).toContain('Verify the MCP server is running');
expect(result.isTransient).toBe(false);
});
});
describe('DNS errors', () => {
it('should detect ENOTFOUND', () => {
const error = { message: 'getaddrinfo ENOTFOUND unknown.host.com' };
const result = extractSSEErrorMessage(error);
expect(result.message).toContain('SSE DNS resolution failed');
expect(result.message).toContain('Check the server URL');
expect(result.isTransient).toBe(false);
});
it('should detect getaddrinfo errors', () => {
const error = { message: 'getaddrinfo EAI_AGAIN example.com' };
const result = extractSSEErrorMessage(error);
expect(result.message).toContain('SSE DNS resolution failed');
expect(result.isTransient).toBe(false);
});
});
describe('HTTP status code errors', () => {
it('should detect 502 as proxy hint and transient', () => {
const error = { message: 'Non-200 status code (502): Bad Gateway' };
const result = extractSSEErrorMessage(error);
expect(result.code).toBe(502);
expect(result.isProxyHint).toBe(true);
expect(result.isTransient).toBe(true);
});
it('should detect 503 as proxy hint and transient', () => {
const error = { message: 'Error: Service Unavailable (503)' };
const result = extractSSEErrorMessage(error);
expect(result.code).toBe(503);
expect(result.isProxyHint).toBe(true);
expect(result.isTransient).toBe(true);
});
it('should detect 504 as proxy hint and transient', () => {
const error = { message: 'Gateway Timeout 504' };
const result = extractSSEErrorMessage(error);
expect(result.code).toBe(504);
expect(result.isProxyHint).toBe(true);
expect(result.isTransient).toBe(true);
});
it('should detect 500 as transient but not proxy hint', () => {
const error = { message: 'Internal Server Error (500)' };
const result = extractSSEErrorMessage(error);
expect(result.code).toBe(500);
expect(result.isProxyHint).toBe(false);
expect(result.isTransient).toBe(true);
});
it('should detect 404 as non-transient', () => {
const error = { message: 'Not Found (404)' };
const result = extractSSEErrorMessage(error);
expect(result.code).toBe(404);
expect(result.isProxyHint).toBe(false);
expect(result.isTransient).toBe(false);
});
it('should detect 401 as non-transient', () => {
const error = { message: 'Unauthorized (401)' };
const result = extractSSEErrorMessage(error);
expect(result.code).toBe(401);
expect(result.isTransient).toBe(false);
});
});
describe('SseError from MCP SDK', () => {
it('should handle SseError with event property', () => {
const error = {
message: 'SSE error: undefined',
code: undefined,
event: { type: 'error', code: undefined, message: undefined },
};
const result = extractSSEErrorMessage(error);
expect(result.message).toContain('SSE connection closed');
expect(result.isTransient).toBe(true);
});
it('should preserve code from SseError', () => {
const error = {
message: 'SSE error: Server sent HTTP 204, not reconnecting',
code: 204,
};
const result = extractSSEErrorMessage(error);
expect(result.code).toBe(204);
});
});
describe('regular error messages', () => {
it('should pass through regular error messages', () => {
const error = { message: 'Some specific error message', code: 42 };
const result = extractSSEErrorMessage(error);
expect(result.message).toBe('Some specific error message');
expect(result.code).toBe(42);
expect(result.isProxyHint).toBe(false);
expect(result.isTransient).toBe(false);
});
});
});

View file

@ -331,12 +331,14 @@ describe('MCPConnectionFactory', () => {
expect(deleteCallOrder).toBeLessThan(createCallOrder);
// Verify createFlow was called with fresh metadata
// 4th arg is the abort signal (undefined in this test since no signal was provided)
expect(mockFlowManager.createFlow).toHaveBeenCalledWith(
'user123:test-server',
'mcp_oauth',
expect.objectContaining({
codeVerifier: 'new-code-verifier-xyz',
}),
undefined,
);
});
});

View file

@ -68,6 +68,145 @@ function isStreamableHTTPOptions(options: t.MCPOptions): options is t.Streamable
const FIVE_MINUTES = 5 * 60 * 1000;
const DEFAULT_TIMEOUT = 60000;
/** SSE connections through proxies may need longer initial handshake time */
const SSE_CONNECT_TIMEOUT = 120000;
/**
* Headers for SSE connections.
*
* Headers we intentionally DO NOT include:
* - Accept: text/event-stream - Already set by eventsource library AND MCP SDK
* - X-Accel-Buffering: This is a RESPONSE header for Nginx, not a request header.
* The upstream MCP server must send this header for Nginx to respect it.
* - Connection: keep-alive: Forbidden in HTTP/2 (RFC 7540 §8.1.2.2).
* HTTP/2 manages connection persistence differently.
*/
const SSE_REQUEST_HEADERS = {
'Cache-Control': 'no-cache',
};
/**
* Extracts a meaningful error message from SSE transport errors.
* The MCP SDK's SSEClientTransport can produce "SSE error: undefined" when the
* underlying eventsource library encounters connection issues without a specific message.
*
* @returns Object containing:
* - message: Human-readable error description
* - code: HTTP status code if available
* - isProxyHint: Whether this error suggests proxy misconfiguration
* - isTransient: Whether this is likely a transient error that will auto-reconnect
*/
function extractSSEErrorMessage(error: unknown): {
message: string;
code?: number;
isProxyHint: boolean;
isTransient: boolean;
} {
if (!error || typeof error !== 'object') {
return {
message: 'Unknown SSE transport error',
isProxyHint: true,
isTransient: true,
};
}
const errorObj = error as { message?: string; code?: number; event?: unknown };
const rawMessage = errorObj.message ?? '';
const code = errorObj.code;
/**
* Handle the common "SSE error: undefined" case.
* This typically occurs when:
* 1. A reverse proxy buffers the SSE stream (proxy issue)
* 2. The server closes an idle connection (normal SSE behavior)
* 3. Network interruption without specific error details
*
* In all cases, the eventsource library will attempt to reconnect automatically.
*/
if (rawMessage === 'SSE error: undefined' || rawMessage === 'undefined' || !rawMessage) {
return {
message:
'SSE connection closed. This can occur due to: (1) idle connection timeout (normal), ' +
'(2) reverse proxy buffering (check proxy_buffering config), or (3) network interruption.',
code,
isProxyHint: true,
isTransient: true,
};
}
/**
* Check for timeout patterns. Use case-insensitive matching for common timeout error codes:
* - ETIMEDOUT: TCP connection timeout
* - ESOCKETTIMEDOUT: Socket timeout
* - "timed out" / "timeout": Generic timeout messages
*/
const lowerMessage = rawMessage.toLowerCase();
if (
rawMessage.includes('ETIMEDOUT') ||
rawMessage.includes('ESOCKETTIMEDOUT') ||
lowerMessage.includes('timed out') ||
lowerMessage.includes('timeout after') ||
lowerMessage.includes('request timeout')
) {
return {
message: `SSE connection timed out: ${rawMessage}. If behind a reverse proxy, increase proxy_read_timeout.`,
code,
isProxyHint: true,
isTransient: true,
};
}
// Connection reset is often transient (server restart, proxy reload)
if (rawMessage.includes('ECONNRESET')) {
return {
message: `SSE connection reset: ${rawMessage}. The server or proxy may have restarted.`,
code,
isProxyHint: false,
isTransient: true,
};
}
// Connection refused is more serious - server may be down
if (rawMessage.includes('ECONNREFUSED')) {
return {
message: `SSE connection refused: ${rawMessage}. Verify the MCP server is running and accessible.`,
code,
isProxyHint: false,
isTransient: false,
};
}
// DNS failure is usually a configuration issue, not transient
if (rawMessage.includes('ENOTFOUND') || rawMessage.includes('getaddrinfo')) {
return {
message: `SSE DNS resolution failed: ${rawMessage}. Check the server URL is correct.`,
code,
isProxyHint: false,
isTransient: false,
};
}
// Check for HTTP status codes in the message
const statusMatch = rawMessage.match(/\b(4\d{2}|5\d{2})\b/);
if (statusMatch) {
const statusCode = parseInt(statusMatch[1], 10);
// 5xx errors are often transient, 4xx are usually not
const isServerError = statusCode >= 500 && statusCode < 600;
return {
message: rawMessage,
code: statusCode,
isProxyHint: statusCode === 502 || statusCode === 503 || statusCode === 504,
isTransient: isServerError,
};
}
return {
message: rawMessage,
code,
isProxyHint: false,
isTransient: false,
};
}
interface MCPConnectionParams {
serverName: string;
@ -258,18 +397,29 @@ export class MCPConnection extends EventEmitter {
headers['Authorization'] = `Bearer ${this.oauthTokens.access_token}`;
}
const timeoutValue = this.timeout || DEFAULT_TIMEOUT;
/**
* SSE connections need longer timeouts for reliability.
* The connect timeout is extended because proxies may delay initial response.
*/
const sseTimeout = this.timeout || SSE_CONNECT_TIMEOUT;
const transport = new SSEClientTransport(url, {
requestInit: {
headers,
/** User/OAuth headers override SSE defaults */
headers: { ...SSE_REQUEST_HEADERS, ...headers },
signal: abortController.signal,
},
eventSourceInit: {
fetch: (url, init) => {
const fetchHeaders = new Headers(Object.assign({}, init?.headers, headers));
/** Merge headers: SSE defaults < init headers < user headers (user wins) */
const fetchHeaders = new Headers(
Object.assign({}, SSE_REQUEST_HEADERS, init?.headers, headers),
);
const agent = new Agent({
bodyTimeout: timeoutValue,
headersTimeout: timeoutValue,
bodyTimeout: sseTimeout,
headersTimeout: sseTimeout,
/** Extended keep-alive for long-lived SSE connections */
keepAliveTimeout: sseTimeout,
keepAliveMaxTimeout: sseTimeout * 2,
});
return undiciFetch(url, {
...init,
@ -280,7 +430,7 @@ export class MCPConnection extends EventEmitter {
},
fetch: this.createFetchFunction(
this.getRequestHeaders.bind(this),
this.timeout,
sseTimeout,
) as unknown as FetchLike,
});
@ -639,26 +789,70 @@ export class MCPConnection extends EventEmitter {
private setupTransportErrorHandlers(transport: Transport): void {
transport.onerror = (error) => {
if (error && typeof error === 'object' && 'code' in error) {
const errorCode = (error as unknown as { code?: number }).code;
// Extract meaningful error information (handles "SSE error: undefined" cases)
const {
message: errorMessage,
code: errorCode,
isProxyHint,
isTransient,
} = extractSSEErrorMessage(error);
// Ignore SSE 404 errors for servers that don't support SSE
if (
errorCode === 404 &&
String(error?.message).toLowerCase().includes('failed to open sse stream')
) {
logger.warn(`${this.getLogPrefix()} SSE stream not available (404). Ignoring.`);
return;
// Ignore SSE 404 errors for servers that don't support SSE
if (errorCode === 404 && errorMessage.toLowerCase().includes('failed to open sse stream')) {
logger.warn(`${this.getLogPrefix()} SSE stream not available (404). Ignoring.`);
return;
}
// Check if it's an OAuth authentication error
if (this.isOAuthError(error)) {
logger.warn(`${this.getLogPrefix()} OAuth authentication error detected`);
this.emit('oauthError', error);
}
/**
* Log with enhanced context for debugging.
* All transport.onerror events are logged as errors to preserve stack traces.
* isTransient indicates whether auto-reconnection is expected to succeed.
*
* The MCP SDK's SseError extends Error and includes:
* - code: HTTP status code or eventsource error code
* - event: The original eventsource ErrorEvent
* - stack: Full stack trace
*/
const errorContext: Record<string, unknown> = {
code: errorCode,
isTransient,
};
if (isProxyHint) {
errorContext.hint = 'Check Nginx/proxy configuration for SSE endpoints';
}
// Extract additional debug info from SseError if available
if (error && typeof error === 'object') {
const sseError = error as { event?: unknown; stack?: string };
// Include the original eventsource event for debugging
if (sseError.event && typeof sseError.event === 'object') {
const event = sseError.event as { code?: number; message?: string; type?: string };
errorContext.eventDetails = {
type: event.type,
code: event.code,
message: event.message,
};
}
// Check if it's an OAuth authentication error
if (this.isOAuthError(error)) {
logger.warn(`${this.getLogPrefix()} OAuth authentication error detected`);
this.emit('oauthError', error);
// Include stack trace if available
if (sseError.stack) {
errorContext.stack = sseError.stack;
}
}
logger.error(`${this.getLogPrefix()} Transport error:`, error);
const errorLabel = isTransient
? 'Transport error (transient, will reconnect)'
: 'Transport error (may require manual intervention)';
logger.error(`${this.getLogPrefix()} ${errorLabel}: ${errorMessage}`, errorContext);
this.emit('connectionChange', 'error');
};

View file

@ -0,0 +1,61 @@
/**
* MCP-specific error classes
*/
export const MCPErrorCodes = {
DOMAIN_NOT_ALLOWED: 'MCP_DOMAIN_NOT_ALLOWED',
INSPECTION_FAILED: 'MCP_INSPECTION_FAILED',
} as const;
export type MCPErrorCode = (typeof MCPErrorCodes)[keyof typeof MCPErrorCodes];
/**
* Custom error for MCP domain restriction violations.
* Thrown when a user attempts to connect to an MCP server whose domain is not in the allowlist.
*/
export class MCPDomainNotAllowedError extends Error {
public readonly code = MCPErrorCodes.DOMAIN_NOT_ALLOWED;
public readonly statusCode = 403;
public readonly domain: string;
constructor(domain: string) {
super(`Domain "${domain}" is not allowed`);
this.name = 'MCPDomainNotAllowedError';
this.domain = domain;
Object.setPrototypeOf(this, MCPDomainNotAllowedError.prototype);
}
}
/**
* Custom error for MCP server inspection failures.
* Thrown when attempting to connect/inspect an MCP server fails.
*/
export class MCPInspectionFailedError extends Error {
public readonly code = MCPErrorCodes.INSPECTION_FAILED;
public readonly statusCode = 400;
public readonly serverName: string;
constructor(serverName: string, cause?: Error) {
super(`Failed to connect to MCP server "${serverName}"`);
this.name = 'MCPInspectionFailedError';
this.serverName = serverName;
if (cause) {
this.cause = cause;
}
Object.setPrototypeOf(this, MCPInspectionFailedError.prototype);
}
}
/**
* Type guard to check if an error is an MCPDomainNotAllowedError
*/
export function isMCPDomainNotAllowedError(error: unknown): error is MCPDomainNotAllowedError {
return error instanceof MCPDomainNotAllowedError;
}
/**
* Type guard to check if an error is an MCPInspectionFailedError
*/
export function isMCPInspectionFailedError(error: unknown): error is MCPInspectionFailedError {
return error instanceof MCPInspectionFailedError;
}

View file

@ -9,5 +9,7 @@ export const mcpConfig = {
OAUTH_DETECTION_TIMEOUT: math(process.env.MCP_OAUTH_DETECTION_TIMEOUT ?? 5000),
CONNECTION_CHECK_TTL: math(process.env.MCP_CONNECTION_CHECK_TTL ?? 60000),
/** Idle timeout (ms) after which user connections are disconnected. Default: 15 minutes */
USER_CONNECTION_IDLE_TIMEOUT: math(process.env.MCP_USER_CONNECTION_IDLE_TIMEOUT ?? 15 * 60 * 1000),
USER_CONNECTION_IDLE_TIMEOUT: math(
process.env.MCP_USER_CONNECTION_IDLE_TIMEOUT ?? 15 * 60 * 1000,
),
};

View file

@ -9,7 +9,7 @@ import {
discoverAuthorizationServerMetadata,
discoverOAuthProtectedResourceMetadata,
} from '@modelcontextprotocol/sdk/client/auth.js';
import type { MCPOptions } from 'librechat-data-provider';
import { TokenExchangeMethodEnum, type MCPOptions } from 'librechat-data-provider';
import type { FlowStateManager } from '~/flow/manager';
import type {
OAuthClientInformation,
@ -27,15 +27,117 @@ export class MCPOAuthHandler {
private static readonly FLOW_TYPE = 'mcp_oauth';
private static readonly FLOW_TTL = 10 * 60 * 1000; // 10 minutes
private static getForcedTokenEndpointAuthMethod(
tokenExchangeMethod?: TokenExchangeMethodEnum,
): 'client_secret_basic' | 'client_secret_post' | undefined {
if (tokenExchangeMethod === TokenExchangeMethodEnum.DefaultPost) {
return 'client_secret_post';
}
if (tokenExchangeMethod === TokenExchangeMethodEnum.BasicAuthHeader) {
return 'client_secret_basic';
}
return undefined;
}
private static resolveTokenEndpointAuthMethod(options: {
tokenExchangeMethod?: TokenExchangeMethodEnum;
tokenAuthMethods: string[];
preferredMethod?: string;
}): 'client_secret_basic' | 'client_secret_post' | undefined {
const forcedMethod = this.getForcedTokenEndpointAuthMethod(options.tokenExchangeMethod);
const preferredMethod = forcedMethod ?? options.preferredMethod;
if (preferredMethod === 'client_secret_basic' || preferredMethod === 'client_secret_post') {
return preferredMethod;
}
if (options.tokenAuthMethods.includes('client_secret_basic')) {
return 'client_secret_basic';
}
if (options.tokenAuthMethods.includes('client_secret_post')) {
return 'client_secret_post';
}
return undefined;
}
/**
* Creates a fetch function with custom headers injected
*/
private static createOAuthFetch(headers: Record<string, string>): FetchLike {
private static createOAuthFetch(
headers: Record<string, string>,
clientInfo?: OAuthClientInformation,
): FetchLike {
return async (url: string | URL, init?: RequestInit): Promise<Response> => {
const newHeaders = new Headers(init?.headers ?? {});
for (const [key, value] of Object.entries(headers)) {
newHeaders.set(key, value);
}
const method = (init?.method ?? 'GET').toUpperCase();
const initBody = init?.body;
let params: URLSearchParams | undefined;
if (initBody instanceof URLSearchParams) {
params = initBody;
} else if (typeof initBody === 'string') {
const parsed = new URLSearchParams(initBody);
if (parsed.has('grant_type')) {
params = parsed;
}
}
/**
* FastMCP 2.14+/MCP SDK 1.24+ token endpoints can be strict about:
* - Content-Type (must be application/x-www-form-urlencoded)
* - where client_id/client_secret are supplied (default_post vs basic header)
*/
if (method === 'POST' && params?.has('grant_type')) {
newHeaders.set('Content-Type', 'application/x-www-form-urlencoded');
if (clientInfo?.client_id) {
let authMethod = clientInfo.token_endpoint_auth_method;
if (!authMethod) {
if (newHeaders.has('Authorization')) {
authMethod = 'client_secret_basic';
} else if (params.has('client_id') || params.has('client_secret')) {
authMethod = 'client_secret_post';
} else if (clientInfo.client_secret) {
authMethod = 'client_secret_post';
} else {
authMethod = 'none';
}
}
if (!clientInfo.client_secret || authMethod === 'none') {
newHeaders.delete('Authorization');
if (!params.has('client_id')) {
params.set('client_id', clientInfo.client_id);
}
} else if (authMethod === 'client_secret_post') {
newHeaders.delete('Authorization');
if (!params.has('client_id')) {
params.set('client_id', clientInfo.client_id);
}
if (!params.has('client_secret')) {
params.set('client_secret', clientInfo.client_secret);
}
} else if (authMethod === 'client_secret_basic') {
if (!newHeaders.has('Authorization')) {
const clientAuth = Buffer.from(
`${clientInfo.client_id}:${clientInfo.client_secret}`,
).toString('base64');
newHeaders.set('Authorization', `Basic ${clientAuth}`);
}
}
}
return fetch(url, {
...init,
body: params.toString(),
headers: newHeaders,
});
}
return fetch(url, {
...init,
headers: newHeaders,
@ -157,6 +259,7 @@ export class MCPOAuthHandler {
oauthHeaders: Record<string, string>,
resourceMetadata?: OAuthProtectedResourceMetadata,
redirectUri?: string,
tokenExchangeMethod?: TokenExchangeMethodEnum,
): Promise<OAuthClientInformation> {
logger.debug(
`[MCPOAuth] Starting client registration for ${sanitizeUrlForLogging(serverUrl)}, server metadata:`,
@ -197,7 +300,11 @@ export class MCPOAuthHandler {
clientMetadata.response_types = metadata.response_types_supported || ['code'];
if (metadata.token_endpoint_auth_methods_supported) {
const forcedAuthMethod = this.getForcedTokenEndpointAuthMethod(tokenExchangeMethod);
if (forcedAuthMethod) {
clientMetadata.token_endpoint_auth_method = forcedAuthMethod;
} else if (metadata.token_endpoint_auth_methods_supported) {
// Prefer client_secret_basic if supported, otherwise use the first supported method
if (metadata.token_endpoint_auth_methods_supported.includes('client_secret_basic')) {
clientMetadata.token_endpoint_auth_method = 'client_secret_basic';
@ -227,6 +334,12 @@ export class MCPOAuthHandler {
fetchFn: this.createOAuthFetch(oauthHeaders),
});
if (forcedAuthMethod) {
clientInfo.token_endpoint_auth_method = forcedAuthMethod;
} else if (!clientInfo.token_endpoint_auth_method) {
clientInfo.token_endpoint_auth_method = clientMetadata.token_endpoint_auth_method;
}
logger.debug(
`[MCPOAuth] Client registered successfully for ${sanitizeUrlForLogging(serverUrl)}:`,
{
@ -281,6 +394,26 @@ export class MCPOAuthHandler {
}
/** Metadata based on pre-configured settings */
let tokenEndpointAuthMethod: string;
if (!config.client_secret) {
tokenEndpointAuthMethod = 'none';
} else {
// When token_exchange_method is undefined or not DefaultPost, default to using
// client_secret_basic (Basic Auth header) for token endpoint authentication.
tokenEndpointAuthMethod =
this.getForcedTokenEndpointAuthMethod(config.token_exchange_method) ??
'client_secret_basic';
}
let defaultTokenAuthMethods: string[];
if (tokenEndpointAuthMethod === 'none') {
defaultTokenAuthMethods = ['none'];
} else if (tokenEndpointAuthMethod === 'client_secret_post') {
defaultTokenAuthMethods = ['client_secret_post', 'client_secret_basic'];
} else {
defaultTokenAuthMethods = ['client_secret_basic', 'client_secret_post'];
}
const metadata: OAuthMetadata = {
authorization_endpoint: config.authorization_url,
token_endpoint: config.token_url,
@ -290,10 +423,8 @@ export class MCPOAuthHandler {
'authorization_code',
'refresh_token',
],
token_endpoint_auth_methods_supported: config?.token_endpoint_auth_methods_supported ?? [
'client_secret_basic',
'client_secret_post',
],
token_endpoint_auth_methods_supported:
config?.token_endpoint_auth_methods_supported ?? defaultTokenAuthMethods,
response_types_supported: config?.response_types_supported ?? ['code'],
code_challenge_methods_supported: codeChallengeMethodsSupported,
};
@ -303,6 +434,7 @@ export class MCPOAuthHandler {
client_secret: config.client_secret,
redirect_uris: [config.redirect_uri || this.getDefaultRedirectUri(serverName)],
scope: config.scope,
token_endpoint_auth_method: tokenEndpointAuthMethod,
};
logger.debug(`[MCPOAuth] Starting authorization with pre-configured settings`);
@ -359,6 +491,7 @@ export class MCPOAuthHandler {
oauthHeaders,
resourceMetadata,
redirectUri,
config?.token_exchange_method,
);
logger.debug(`[MCPOAuth] Client registered with ID: ${clientInfo.client_id}`);
@ -490,7 +623,7 @@ export class MCPOAuthHandler {
codeVerifier: metadata.codeVerifier,
authorizationCode,
resource,
fetchFn: this.createOAuthFetch(oauthHeaders),
fetchFn: this.createOAuthFetch(oauthHeaders, metadata.clientInfo),
});
logger.debug('[MCPOAuth] Token exchange successful', {
@ -663,8 +796,8 @@ export class MCPOAuthHandler {
}
const headers: HeadersInit = {
'Content-Type': 'application/x-www-form-urlencoded',
Accept: 'application/json',
'Content-Type': 'application/x-www-form-urlencoded',
...oauthHeaders,
};
@ -672,17 +805,20 @@ export class MCPOAuthHandler {
if (metadata.clientInfo.client_secret) {
/** Default to client_secret_basic if no methods specified (per RFC 8414) */
const tokenAuthMethods = authMethods ?? ['client_secret_basic'];
const usesBasicAuth = tokenAuthMethods.includes('client_secret_basic');
const usesClientSecretPost = tokenAuthMethods.includes('client_secret_post');
const authMethod = this.resolveTokenEndpointAuthMethod({
tokenExchangeMethod: config?.token_exchange_method,
tokenAuthMethods,
preferredMethod: metadata.clientInfo.token_endpoint_auth_method,
});
if (usesBasicAuth) {
if (authMethod === 'client_secret_basic') {
/** Use Basic auth */
logger.debug('[MCPOAuth] Using client_secret_basic authentication method');
const clientAuth = Buffer.from(
`${metadata.clientInfo.client_id}:${metadata.clientInfo.client_secret}`,
).toString('base64');
headers['Authorization'] = `Basic ${clientAuth}`;
} else if (usesClientSecretPost) {
} else if (authMethod === 'client_secret_post') {
/** Use client_secret_post */
logger.debug('[MCPOAuth] Using client_secret_post authentication method');
body.append('client_id', metadata.clientInfo.client_id);
@ -739,8 +875,8 @@ export class MCPOAuthHandler {
}
const headers: HeadersInit = {
'Content-Type': 'application/x-www-form-urlencoded',
Accept: 'application/json',
'Content-Type': 'application/x-www-form-urlencoded',
...oauthHeaders,
};
@ -750,10 +886,12 @@ export class MCPOAuthHandler {
const tokenAuthMethods = config.token_endpoint_auth_methods_supported ?? [
'client_secret_basic',
];
const usesBasicAuth = tokenAuthMethods.includes('client_secret_basic');
const usesClientSecretPost = tokenAuthMethods.includes('client_secret_post');
const authMethod = this.resolveTokenEndpointAuthMethod({
tokenExchangeMethod: config.token_exchange_method,
tokenAuthMethods,
});
if (usesBasicAuth) {
if (authMethod === 'client_secret_basic') {
/** Use Basic auth */
logger.debug(
'[MCPOAuth] Using client_secret_basic authentication method (pre-configured)',
@ -762,7 +900,7 @@ export class MCPOAuthHandler {
'base64',
);
headers['Authorization'] = `Basic ${clientAuth}`;
} else if (usesClientSecretPost) {
} else if (authMethod === 'client_secret_post') {
/** Use client_secret_post */
logger.debug(
'[MCPOAuth] Using client_secret_post authentication method (pre-configured)',
@ -832,8 +970,8 @@ export class MCPOAuthHandler {
});
const headers: HeadersInit = {
'Content-Type': 'application/x-www-form-urlencoded',
Accept: 'application/json',
'Content-Type': 'application/x-www-form-urlencoded',
...oauthHeaders,
};

View file

@ -2,7 +2,9 @@ import { Constants } from 'librechat-data-provider';
import type { JsonSchemaType } from '@librechat/data-schemas';
import type { MCPConnection } from '~/mcp/connection';
import type * as t from '~/mcp/types';
import { isMCPDomainAllowed, extractMCPServerDomain } from '~/auth/domain';
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
import { MCPDomainNotAllowedError } from '~/mcp/errors';
import { detectOAuthRequirement } from '~/mcp/oauth';
import { isEnabled } from '~/utils';
@ -24,13 +26,22 @@ export class MCPServerInspector {
* @param serverName - The name of the server (used for tool function naming)
* @param rawConfig - The raw server configuration
* @param connection - The MCP connection
* @param allowedDomains - Optional list of allowed domains for remote transports
* @returns A fully processed and enriched configuration with server metadata
*/
public static async inspect(
serverName: string,
rawConfig: t.MCPOptions,
connection?: MCPConnection,
allowedDomains?: string[] | null,
): Promise<t.ParsedServerConfig> {
// Validate domain against allowlist BEFORE attempting connection
const isDomainAllowed = await isMCPDomainAllowed(rawConfig, allowedDomains);
if (!isDomainAllowed) {
const domain = extractMCPServerDomain(rawConfig);
throw new MCPDomainNotAllowedError(domain ?? 'unknown');
}
const start = Date.now();
const inspector = new MCPServerInspector(serverName, rawConfig, connection);
await inspector.inspectServer();

View file

@ -30,11 +30,24 @@ export class MCPServersInitializer {
* - Followers wait and poll `statusCache` until the leader finishes, ensuring only one node
* performs the expensive initialization operations.
*/
private static hasInitializedThisProcess = false;
/** Reset the process-level initialization flag. Only used for testing. */
public static resetProcessFlag(): void {
MCPServersInitializer.hasInitializedThisProcess = false;
}
public static async initialize(rawConfigs: t.MCPServers): Promise<void> {
if (await statusCache.isInitialized()) return;
// On first call in this process, always reset and re-initialize
// This ensures we don't use stale Redis data from previous runs
const isFirstCallThisProcess = !MCPServersInitializer.hasInitializedThisProcess;
// Set flag immediately so recursive calls (from followers) use Redis cache for coordination
MCPServersInitializer.hasInitializedThisProcess = true;
if (!isFirstCallThisProcess && (await statusCache.isInitialized())) return;
if (await isLeader()) {
// Leader performs initialization
// Leader performs initialization - always reset on first call
await statusCache.reset();
await MCPServersRegistry.getInstance().reset();
const serverNames = Object.keys(rawConfigs);

View file

@ -1,9 +1,12 @@
import { Keyv } from 'keyv';
import { logger } from '@librechat/data-schemas';
import type { IServerConfigsRepositoryInterface } from './ServerConfigsRepositoryInterface';
import type * as t from '~/mcp/types';
import { MCPInspectionFailedError, isMCPDomainNotAllowedError } from '~/mcp/errors';
import { ServerConfigsCacheFactory } from './cache/ServerConfigsCacheFactory';
import { MCPServerInspector } from './MCPServerInspector';
import { ServerConfigsDB } from './db/ServerConfigsDB';
import { cacheConfig } from '~/cache/cacheConfig';
/**
* Central registry for managing MCP server configurations.
@ -20,14 +23,33 @@ export class MCPServersRegistry {
private readonly dbConfigsRepo: IServerConfigsRepositoryInterface;
private readonly cacheConfigsRepo: IServerConfigsRepositoryInterface;
private readonly allowedDomains?: string[] | null;
private readonly readThroughCache: Keyv<t.ParsedServerConfig>;
private readonly readThroughCacheAll: Keyv<Record<string, t.ParsedServerConfig>>;
constructor(mongoose: typeof import('mongoose')) {
constructor(mongoose: typeof import('mongoose'), allowedDomains?: string[] | null) {
this.dbConfigsRepo = new ServerConfigsDB(mongoose);
this.cacheConfigsRepo = ServerConfigsCacheFactory.create('App', false);
this.allowedDomains = allowedDomains;
const ttl = cacheConfig.MCP_REGISTRY_CACHE_TTL;
this.readThroughCache = new Keyv<t.ParsedServerConfig>({
namespace: 'mcp-registry-read-through',
ttl,
});
this.readThroughCacheAll = new Keyv<Record<string, t.ParsedServerConfig>>({
namespace: 'mcp-registry-read-through-all',
ttl,
});
}
/** Creates and initializes the singleton MCPServersRegistry instance */
public static createInstance(mongoose: typeof import('mongoose')): MCPServersRegistry {
public static createInstance(
mongoose: typeof import('mongoose'),
allowedDomains?: string[] | null,
): MCPServersRegistry {
if (!mongoose) {
throw new Error(
'MCPServersRegistry creation failed: mongoose instance is required for database operations. ' +
@ -39,7 +61,7 @@ export class MCPServersRegistry {
return MCPServersRegistry.instance;
}
logger.info('[MCPServersRegistry] Creating new instance');
MCPServersRegistry.instance = new MCPServersRegistry(mongoose);
MCPServersRegistry.instance = new MCPServersRegistry(mongoose, allowedDomains);
return MCPServersRegistry.instance;
}
@ -55,20 +77,40 @@ export class MCPServersRegistry {
serverName: string,
userId?: string,
): Promise<t.ParsedServerConfig | undefined> {
const cacheKey = this.getReadThroughCacheKey(serverName, userId);
if (await this.readThroughCache.has(cacheKey)) {
return await this.readThroughCache.get(cacheKey);
}
// First we check if any config exist with the cache
// Yaml config are pre loaded to the cache
const configFromCache = await this.cacheConfigsRepo.get(serverName);
if (configFromCache) return configFromCache;
if (configFromCache) {
await this.readThroughCache.set(cacheKey, configFromCache);
return configFromCache;
}
const configFromDB = await this.dbConfigsRepo.get(serverName, userId);
if (configFromDB) return configFromDB;
return undefined;
await this.readThroughCache.set(cacheKey, configFromDB);
return configFromDB;
}
public async getAllServerConfigs(userId?: string): Promise<Record<string, t.ParsedServerConfig>> {
return {
const cacheKey = userId ?? '__no_user__';
// Check if key exists in read-through cache
if (await this.readThroughCacheAll.has(cacheKey)) {
return (await this.readThroughCacheAll.get(cacheKey)) ?? {};
}
const result = {
...(await this.cacheConfigsRepo.getAll()),
...(await this.dbConfigsRepo.getAll(userId)),
};
await this.readThroughCacheAll.set(cacheKey, result);
return result;
}
public async addServer(
@ -80,10 +122,19 @@ export class MCPServersRegistry {
const configRepo = this.getConfigRepository(storageLocation);
let parsedConfig: t.ParsedServerConfig;
try {
parsedConfig = await MCPServerInspector.inspect(serverName, config);
parsedConfig = await MCPServerInspector.inspect(
serverName,
config,
undefined,
this.allowedDomains,
);
} catch (error) {
logger.error(`[MCPServersRegistry] Failed to inspect server "${serverName}":`, error);
throw new Error(`MCP_INSPECTION_FAILED: Failed to connect to MCP server "${serverName}"`);
// Preserve domain-specific error for better error handling
if (isMCPDomainNotAllowedError(error)) {
throw error;
}
throw new MCPInspectionFailedError(serverName, error as Error);
}
return await configRepo.add(serverName, parsedConfig, userId);
}
@ -113,10 +164,19 @@ export class MCPServersRegistry {
let parsedConfig: t.ParsedServerConfig;
try {
parsedConfig = await MCPServerInspector.inspect(serverName, configForInspection);
parsedConfig = await MCPServerInspector.inspect(
serverName,
configForInspection,
undefined,
this.allowedDomains,
);
} catch (error) {
logger.error(`[MCPServersRegistry] Failed to inspect server "${serverName}":`, error);
throw new Error(`MCP_INSPECTION_FAILED: Failed to connect to MCP server "${serverName}"`);
// Preserve domain-specific error for better error handling
if (isMCPDomainNotAllowedError(error)) {
throw error;
}
throw new MCPInspectionFailedError(serverName, error as Error);
}
await configRepo.update(serverName, parsedConfig, userId);
return parsedConfig;
@ -132,6 +192,8 @@ export class MCPServersRegistry {
public async reset(): Promise<void> {
await this.cacheConfigsRepo.reset();
await this.readThroughCache.clear();
await this.readThroughCacheAll.clear();
}
public async removeServer(
@ -155,4 +217,8 @@ export class MCPServersRegistry {
);
}
}
private getReadThroughCacheKey(serverName: string, userId?: string): string {
return userId ? `${serverName}::${userId}` : serverName;
}
}

View file

@ -201,6 +201,11 @@ describe('MCPServersInitializer Redis Integration Tests', () => {
// Mock MCPConnectionFactory
jest.spyOn(MCPConnectionFactory, 'create').mockResolvedValue(mockConnection);
// Reset caches and process flag before each test
await registryStatusCache.reset();
await registry.reset();
MCPServersInitializer.resetProcessFlag();
});
afterEach(async () => {
@ -261,7 +266,7 @@ describe('MCPServersInitializer Redis Integration Tests', () => {
await MCPServersInitializer.initialize(testConfigs);
// Verify inspect was not called again
expect(MCPServerInspector.inspect).not.toHaveBeenCalled();
expect((MCPServerInspector.inspect as jest.Mock).mock.calls.length).toBe(0);
});
it('should initialize all servers to cache repository', async () => {
@ -309,4 +314,118 @@ describe('MCPServersInitializer Redis Integration Tests', () => {
expect(await registryStatusCache.isInitialized()).toBe(true);
});
});
describe('horizontal scaling / app restart behavior', () => {
it('should re-initialize on first call even if Redis says initialized (simulating app restart)', async () => {
// First: run full initialization
await MCPServersInitializer.initialize(testConfigs);
expect(await registryStatusCache.isInitialized()).toBe(true);
// Add a stale server directly to Redis to simulate stale data
await registry.addServer(
'stale_server',
{
type: 'stdio',
command: 'node',
args: ['stale.js'],
},
'CACHE',
);
expect(await registry.getServerConfig('stale_server')).toBeDefined();
// Simulate app restart by resetting the process flag (but NOT Redis)
MCPServersInitializer.resetProcessFlag();
// Clear mocks to track new calls
jest.clearAllMocks();
// Re-initialize - should still run initialization because process flag was reset
await MCPServersInitializer.initialize(testConfigs);
// Stale server should be gone because registry.reset() was called
expect(await registry.getServerConfig('stale_server')).toBeUndefined();
// Real servers should be present
expect(await registry.getServerConfig('file_tools_server')).toBeDefined();
expect(await registry.getServerConfig('disabled_server')).toBeDefined();
// Inspector should have been called (proving re-initialization happened)
expect((MCPServerInspector.inspect as jest.Mock).mock.calls.length).toBeGreaterThan(0);
});
it('should skip re-initialization on subsequent calls within same process', async () => {
// First initialization
await MCPServersInitializer.initialize(testConfigs);
expect(await registryStatusCache.isInitialized()).toBe(true);
// Clear mocks
jest.clearAllMocks();
// Second call in same process should skip
await MCPServersInitializer.initialize(testConfigs);
// Inspector should NOT have been called
expect((MCPServerInspector.inspect as jest.Mock).mock.calls.length).toBe(0);
});
it('should clear stale data from Redis when a new instance becomes leader', async () => {
// Initial setup with testConfigs
await MCPServersInitializer.initialize(testConfigs);
// Add stale data that shouldn't exist after next initialization
await registry.addServer(
'should_be_removed',
{
type: 'stdio',
command: 'node',
args: ['old.js'],
},
'CACHE',
);
// Verify stale data exists
expect(await registry.getServerConfig('should_be_removed')).toBeDefined();
// Simulate new process starting (reset process flag)
MCPServersInitializer.resetProcessFlag();
// Initialize with different configs (fewer servers)
const reducedConfigs: t.MCPServers = {
file_tools_server: testConfigs.file_tools_server,
};
await MCPServersInitializer.initialize(reducedConfigs);
// Stale server from previous config should be gone
expect(await registry.getServerConfig('should_be_removed')).toBeUndefined();
// Server not in new configs should be gone
expect(await registry.getServerConfig('disabled_server')).toBeUndefined();
// Only server in new configs should exist
expect(await registry.getServerConfig('file_tools_server')).toBeDefined();
});
it('should work correctly when multiple instances share Redis (leader handles init)', async () => {
// First instance initializes (we are the leader)
await MCPServersInitializer.initialize(testConfigs);
// Verify initialized state is in Redis
expect(await registryStatusCache.isInitialized()).toBe(true);
// Verify servers are in Redis
const fileToolsServer = await registry.getServerConfig('file_tools_server');
expect(fileToolsServer).toBeDefined();
expect(fileToolsServer?.tools).toBe('file_read, file_write');
// Simulate second instance starting (reset process flag but keep Redis data)
MCPServersInitializer.resetProcessFlag();
jest.clearAllMocks();
// Second instance initializes - should still process because isFirstCallThisProcess
await MCPServersInitializer.initialize(testConfigs);
// Redis should still have correct data
expect(await registryStatusCache.isInitialized()).toBe(true);
expect(await registry.getServerConfig('file_tools_server')).toBeDefined();
});
});
});

View file

@ -179,9 +179,10 @@ describe('MCPServersInitializer', () => {
} as unknown as t.ParsedServerConfig;
});
// Reset caches before each test
// Reset caches and process flag before each test
await registryStatusCache.reset();
await registry.reset();
MCPServersInitializer.resetProcessFlag();
jest.clearAllMocks();
});
@ -223,18 +224,38 @@ describe('MCPServersInitializer', () => {
it('should process all server configs through inspector', async () => {
await MCPServersInitializer.initialize(testConfigs);
// Verify all configs were processed by inspector (without connection parameter)
// Verify all configs were processed by inspector
// Signature: inspect(serverName, rawConfig, connection?, allowedDomains?)
expect(mockInspect).toHaveBeenCalledTimes(5);
expect(mockInspect).toHaveBeenCalledWith('disabled_server', testConfigs.disabled_server);
expect(mockInspect).toHaveBeenCalledWith('oauth_server', testConfigs.oauth_server);
expect(mockInspect).toHaveBeenCalledWith('file_tools_server', testConfigs.file_tools_server);
expect(mockInspect).toHaveBeenCalledWith(
'disabled_server',
testConfigs.disabled_server,
undefined,
undefined,
);
expect(mockInspect).toHaveBeenCalledWith(
'oauth_server',
testConfigs.oauth_server,
undefined,
undefined,
);
expect(mockInspect).toHaveBeenCalledWith(
'file_tools_server',
testConfigs.file_tools_server,
undefined,
undefined,
);
expect(mockInspect).toHaveBeenCalledWith(
'search_tools_server',
testConfigs.search_tools_server,
undefined,
undefined,
);
expect(mockInspect).toHaveBeenCalledWith(
'remote_no_oauth_server',
testConfigs.remote_no_oauth_server,
undefined,
undefined,
);
});
@ -309,5 +330,51 @@ describe('MCPServersInitializer', () => {
expect(await registryStatusCache.isInitialized()).toBe(true);
});
it('should re-initialize on first call even if Redis cache says initialized (simulating app restart)', async () => {
// First initialization - populates caches
await MCPServersInitializer.initialize(testConfigs);
expect(await registryStatusCache.isInitialized()).toBe(true);
expect(await registry.getServerConfig('file_tools_server')).toBeDefined();
// Simulate stale data: add an extra server that shouldn't be there
await registry.addServer('stale_server', testConfigs.file_tools_server, 'CACHE');
expect(await registry.getServerConfig('stale_server')).toBeDefined();
jest.clearAllMocks();
// Simulate app restart by resetting the process flag
// In real scenario, this happens automatically when process restarts
MCPServersInitializer.resetProcessFlag();
// Re-initialize - should reset caches even though Redis says initialized
await MCPServersInitializer.initialize(testConfigs);
// Verify stale server was removed (cache was reset)
expect(await registry.getServerConfig('stale_server')).toBeUndefined();
// Verify new servers are present
expect(await registry.getServerConfig('file_tools_server')).toBeDefined();
expect(await registry.getServerConfig('oauth_server')).toBeDefined();
// Verify inspector was called again (re-initialization happened)
expect(mockInspect).toHaveBeenCalled();
});
it('should not re-initialize on subsequent calls within same process', async () => {
// First initialization (5 servers in testConfigs)
await MCPServersInitializer.initialize(testConfigs);
expect(mockInspect).toHaveBeenCalledTimes(5);
jest.clearAllMocks();
// Second call - should skip because process flag is set and Redis says initialized
await MCPServersInitializer.initialize(testConfigs);
expect(mockInspect).not.toHaveBeenCalled();
// Third call - still skips
await MCPServersInitializer.initialize(testConfigs);
expect(mockInspect).not.toHaveBeenCalled();
});
});
});

View file

@ -192,15 +192,14 @@ describe('MCPServersRegistry Redis Integration Tests', () => {
// Add server
await registry.addServer(serverName, testRawConfig, 'CACHE');
// Verify server exists
const configBefore = await registry.getServerConfig(serverName);
expect(configBefore).toBeDefined();
// Verify server exists in underlying cache repository (not via getServerConfig to avoid populating read-through cache)
expect(await registry['cacheConfigsRepo'].get(serverName)).toBeDefined();
// Remove server
await registry.removeServer(serverName, 'CACHE');
// Verify server was removed
const configAfter = await registry.getServerConfig(serverName);
// Verify server was removed from underlying cache repository
const configAfter = await registry['cacheConfigsRepo'].get(serverName);
expect(configAfter).toBeUndefined();
});
});

View file

@ -158,11 +158,13 @@ describe('MCPServersRegistry', () => {
it('should route removeServer to cache repository', async () => {
await registry.addServer('cache_server', testParsedConfig, 'CACHE');
expect(await registry.getServerConfig('cache_server')).toBeDefined();
// Verify server exists in underlying cache repository (not via getServerConfig to avoid populating read-through cache)
expect(await registry['cacheConfigsRepo'].get('cache_server')).toBeDefined();
await registry.removeServer('cache_server', 'CACHE');
const config = await registry.getServerConfig('cache_server');
// Verify server is removed from underlying cache repository
const config = await registry['cacheConfigsRepo'].get('cache_server');
expect(config).toBeUndefined();
});
});
@ -190,4 +192,114 @@ describe('MCPServersRegistry', () => {
});
});
});
describe('Read-through cache', () => {
describe('getServerConfig', () => {
it('should cache repeated calls for the same server', async () => {
// Add a server to the cache repository
await registry['cacheConfigsRepo'].add('test_server', testParsedConfig);
// Spy on the cache repository get method
const cacheRepoGetSpy = jest.spyOn(registry['cacheConfigsRepo'], 'get');
// First call should hit the cache repository
const config1 = await registry.getServerConfig('test_server');
expect(config1).toEqual(testParsedConfig);
expect(cacheRepoGetSpy).toHaveBeenCalledTimes(1);
// Second call should hit the read-through cache, not the repository
const config2 = await registry.getServerConfig('test_server');
expect(config2).toEqual(testParsedConfig);
expect(cacheRepoGetSpy).toHaveBeenCalledTimes(1); // Still 1, not 2
// Third call should also hit the read-through cache
const config3 = await registry.getServerConfig('test_server');
expect(config3).toEqual(testParsedConfig);
expect(cacheRepoGetSpy).toHaveBeenCalledTimes(1); // Still 1
});
it('should cache "not found" results to avoid repeated DB lookups', async () => {
// Spy on the DB repository get method
const dbRepoGetSpy = jest.spyOn(registry['dbConfigsRepo'], 'get');
// First call - server doesn't exist, should hit DB
const config1 = await registry.getServerConfig('nonexistent_server');
expect(config1).toBeUndefined();
expect(dbRepoGetSpy).toHaveBeenCalledTimes(1);
// Second call - should hit read-through cache, not DB
const config2 = await registry.getServerConfig('nonexistent_server');
expect(config2).toBeUndefined();
expect(dbRepoGetSpy).toHaveBeenCalledTimes(1); // Still 1, not 2
});
it('should use different cache keys for different userIds', async () => {
// Spy on the cache repository get method
const cacheRepoGetSpy = jest.spyOn(registry['cacheConfigsRepo'], 'get');
// First call without userId
await registry.getServerConfig('test_server');
expect(cacheRepoGetSpy).toHaveBeenCalledTimes(1);
// Call with userId - should be a different cache key, so hits repository again
await registry.getServerConfig('test_server', 'user123');
expect(cacheRepoGetSpy).toHaveBeenCalledTimes(2);
// Repeat call with same userId - should hit read-through cache
await registry.getServerConfig('test_server', 'user123');
expect(cacheRepoGetSpy).toHaveBeenCalledTimes(2); // Still 2
// Call with different userId - should hit repository
await registry.getServerConfig('test_server', 'user456');
expect(cacheRepoGetSpy).toHaveBeenCalledTimes(3);
});
});
describe('getAllServerConfigs', () => {
it('should cache repeated calls', async () => {
// Add servers to cache
await registry['cacheConfigsRepo'].add('server1', testParsedConfig);
await registry['cacheConfigsRepo'].add('server2', testParsedConfig);
// Spy on the cache repository getAll method
const cacheRepoGetAllSpy = jest.spyOn(registry['cacheConfigsRepo'], 'getAll');
// First call should hit the repository
const configs1 = await registry.getAllServerConfigs();
expect(Object.keys(configs1)).toHaveLength(2);
expect(cacheRepoGetAllSpy).toHaveBeenCalledTimes(1);
// Second call should hit the read-through cache
const configs2 = await registry.getAllServerConfigs();
expect(Object.keys(configs2)).toHaveLength(2);
expect(cacheRepoGetAllSpy).toHaveBeenCalledTimes(1); // Still 1
// Third call should also hit the read-through cache
const configs3 = await registry.getAllServerConfigs();
expect(Object.keys(configs3)).toHaveLength(2);
expect(cacheRepoGetAllSpy).toHaveBeenCalledTimes(1); // Still 1
});
it('should use different cache keys for different userIds', async () => {
// Spy on the cache repository getAll method
const cacheRepoGetAllSpy = jest.spyOn(registry['cacheConfigsRepo'], 'getAll');
// First call without userId
await registry.getAllServerConfigs();
expect(cacheRepoGetAllSpy).toHaveBeenCalledTimes(1);
// Call with userId - should be a different cache key
await registry.getAllServerConfigs('user123');
expect(cacheRepoGetAllSpy).toHaveBeenCalledTimes(2);
// Repeat call with same userId - should hit read-through cache
await registry.getAllServerConfigs('user123');
expect(cacheRepoGetAllSpy).toHaveBeenCalledTimes(2); // Still 2
// Call with different userId - should hit repository
await registry.getAllServerConfigs('user456');
expect(cacheRepoGetAllSpy).toHaveBeenCalledTimes(3);
});
});
});
});

View file

@ -0,0 +1,940 @@
import { logger } from '@librechat/data-schemas';
import type { StandardGraph } from '@librechat/agents';
import type { Agents } from 'librechat-data-provider';
import type {
SerializableJobData,
IEventTransport,
AbortResult,
IJobStore,
} from './interfaces/IJobStore';
import type * as t from '~/types';
import { InMemoryEventTransport } from './implementations/InMemoryEventTransport';
import { InMemoryJobStore } from './implementations/InMemoryJobStore';
/**
* Configuration options for GenerationJobManager
*/
export interface GenerationJobManagerOptions {
jobStore?: IJobStore;
eventTransport?: IEventTransport;
/**
* If true, cleans up event transport immediately when job completes.
* If false, keeps EventEmitters until periodic cleanup for late reconnections.
* Default: true (immediate cleanup to save memory)
*/
cleanupOnComplete?: boolean;
}
/**
* Runtime state for active jobs - not serializable, kept in-memory per instance.
* Contains AbortController, ready promise, and other non-serializable state.
*
* @property abortController - Controller to abort the generation
* @property readyPromise - Resolves immediately (legacy, kept for API compatibility)
* @property resolveReady - Function to resolve readyPromise
* @property finalEvent - Cached final event for late subscribers
* @property syncSent - Whether sync event was sent (reset when all subscribers leave)
* @property earlyEventBuffer - Buffer for events emitted before first subscriber connects
* @property hasSubscriber - Whether at least one subscriber has connected
* @property allSubscribersLeftHandlers - Internal handlers for disconnect events.
* These are stored separately from eventTransport subscribers to avoid being counted
* in subscriber count. This is critical: if these were registered via subscribe(),
* they would count as subscribers, causing isFirstSubscriber() to return false
* when the real client connects, which would prevent readyPromise from resolving.
*/
interface RuntimeJobState {
abortController: AbortController;
readyPromise: Promise<void>;
resolveReady: () => void;
finalEvent?: t.ServerSentEvent;
syncSent: boolean;
earlyEventBuffer: t.ServerSentEvent[];
hasSubscriber: boolean;
allSubscribersLeftHandlers?: Array<(...args: unknown[]) => void>;
}
/**
* Manages generation jobs for resumable LLM streams.
*
* Architecture: Composes two pluggable services via dependency injection:
* - jobStore: Job metadata + content state (InMemory Redis for horizontal scaling)
* - eventTransport: Pub/sub events (InMemory Redis Pub/Sub for horizontal scaling)
*
* Content state is tied to jobs:
* - In-memory: jobStore holds WeakRef to graph for live content/run steps access
* - Redis: jobStore persists chunks, reconstructs content on demand
*
* All storage methods are async to support both in-memory and external stores (Redis, etc.).
*
* @example Redis injection:
* ```ts
* const manager = new GenerationJobManagerClass({
* jobStore: new RedisJobStore(redisClient),
* eventTransport: new RedisPubSubTransport(redisClient),
* });
* ```
*/
class GenerationJobManagerClass {
/** Job metadata + content state storage - swappable for Redis, etc. */
private jobStore: IJobStore;
/** Event pub/sub transport - swappable for Redis Pub/Sub, etc. */
private eventTransport: IEventTransport;
/** Runtime state - always in-memory, not serializable */
private runtimeState = new Map<string, RuntimeJobState>();
private cleanupInterval: NodeJS.Timeout | null = null;
/** Whether we're using Redis stores */
private _isRedis = false;
/** Whether to cleanup event transport immediately on job completion */
private _cleanupOnComplete = true;
constructor(options?: GenerationJobManagerOptions) {
this.jobStore =
options?.jobStore ?? new InMemoryJobStore({ ttlAfterComplete: 0, maxJobs: 1000 });
this.eventTransport = options?.eventTransport ?? new InMemoryEventTransport();
this._cleanupOnComplete = options?.cleanupOnComplete ?? true;
}
/**
* Initialize the job manager with periodic cleanup.
* Call this once at application startup.
*/
initialize(): void {
if (this.cleanupInterval) {
return;
}
this.jobStore.initialize();
this.cleanupInterval = setInterval(() => {
this.cleanup();
}, 60000);
if (this.cleanupInterval.unref) {
this.cleanupInterval.unref();
}
logger.debug('[GenerationJobManager] Initialized');
}
/**
* Configure the manager with custom stores.
* Call this BEFORE initialize() to use Redis or other stores.
*
* @example Using Redis
* ```ts
* import { createStreamServicesFromCache } from '~/stream/createStreamServices';
* import { cacheConfig, ioredisClient } from '~/cache';
*
* const services = createStreamServicesFromCache({ cacheConfig, ioredisClient });
* GenerationJobManager.configure(services);
* GenerationJobManager.initialize();
* ```
*/
configure(services: {
jobStore: IJobStore;
eventTransport: IEventTransport;
isRedis?: boolean;
cleanupOnComplete?: boolean;
}): void {
if (this.cleanupInterval) {
logger.warn(
'[GenerationJobManager] Reconfiguring after initialization - destroying existing services',
);
this.destroy();
}
this.jobStore = services.jobStore;
this.eventTransport = services.eventTransport;
this._isRedis = services.isRedis ?? false;
this._cleanupOnComplete = services.cleanupOnComplete ?? true;
logger.info(
`[GenerationJobManager] Configured with ${this._isRedis ? 'Redis' : 'in-memory'} stores`,
);
}
/**
* Check if using Redis stores.
*/
get isRedis(): boolean {
return this._isRedis;
}
/**
* Get the job store instance (for advanced use cases).
*/
getJobStore(): IJobStore {
return this.jobStore;
}
/**
* Create a new generation job.
*
* This sets up:
* 1. Serializable job data in the job store
* 2. Runtime state including readyPromise (resolves when first SSE client connects)
* 3. allSubscribersLeft callback for handling client disconnections
*
* The readyPromise mechanism ensures generation doesn't start before the client
* is ready to receive events. The controller awaits this promise (with a short timeout)
* before starting LLM generation.
*
* @param streamId - Unique identifier for this stream
* @param userId - User who initiated the request
* @param conversationId - Optional conversation ID for lookup
* @returns A facade object for the GenerationJob
*/
async createJob(
streamId: string,
userId: string,
conversationId?: string,
): Promise<t.GenerationJob> {
const jobData = await this.jobStore.createJob(streamId, userId, conversationId);
/**
* Create runtime state with readyPromise.
*
* With the resumable stream architecture, we no longer need to wait for the
* first subscriber before starting generation:
* - Redis mode: Events are persisted and can be replayed via sync
* - In-memory mode: Content is aggregated and sent via sync on connect
*
* We resolve readyPromise immediately to eliminate startup latency.
* The sync mechanism handles late-connecting clients.
*/
let resolveReady: () => void;
const readyPromise = new Promise<void>((resolve) => {
resolveReady = resolve;
});
const runtime: RuntimeJobState = {
abortController: new AbortController(),
readyPromise,
resolveReady: resolveReady!,
syncSent: false,
earlyEventBuffer: [],
hasSubscriber: false,
};
this.runtimeState.set(streamId, runtime);
// Resolve immediately - early event buffer handles late subscribers
resolveReady!();
/**
* Set up all-subscribers-left callback.
* When all SSE clients disconnect, this:
* 1. Resets syncSent so reconnecting clients get sync event
* 2. Calls any registered allSubscribersLeft handlers (e.g., to save partial responses)
*/
this.eventTransport.onAllSubscribersLeft(streamId, () => {
const currentRuntime = this.runtimeState.get(streamId);
if (currentRuntime) {
currentRuntime.syncSent = false;
// Call registered handlers (from job.emitter.on('allSubscribersLeft', ...))
if (currentRuntime.allSubscribersLeftHandlers) {
this.jobStore
.getContentParts(streamId)
.then((result) => {
const parts = result?.content ?? [];
for (const handler of currentRuntime.allSubscribersLeftHandlers ?? []) {
try {
handler(parts);
} catch (err) {
logger.error(`[GenerationJobManager] Error in allSubscribersLeft handler:`, err);
}
}
})
.catch((err) => {
logger.error(
`[GenerationJobManager] Failed to get content parts for allSubscribersLeft handlers:`,
err,
);
});
}
}
});
logger.debug(`[GenerationJobManager] Created job: ${streamId}`);
// Return facade for backwards compatibility
return this.buildJobFacade(streamId, jobData, runtime);
}
/**
* Build a GenerationJob facade from composed services.
*
* This facade provides a unified API (job.emitter, job.abortController, etc.)
* while internally delegating to the injected services (jobStore, eventTransport,
* contentState). This allows swapping implementations (e.g., Redis) without
* changing consumer code.
*
* IMPORTANT: The emitterProxy.on('allSubscribersLeft') handler registration
* does NOT use eventTransport.subscribe(). This is intentional:
*
* If we used subscribe() for internal handlers, those handlers would count
* as subscribers. When the real SSE client connects, isFirstSubscriber()
* would return false (because internal handler was "first"), and readyPromise
* would never resolve - causing a 5-second timeout delay before generation starts.
*
* Instead, allSubscribersLeft handlers are stored in runtime.allSubscribersLeftHandlers
* and called directly from the onAllSubscribersLeft callback in createJob().
*
* @param streamId - The stream identifier
* @param jobData - Serializable job metadata from job store
* @param runtime - Non-serializable runtime state (abort controller, promises, etc.)
* @returns A GenerationJob facade object
*/
private buildJobFacade(
streamId: string,
jobData: SerializableJobData,
runtime: RuntimeJobState,
): t.GenerationJob {
/**
* Proxy emitter that delegates to eventTransport for most operations.
* Exception: allSubscribersLeft handlers are stored separately to avoid
* incrementing subscriber count (see class JSDoc above).
*/
const emitterProxy = {
on: (event: string, handler: (...args: unknown[]) => void) => {
if (event === 'allSubscribersLeft') {
// Store handler for internal callback - don't use subscribe() to avoid counting as a subscriber
if (!runtime.allSubscribersLeftHandlers) {
runtime.allSubscribersLeftHandlers = [];
}
runtime.allSubscribersLeftHandlers.push(handler);
}
},
emit: () => {
/* handled via eventTransport */
},
listenerCount: () => this.eventTransport.getSubscriberCount(streamId),
setMaxListeners: () => {
/* no-op for proxy */
},
removeAllListeners: () => this.eventTransport.cleanup(streamId),
off: () => {
/* handled via unsubscribe */
},
};
return {
streamId,
emitter: emitterProxy as unknown as t.GenerationJob['emitter'],
status: jobData.status as t.GenerationJobStatus,
createdAt: jobData.createdAt,
completedAt: jobData.completedAt,
abortController: runtime.abortController,
error: jobData.error,
metadata: {
userId: jobData.userId,
conversationId: jobData.conversationId,
userMessage: jobData.userMessage,
responseMessageId: jobData.responseMessageId,
sender: jobData.sender,
},
readyPromise: runtime.readyPromise,
resolveReady: runtime.resolveReady,
finalEvent: runtime.finalEvent,
syncSent: runtime.syncSent,
};
}
/**
* Get a job by streamId.
*/
async getJob(streamId: string): Promise<t.GenerationJob | undefined> {
const jobData = await this.jobStore.getJob(streamId);
const runtime = this.runtimeState.get(streamId);
if (!jobData || !runtime) {
return undefined;
}
return this.buildJobFacade(streamId, jobData, runtime);
}
/**
* Check if a job exists.
*/
async hasJob(streamId: string): Promise<boolean> {
return this.jobStore.hasJob(streamId);
}
/**
* Get job status.
*/
async getJobStatus(streamId: string): Promise<t.GenerationJobStatus | undefined> {
const jobData = await this.jobStore.getJob(streamId);
return jobData?.status as t.GenerationJobStatus | undefined;
}
/**
* Mark job as complete.
* If cleanupOnComplete is true (default), immediately cleans up job resources.
* Note: eventTransport is NOT cleaned up here to allow the final event to be
* fully transmitted. It will be cleaned up when subscribers disconnect or
* by the periodic cleanup job.
*/
async completeJob(streamId: string, error?: string): Promise<void> {
const runtime = this.runtimeState.get(streamId);
// Abort the controller to signal all pending operations (e.g., OAuth flow monitors)
// that the job is done and they should clean up
if (runtime) {
runtime.abortController.abort();
}
// Clear content state and run step buffer (Redis only)
this.jobStore.clearContentState(streamId);
this.runStepBuffers?.delete(streamId);
// Immediate cleanup if configured (default: true)
if (this._cleanupOnComplete) {
this.runtimeState.delete(streamId);
// Don't cleanup eventTransport here - let the done event fully transmit first.
// EventTransport will be cleaned up when subscribers disconnect or by periodic cleanup.
await this.jobStore.deleteJob(streamId);
} else {
// Only update status if keeping the job around
await this.jobStore.updateJob(streamId, {
status: error ? 'error' : 'complete',
completedAt: Date.now(),
error,
});
}
logger.debug(`[GenerationJobManager] Job completed: ${streamId}`);
}
/**
* Abort a job (user-initiated).
* Returns all data needed for token spending and message saving.
*/
async abortJob(streamId: string): Promise<AbortResult> {
const jobData = await this.jobStore.getJob(streamId);
const runtime = this.runtimeState.get(streamId);
if (!jobData) {
logger.warn(`[GenerationJobManager] Cannot abort - job not found: ${streamId}`);
return { success: false, jobData: null, content: [], finalEvent: null };
}
if (runtime) {
runtime.abortController.abort();
}
// Get content before clearing state
const result = await this.jobStore.getContentParts(streamId);
const content = result?.content ?? [];
// Detect "early abort" - aborted before any generation happened (e.g., during tool loading)
// In this case, no messages were saved to DB, so frontend shouldn't navigate to conversation
const isEarlyAbort = content.length === 0 && !jobData.responseMessageId;
// Create final event for abort
const userMessageId = jobData.userMessage?.messageId;
const abortFinalEvent: t.ServerSentEvent = {
final: true,
// Don't include conversation for early aborts - it doesn't exist in DB
conversation: isEarlyAbort ? null : { conversationId: jobData.conversationId },
title: 'New Chat',
requestMessage: jobData.userMessage
? {
messageId: userMessageId,
parentMessageId: jobData.userMessage.parentMessageId,
conversationId: jobData.conversationId,
text: jobData.userMessage.text ?? '',
isCreatedByUser: true,
}
: null,
responseMessage: isEarlyAbort
? null
: {
messageId: jobData.responseMessageId ?? `${userMessageId ?? 'aborted'}_`,
parentMessageId: userMessageId,
conversationId: jobData.conversationId,
content,
sender: jobData.sender ?? 'AI',
unfinished: true,
error: false,
isCreatedByUser: false,
},
aborted: true,
// Flag for early abort - no messages saved, frontend should go to new chat
earlyAbort: isEarlyAbort,
} as unknown as t.ServerSentEvent;
if (runtime) {
runtime.finalEvent = abortFinalEvent;
}
this.eventTransport.emitDone(streamId, abortFinalEvent);
this.jobStore.clearContentState(streamId);
this.runStepBuffers?.delete(streamId);
// Immediate cleanup if configured (default: true)
if (this._cleanupOnComplete) {
this.runtimeState.delete(streamId);
// Don't cleanup eventTransport here - let the abort event fully transmit first.
await this.jobStore.deleteJob(streamId);
} else {
// Only update status if keeping the job around
await this.jobStore.updateJob(streamId, {
status: 'aborted',
completedAt: Date.now(),
});
}
logger.debug(`[GenerationJobManager] Job aborted: ${streamId}`);
return {
success: true,
jobData,
content,
finalEvent: abortFinalEvent,
};
}
/**
* Subscribe to a job's event stream.
*
* This is called when an SSE client connects to /chat/stream/:streamId.
* On first subscription:
* - Resolves readyPromise (legacy, for API compatibility)
* - Replays any buffered early events (e.g., 'created' event)
*
* @param streamId - The stream to subscribe to
* @param onChunk - Handler for chunk events (streamed tokens, run steps, etc.)
* @param onDone - Handler for completion event (includes final message)
* @param onError - Handler for error events
* @returns Subscription object with unsubscribe function, or null if job not found
*/
async subscribe(
streamId: string,
onChunk: t.ChunkHandler,
onDone?: t.DoneHandler,
onError?: t.ErrorHandler,
): Promise<{ unsubscribe: t.UnsubscribeFn } | null> {
const runtime = this.runtimeState.get(streamId);
if (!runtime) {
return null;
}
const jobData = await this.jobStore.getJob(streamId);
// If job already complete, send final event
setImmediate(() => {
if (
runtime.finalEvent &&
jobData &&
['complete', 'error', 'aborted'].includes(jobData.status)
) {
onDone?.(runtime.finalEvent);
}
});
const subscription = this.eventTransport.subscribe(streamId, {
onChunk: (event) => {
const e = event as t.ServerSentEvent;
// Filter out internal events
if (!(e as Record<string, unknown>)._internal) {
onChunk(e);
}
},
onDone: (event) => onDone?.(event as t.ServerSentEvent),
onError,
});
// Check if this is the first subscriber
const isFirst = this.eventTransport.isFirstSubscriber(streamId);
// First subscriber: replay buffered events and mark as connected
if (!runtime.hasSubscriber) {
runtime.hasSubscriber = true;
// Replay any events that were emitted before subscriber connected
if (runtime.earlyEventBuffer.length > 0) {
logger.debug(
`[GenerationJobManager] Replaying ${runtime.earlyEventBuffer.length} buffered events for ${streamId}`,
);
for (const bufferedEvent of runtime.earlyEventBuffer) {
onChunk(bufferedEvent);
}
// Clear buffer after replay
runtime.earlyEventBuffer = [];
}
}
if (isFirst) {
runtime.resolveReady();
logger.debug(
`[GenerationJobManager] First subscriber ready, resolving promise for ${streamId}`,
);
}
return subscription;
}
/**
* Emit a chunk event to all subscribers.
* Uses runtime state check for performance (avoids async job store lookup per token).
*
* If no subscriber has connected yet, buffers the event for replay when they do.
* This ensures early events (like 'created') aren't lost due to race conditions.
*/
emitChunk(streamId: string, event: t.ServerSentEvent): void {
const runtime = this.runtimeState.get(streamId);
if (!runtime || runtime.abortController.signal.aborted) {
return;
}
// Track user message from created event
this.trackUserMessage(streamId, event);
// For Redis mode, persist chunk for later reconstruction
if (this._isRedis) {
// The SSE event structure is { event: string, data: unknown, ... }
// The aggregator expects { event: string, data: unknown } where data is the payload
const eventObj = event as Record<string, unknown>;
const eventType = eventObj.event as string | undefined;
const eventData = eventObj.data;
if (eventType && eventData !== undefined) {
// Store in format expected by aggregateContent: { event, data }
this.jobStore.appendChunk(streamId, { event: eventType, data: eventData }).catch((err) => {
logger.error(`[GenerationJobManager] Failed to append chunk:`, err);
});
// For run step events, also save to run steps key for quick retrieval
if (eventType === 'on_run_step' || eventType === 'on_run_step_completed') {
this.saveRunStepFromEvent(streamId, eventData as Record<string, unknown>);
}
}
}
// Buffer early events if no subscriber yet (replay when first subscriber connects)
if (!runtime.hasSubscriber) {
runtime.earlyEventBuffer.push(event);
// Also emit to transport in case subscriber connects mid-flight
}
this.eventTransport.emitChunk(streamId, event);
}
/**
* Extract and save run step from event data.
* The data is already the run step object from the event payload.
*/
private saveRunStepFromEvent(streamId: string, data: Record<string, unknown>): void {
// The data IS the run step object
const runStep = data as Agents.RunStep;
if (!runStep.id) {
return;
}
// Fire and forget - accumulate run steps
this.accumulateRunStep(streamId, runStep);
}
/**
* Accumulate run steps for a stream (Redis mode only).
* Uses a simple in-memory buffer that gets flushed to Redis.
* Not used in in-memory mode - run steps come from live graph via WeakRef.
*/
private runStepBuffers: Map<string, Agents.RunStep[]> | null = null;
private accumulateRunStep(streamId: string, runStep: Agents.RunStep): void {
// Lazy initialization - only create map when first used (Redis mode)
if (!this.runStepBuffers) {
this.runStepBuffers = new Map();
}
let buffer = this.runStepBuffers.get(streamId);
if (!buffer) {
buffer = [];
this.runStepBuffers.set(streamId, buffer);
}
// Update or add run step
const existingIdx = buffer.findIndex((rs) => rs.id === runStep.id);
if (existingIdx >= 0) {
buffer[existingIdx] = runStep;
} else {
buffer.push(runStep);
}
// Save to Redis
if (this.jobStore.saveRunSteps) {
this.jobStore.saveRunSteps(streamId, buffer).catch((err) => {
logger.error(`[GenerationJobManager] Failed to save run steps:`, err);
});
}
}
/**
* Track user message from created event.
*/
private trackUserMessage(streamId: string, event: t.ServerSentEvent): void {
const data = event as Record<string, unknown>;
if (!data.created || !data.message) {
return;
}
const message = data.message as Record<string, unknown>;
const updates: Partial<SerializableJobData> = {
userMessage: {
messageId: message.messageId as string,
parentMessageId: message.parentMessageId as string | undefined,
conversationId: message.conversationId as string | undefined,
text: message.text as string | undefined,
},
};
if (message.conversationId) {
updates.conversationId = message.conversationId as string;
}
this.jobStore.updateJob(streamId, updates);
}
/**
* Update job metadata.
*/
async updateMetadata(
streamId: string,
metadata: Partial<t.GenerationJobMetadata>,
): Promise<void> {
const updates: Partial<SerializableJobData> = {};
if (metadata.responseMessageId) {
updates.responseMessageId = metadata.responseMessageId;
}
if (metadata.sender) {
updates.sender = metadata.sender;
}
if (metadata.conversationId) {
updates.conversationId = metadata.conversationId;
}
if (metadata.userMessage) {
updates.userMessage = metadata.userMessage;
}
if (metadata.endpoint) {
updates.endpoint = metadata.endpoint;
}
if (metadata.iconURL) {
updates.iconURL = metadata.iconURL;
}
if (metadata.model) {
updates.model = metadata.model;
}
if (metadata.promptTokens !== undefined) {
updates.promptTokens = metadata.promptTokens;
}
await this.jobStore.updateJob(streamId, updates);
}
/**
* Set reference to the graph's contentParts array.
*/
setContentParts(streamId: string, contentParts: Agents.MessageContentComplex[]): void {
// Use runtime state check for performance (sync check)
if (!this.runtimeState.has(streamId)) {
return;
}
this.jobStore.setContentParts(streamId, contentParts);
}
/**
* Set reference to the graph instance.
*/
setGraph(streamId: string, graph: StandardGraph): void {
// Use runtime state check for performance (sync check)
if (!this.runtimeState.has(streamId)) {
return;
}
this.jobStore.setGraph(streamId, graph);
}
/**
* Get resume state for reconnecting clients.
*/
async getResumeState(streamId: string): Promise<t.ResumeState | null> {
const jobData = await this.jobStore.getJob(streamId);
if (!jobData) {
return null;
}
const result = await this.jobStore.getContentParts(streamId);
const aggregatedContent = result?.content ?? [];
const runSteps = await this.jobStore.getRunSteps(streamId);
logger.debug(`[GenerationJobManager] getResumeState:`, {
streamId,
runStepsLength: runSteps.length,
aggregatedContentLength: aggregatedContent.length,
});
return {
runSteps,
aggregatedContent,
userMessage: jobData.userMessage,
responseMessageId: jobData.responseMessageId,
conversationId: jobData.conversationId,
sender: jobData.sender,
};
}
/**
* Mark that sync has been sent.
*/
markSyncSent(streamId: string): void {
const runtime = this.runtimeState.get(streamId);
if (runtime) {
runtime.syncSent = true;
}
}
/**
* Check if sync has been sent.
*/
wasSyncSent(streamId: string): boolean {
return this.runtimeState.get(streamId)?.syncSent ?? false;
}
/**
* Emit a done event.
*/
emitDone(streamId: string, event: t.ServerSentEvent): void {
const runtime = this.runtimeState.get(streamId);
if (runtime) {
runtime.finalEvent = event;
}
this.eventTransport.emitDone(streamId, event);
}
/**
* Emit an error event.
*/
emitError(streamId: string, error: string): void {
this.eventTransport.emitError(streamId, error);
}
/**
* Cleanup expired jobs.
* Also cleans up any orphaned runtime state, buffers, and event transport entries.
*/
private async cleanup(): Promise<void> {
const count = await this.jobStore.cleanup();
// Cleanup runtime state for deleted jobs
for (const streamId of this.runtimeState.keys()) {
if (!(await this.jobStore.hasJob(streamId))) {
this.runtimeState.delete(streamId);
this.runStepBuffers?.delete(streamId);
this.jobStore.clearContentState(streamId);
this.eventTransport.cleanup(streamId);
}
}
// Also check runStepBuffers for any orphaned entries (Redis mode only)
if (this.runStepBuffers) {
for (const streamId of this.runStepBuffers.keys()) {
if (!(await this.jobStore.hasJob(streamId))) {
this.runStepBuffers.delete(streamId);
}
}
}
// Check eventTransport for orphaned streams (e.g., connections dropped without clean close)
// These are streams that exist in eventTransport but have no corresponding job
for (const streamId of this.eventTransport.getTrackedStreamIds()) {
if (!(await this.jobStore.hasJob(streamId)) && !this.runtimeState.has(streamId)) {
this.eventTransport.cleanup(streamId);
}
}
if (count > 0) {
logger.debug(`[GenerationJobManager] Cleaned up ${count} expired jobs`);
}
}
/**
* Get stream info for status endpoint.
*/
async getStreamInfo(streamId: string): Promise<{
active: boolean;
status: t.GenerationJobStatus;
aggregatedContent?: Agents.MessageContentComplex[];
createdAt: number;
} | null> {
const jobData = await this.jobStore.getJob(streamId);
if (!jobData) {
return null;
}
const result = await this.jobStore.getContentParts(streamId);
const aggregatedContent = result?.content ?? [];
return {
active: jobData.status === 'running',
status: jobData.status as t.GenerationJobStatus,
aggregatedContent,
createdAt: jobData.createdAt,
};
}
/**
* Get total job count.
*/
async getJobCount(): Promise<number> {
return this.jobStore.getJobCount();
}
/**
* Get job count by status.
*/
async getJobCountByStatus(): Promise<Record<t.GenerationJobStatus, number>> {
const [running, complete, error, aborted] = await Promise.all([
this.jobStore.getJobCountByStatus('running'),
this.jobStore.getJobCountByStatus('complete'),
this.jobStore.getJobCountByStatus('error'),
this.jobStore.getJobCountByStatus('aborted'),
]);
return { running, complete, error, aborted };
}
/**
* Get active job IDs for a user.
* Returns conversation IDs of running jobs belonging to the user.
* Performs self-healing cleanup of stale entries.
*
* @param userId - The user ID to query
* @returns Array of conversation IDs with active jobs
*/
async getActiveJobIdsForUser(userId: string): Promise<string[]> {
return this.jobStore.getActiveJobIdsByUser(userId);
}
/**
* Destroy the manager.
* Cleans up all resources including runtime state, buffers, and stores.
*/
async destroy(): Promise<void> {
if (this.cleanupInterval) {
clearInterval(this.cleanupInterval);
this.cleanupInterval = null;
}
await this.jobStore.destroy();
this.eventTransport.destroy();
this.runtimeState.clear();
this.runStepBuffers?.clear();
logger.debug('[GenerationJobManager] Destroyed');
}
}
export const GenerationJobManager = new GenerationJobManagerClass();
export { GenerationJobManagerClass };

View file

@ -0,0 +1,415 @@
import type { Redis, Cluster } from 'ioredis';
/**
* Integration tests for GenerationJobManager.
*
* Tests the job manager with both in-memory and Redis backends
* to ensure consistent behavior across deployment modes.
*
* Run with: USE_REDIS=true npx jest GenerationJobManager.stream_integration
*/
describe('GenerationJobManager Integration Tests', () => {
let originalEnv: NodeJS.ProcessEnv;
let ioredisClient: Redis | Cluster | null = null;
const testPrefix = 'JobManager-Integration-Test';
beforeAll(async () => {
originalEnv = { ...process.env };
// Set up test environment
process.env.USE_REDIS = process.env.USE_REDIS ?? 'true';
process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379';
process.env.REDIS_KEY_PREFIX = testPrefix;
jest.resetModules();
const { ioredisClient: client } = await import('../../cache/redisClients');
ioredisClient = client;
});
afterEach(async () => {
// Clean up module state
jest.resetModules();
// Clean up Redis keys (delete individually for cluster compatibility)
if (ioredisClient) {
try {
const keys = await ioredisClient.keys(`${testPrefix}*`);
const streamKeys = await ioredisClient.keys(`stream:*`);
const allKeys = [...keys, ...streamKeys];
await Promise.all(allKeys.map((key) => ioredisClient!.del(key)));
} catch {
// Ignore cleanup errors
}
}
});
afterAll(async () => {
if (ioredisClient) {
try {
// Use quit() to gracefully close - waits for pending commands
await ioredisClient.quit();
} catch {
// Fall back to disconnect if quit fails
try {
ioredisClient.disconnect();
} catch {
// Ignore
}
}
}
process.env = originalEnv;
});
describe('In-Memory Mode', () => {
test('should create and manage jobs', async () => {
const { GenerationJobManager } = await import('../GenerationJobManager');
const { InMemoryJobStore } = await import('../implementations/InMemoryJobStore');
const { InMemoryEventTransport } = await import('../implementations/InMemoryEventTransport');
// Configure with in-memory
// cleanupOnComplete: false so we can verify completed status
GenerationJobManager.configure({
jobStore: new InMemoryJobStore({ ttlAfterComplete: 60000 }),
eventTransport: new InMemoryEventTransport(),
isRedis: false,
cleanupOnComplete: false,
});
await GenerationJobManager.initialize();
const streamId = `inmem-job-${Date.now()}`;
const userId = 'test-user-1';
// Create job (async)
const job = await GenerationJobManager.createJob(streamId, userId);
expect(job.streamId).toBe(streamId);
expect(job.status).toBe('running');
// Check job exists
const hasJob = await GenerationJobManager.hasJob(streamId);
expect(hasJob).toBe(true);
// Get job
const retrieved = await GenerationJobManager.getJob(streamId);
expect(retrieved?.streamId).toBe(streamId);
// Update job
await GenerationJobManager.updateMetadata(streamId, { sender: 'TestAgent' });
const updated = await GenerationJobManager.getJob(streamId);
expect(updated?.metadata?.sender).toBe('TestAgent');
// Complete job
await GenerationJobManager.completeJob(streamId);
const completed = await GenerationJobManager.getJob(streamId);
expect(completed?.status).toBe('complete');
await GenerationJobManager.destroy();
});
test('should handle event streaming', async () => {
const { GenerationJobManager } = await import('../GenerationJobManager');
const { InMemoryJobStore } = await import('../implementations/InMemoryJobStore');
const { InMemoryEventTransport } = await import('../implementations/InMemoryEventTransport');
GenerationJobManager.configure({
jobStore: new InMemoryJobStore({ ttlAfterComplete: 60000 }),
eventTransport: new InMemoryEventTransport(),
isRedis: false,
});
await GenerationJobManager.initialize();
const streamId = `inmem-events-${Date.now()}`;
await GenerationJobManager.createJob(streamId, 'user-1');
const receivedChunks: unknown[] = [];
// Subscribe to events (subscribe takes separate args, not an object)
const subscription = await GenerationJobManager.subscribe(streamId, (event) =>
receivedChunks.push(event),
);
const { unsubscribe } = subscription!;
// Wait for first subscriber to be registered
await new Promise((resolve) => setTimeout(resolve, 10));
// Emit chunks (emitChunk takes { event, data } format)
GenerationJobManager.emitChunk(streamId, {
event: 'on_message_delta',
data: { type: 'text', text: 'Hello' },
});
GenerationJobManager.emitChunk(streamId, {
event: 'on_message_delta',
data: { type: 'text', text: ' world' },
});
// Give time for events to propagate
await new Promise((resolve) => setTimeout(resolve, 50));
// Verify chunks were received
expect(receivedChunks.length).toBeGreaterThan(0);
// Complete the job (this cleans up resources)
await GenerationJobManager.completeJob(streamId);
unsubscribe();
await GenerationJobManager.destroy();
});
});
describe('Redis Mode', () => {
test('should create and manage jobs via Redis', async () => {
if (!ioredisClient) {
console.warn('Redis not available, skipping test');
return;
}
const { GenerationJobManager } = await import('../GenerationJobManager');
const { createStreamServices } = await import('../createStreamServices');
// Create Redis services
const services = createStreamServices({
useRedis: true,
redisClient: ioredisClient,
});
expect(services.isRedis).toBe(true);
GenerationJobManager.configure(services);
await GenerationJobManager.initialize();
const streamId = `redis-job-${Date.now()}`;
const userId = 'test-user-redis';
// Create job (async)
const job = await GenerationJobManager.createJob(streamId, userId);
expect(job.streamId).toBe(streamId);
// Verify in Redis
const hasJob = await GenerationJobManager.hasJob(streamId);
expect(hasJob).toBe(true);
// Update and verify
await GenerationJobManager.updateMetadata(streamId, { sender: 'RedisAgent' });
const updated = await GenerationJobManager.getJob(streamId);
expect(updated?.metadata?.sender).toBe('RedisAgent');
await GenerationJobManager.destroy();
});
test('should persist chunks for cross-instance resume', async () => {
if (!ioredisClient) {
console.warn('Redis not available, skipping test');
return;
}
const { GenerationJobManager } = await import('../GenerationJobManager');
const { createStreamServices } = await import('../createStreamServices');
const services = createStreamServices({
useRedis: true,
redisClient: ioredisClient,
});
GenerationJobManager.configure(services);
await GenerationJobManager.initialize();
const streamId = `redis-chunks-${Date.now()}`;
await GenerationJobManager.createJob(streamId, 'user-1');
// Emit chunks (these should be persisted to Redis)
// emitChunk takes { event, data } format
GenerationJobManager.emitChunk(streamId, {
event: 'on_run_step',
data: {
id: 'step-1',
runId: 'run-1',
index: 0,
stepDetails: { type: 'message_creation' },
},
});
GenerationJobManager.emitChunk(streamId, {
event: 'on_message_delta',
data: {
id: 'step-1',
delta: { content: { type: 'text', text: 'Persisted ' } },
},
});
GenerationJobManager.emitChunk(streamId, {
event: 'on_message_delta',
data: {
id: 'step-1',
delta: { content: { type: 'text', text: 'content' } },
},
});
// Wait for async operations
await new Promise((resolve) => setTimeout(resolve, 100));
// Simulate getting resume state (as if from different instance)
const resumeState = await GenerationJobManager.getResumeState(streamId);
expect(resumeState).not.toBeNull();
expect(resumeState!.aggregatedContent?.length).toBeGreaterThan(0);
await GenerationJobManager.destroy();
});
test('should handle abort and return content', async () => {
if (!ioredisClient) {
console.warn('Redis not available, skipping test');
return;
}
const { GenerationJobManager } = await import('../GenerationJobManager');
const { createStreamServices } = await import('../createStreamServices');
const services = createStreamServices({
useRedis: true,
redisClient: ioredisClient,
});
GenerationJobManager.configure(services);
await GenerationJobManager.initialize();
const streamId = `redis-abort-${Date.now()}`;
await GenerationJobManager.createJob(streamId, 'user-1');
// Emit some content (emitChunk takes { event, data } format)
GenerationJobManager.emitChunk(streamId, {
event: 'on_run_step',
data: {
id: 'step-1',
runId: 'run-1',
index: 0,
stepDetails: { type: 'message_creation' },
},
});
GenerationJobManager.emitChunk(streamId, {
event: 'on_message_delta',
data: {
id: 'step-1',
delta: { content: { type: 'text', text: 'Partial response...' } },
},
});
await new Promise((resolve) => setTimeout(resolve, 100));
// Abort the job
const abortResult = await GenerationJobManager.abortJob(streamId);
expect(abortResult.success).toBe(true);
expect(abortResult.content.length).toBeGreaterThan(0);
await GenerationJobManager.destroy();
});
});
describe('Cross-Mode Consistency', () => {
test('should have consistent API between in-memory and Redis modes', async () => {
// This test verifies that the same operations work identically
// regardless of backend mode
const runTestWithMode = async (isRedis: boolean) => {
jest.resetModules();
const { GenerationJobManager } = await import('../GenerationJobManager');
if (isRedis && ioredisClient) {
const { createStreamServices } = await import('../createStreamServices');
GenerationJobManager.configure({
...createStreamServices({
useRedis: true,
redisClient: ioredisClient,
}),
cleanupOnComplete: false, // Keep job for verification
});
} else {
const { InMemoryJobStore } = await import('../implementations/InMemoryJobStore');
const { InMemoryEventTransport } = await import(
'../implementations/InMemoryEventTransport'
);
GenerationJobManager.configure({
jobStore: new InMemoryJobStore({ ttlAfterComplete: 60000 }),
eventTransport: new InMemoryEventTransport(),
isRedis: false,
cleanupOnComplete: false,
});
}
await GenerationJobManager.initialize();
const streamId = `consistency-${isRedis ? 'redis' : 'inmem'}-${Date.now()}`;
// Test sequence
const job = await GenerationJobManager.createJob(streamId, 'user-1');
expect(job.streamId).toBe(streamId);
expect(job.status).toBe('running');
const hasJob = await GenerationJobManager.hasJob(streamId);
expect(hasJob).toBe(true);
await GenerationJobManager.updateMetadata(streamId, {
sender: 'ConsistencyAgent',
responseMessageId: 'resp-123',
});
const updated = await GenerationJobManager.getJob(streamId);
expect(updated?.metadata?.sender).toBe('ConsistencyAgent');
expect(updated?.metadata?.responseMessageId).toBe('resp-123');
await GenerationJobManager.completeJob(streamId);
const completed = await GenerationJobManager.getJob(streamId);
expect(completed?.status).toBe('complete');
await GenerationJobManager.destroy();
};
// Test in-memory mode
await runTestWithMode(false);
// Test Redis mode if available
if (ioredisClient) {
await runTestWithMode(true);
}
});
});
describe('createStreamServices Auto-Detection', () => {
test('should auto-detect Redis when USE_REDIS is true', async () => {
if (!ioredisClient) {
console.warn('Redis not available, skipping test');
return;
}
// Force USE_REDIS to true
process.env.USE_REDIS = 'true';
jest.resetModules();
const { createStreamServices } = await import('../createStreamServices');
const services = createStreamServices();
// Should detect Redis
expect(services.isRedis).toBe(true);
});
test('should fall back to in-memory when USE_REDIS is false', async () => {
process.env.USE_REDIS = 'false';
jest.resetModules();
const { createStreamServices } = await import('../createStreamServices');
const services = createStreamServices();
expect(services.isRedis).toBe(false);
});
test('should allow forcing in-memory via config override', async () => {
const { createStreamServices } = await import('../createStreamServices');
const services = createStreamServices({ useRedis: false });
expect(services.isRedis).toBe(false);
});
});
});

View file

@ -0,0 +1,326 @@
import type { Redis, Cluster } from 'ioredis';
/**
* Integration tests for RedisEventTransport.
*
* Tests Redis Pub/Sub functionality:
* - Cross-instance event delivery
* - Subscriber management
* - Error handling
*
* Run with: USE_REDIS=true npx jest RedisEventTransport.stream_integration
*/
describe('RedisEventTransport Integration Tests', () => {
let originalEnv: NodeJS.ProcessEnv;
let ioredisClient: Redis | Cluster | null = null;
const testPrefix = 'EventTransport-Integration-Test';
beforeAll(async () => {
originalEnv = { ...process.env };
process.env.USE_REDIS = process.env.USE_REDIS ?? 'true';
process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379';
process.env.REDIS_KEY_PREFIX = testPrefix;
jest.resetModules();
const { ioredisClient: client } = await import('../../cache/redisClients');
ioredisClient = client;
});
afterAll(async () => {
if (ioredisClient) {
try {
// Use quit() to gracefully close - waits for pending commands
await ioredisClient.quit();
} catch {
// Fall back to disconnect if quit fails
try {
ioredisClient.disconnect();
} catch {
// Ignore
}
}
}
process.env = originalEnv;
});
describe('Pub/Sub Event Delivery', () => {
test('should deliver events to subscribers on same instance', async () => {
if (!ioredisClient) {
console.warn('Redis not available, skipping test');
return;
}
const { RedisEventTransport } = await import('../implementations/RedisEventTransport');
// Create subscriber client (Redis pub/sub requires dedicated connection)
const subscriber = (ioredisClient as Redis).duplicate();
const transport = new RedisEventTransport(ioredisClient, subscriber);
const streamId = `pubsub-same-${Date.now()}`;
const receivedChunks: unknown[] = [];
let doneEvent: unknown = null;
// Subscribe
const { unsubscribe } = transport.subscribe(streamId, {
onChunk: (event) => receivedChunks.push(event),
onDone: (event) => {
doneEvent = event;
},
});
// Wait for subscription to be established
await new Promise((resolve) => setTimeout(resolve, 100));
// Emit events
transport.emitChunk(streamId, { type: 'text', text: 'Hello' });
transport.emitChunk(streamId, { type: 'text', text: ' World' });
transport.emitDone(streamId, { finished: true });
// Wait for events to propagate
await new Promise((resolve) => setTimeout(resolve, 200));
expect(receivedChunks.length).toBe(2);
expect(doneEvent).toEqual({ finished: true });
unsubscribe();
transport.destroy();
subscriber.disconnect();
});
test('should deliver events across transport instances (simulating different servers)', async () => {
if (!ioredisClient) {
console.warn('Redis not available, skipping test');
return;
}
const { RedisEventTransport } = await import('../implementations/RedisEventTransport');
// Create two separate transport instances (simulating two servers)
const subscriber1 = (ioredisClient as Redis).duplicate();
const subscriber2 = (ioredisClient as Redis).duplicate();
const transport1 = new RedisEventTransport(ioredisClient, subscriber1);
const transport2 = new RedisEventTransport(ioredisClient, subscriber2);
const streamId = `pubsub-cross-${Date.now()}`;
const instance2Chunks: unknown[] = [];
// Subscribe on transport 2 (consumer)
const sub2 = transport2.subscribe(streamId, {
onChunk: (event) => instance2Chunks.push(event),
});
// Wait for subscription
await new Promise((resolve) => setTimeout(resolve, 100));
// Emit from transport 1 (producer on different instance)
transport1.emitChunk(streamId, { data: 'from-instance-1' });
// Wait for cross-instance delivery
await new Promise((resolve) => setTimeout(resolve, 200));
// Transport 2 should receive the event
expect(instance2Chunks.length).toBe(1);
expect(instance2Chunks[0]).toEqual({ data: 'from-instance-1' });
sub2.unsubscribe();
transport1.destroy();
transport2.destroy();
subscriber1.disconnect();
subscriber2.disconnect();
});
test('should handle multiple subscribers to same stream', async () => {
if (!ioredisClient) {
console.warn('Redis not available, skipping test');
return;
}
const { RedisEventTransport } = await import('../implementations/RedisEventTransport');
const subscriber = (ioredisClient as Redis).duplicate();
const transport = new RedisEventTransport(ioredisClient, subscriber);
const streamId = `pubsub-multi-${Date.now()}`;
const subscriber1Chunks: unknown[] = [];
const subscriber2Chunks: unknown[] = [];
// Two subscribers
const sub1 = transport.subscribe(streamId, {
onChunk: (event) => subscriber1Chunks.push(event),
});
const sub2 = transport.subscribe(streamId, {
onChunk: (event) => subscriber2Chunks.push(event),
});
await new Promise((resolve) => setTimeout(resolve, 100));
transport.emitChunk(streamId, { data: 'broadcast' });
await new Promise((resolve) => setTimeout(resolve, 200));
// Both should receive
expect(subscriber1Chunks.length).toBe(1);
expect(subscriber2Chunks.length).toBe(1);
sub1.unsubscribe();
sub2.unsubscribe();
transport.destroy();
subscriber.disconnect();
});
});
describe('Subscriber Management', () => {
test('should track first subscriber correctly', async () => {
if (!ioredisClient) {
console.warn('Redis not available, skipping test');
return;
}
const { RedisEventTransport } = await import('../implementations/RedisEventTransport');
const subscriber = (ioredisClient as Redis).duplicate();
const transport = new RedisEventTransport(ioredisClient, subscriber);
const streamId = `first-sub-${Date.now()}`;
// Before any subscribers - count is 0, not "first" since no one subscribed
expect(transport.getSubscriberCount(streamId)).toBe(0);
// First subscriber
const sub1 = transport.subscribe(streamId, { onChunk: () => {} });
await new Promise((resolve) => setTimeout(resolve, 50));
// Now there's a subscriber - isFirstSubscriber returns true when count is 1
expect(transport.getSubscriberCount(streamId)).toBe(1);
expect(transport.isFirstSubscriber(streamId)).toBe(true);
// Second subscriber - not first anymore
const sub2temp = transport.subscribe(streamId, { onChunk: () => {} });
await new Promise((resolve) => setTimeout(resolve, 50));
expect(transport.isFirstSubscriber(streamId)).toBe(false);
sub2temp.unsubscribe();
const sub2 = transport.subscribe(streamId, { onChunk: () => {} });
await new Promise((resolve) => setTimeout(resolve, 50));
expect(transport.getSubscriberCount(streamId)).toBe(2);
sub1.unsubscribe();
sub2.unsubscribe();
transport.destroy();
subscriber.disconnect();
});
test('should fire onAllSubscribersLeft when last subscriber leaves', async () => {
if (!ioredisClient) {
console.warn('Redis not available, skipping test');
return;
}
const { RedisEventTransport } = await import('../implementations/RedisEventTransport');
const subscriber = (ioredisClient as Redis).duplicate();
const transport = new RedisEventTransport(ioredisClient, subscriber);
const streamId = `all-left-${Date.now()}`;
let allLeftCalled = false;
transport.onAllSubscribersLeft(streamId, () => {
allLeftCalled = true;
});
const sub1 = transport.subscribe(streamId, { onChunk: () => {} });
const sub2 = transport.subscribe(streamId, { onChunk: () => {} });
await new Promise((resolve) => setTimeout(resolve, 50));
// Unsubscribe first
sub1.unsubscribe();
await new Promise((resolve) => setTimeout(resolve, 50));
// Still have one subscriber
expect(allLeftCalled).toBe(false);
// Unsubscribe last
sub2.unsubscribe();
await new Promise((resolve) => setTimeout(resolve, 50));
// Now all left
expect(allLeftCalled).toBe(true);
transport.destroy();
subscriber.disconnect();
});
});
describe('Error Handling', () => {
test('should deliver error events to subscribers', async () => {
if (!ioredisClient) {
console.warn('Redis not available, skipping test');
return;
}
const { RedisEventTransport } = await import('../implementations/RedisEventTransport');
const subscriber = (ioredisClient as Redis).duplicate();
const transport = new RedisEventTransport(ioredisClient, subscriber);
const streamId = `error-${Date.now()}`;
let receivedError: string | null = null;
transport.subscribe(streamId, {
onChunk: () => {},
onError: (err) => {
receivedError = err;
},
});
await new Promise((resolve) => setTimeout(resolve, 100));
transport.emitError(streamId, 'Test error message');
await new Promise((resolve) => setTimeout(resolve, 200));
expect(receivedError).toBe('Test error message');
transport.destroy();
subscriber.disconnect();
});
});
describe('Cleanup', () => {
test('should clean up stream resources', async () => {
if (!ioredisClient) {
console.warn('Redis not available, skipping test');
return;
}
const { RedisEventTransport } = await import('../implementations/RedisEventTransport');
const subscriber = (ioredisClient as Redis).duplicate();
const transport = new RedisEventTransport(ioredisClient, subscriber);
const streamId = `cleanup-${Date.now()}`;
transport.subscribe(streamId, { onChunk: () => {} });
await new Promise((resolve) => setTimeout(resolve, 50));
expect(transport.getSubscriberCount(streamId)).toBe(1);
// Cleanup the stream
transport.cleanup(streamId);
// Subscriber count should be 0
expect(transport.getSubscriberCount(streamId)).toBe(0);
transport.destroy();
subscriber.disconnect();
});
});
});

View file

@ -0,0 +1,975 @@
import { StepTypes } from 'librechat-data-provider';
import type { Agents } from 'librechat-data-provider';
import type { Redis, Cluster } from 'ioredis';
import { StandardGraph } from '@librechat/agents';
/**
* Integration tests for RedisJobStore.
*
* Tests horizontal scaling scenarios:
* - Multi-instance job access
* - Content reconstruction from chunks
* - Consumer groups for resumable streams
* - TTL and cleanup behavior
*
* Run with: USE_REDIS=true npx jest RedisJobStore.stream_integration
*/
describe('RedisJobStore Integration Tests', () => {
let originalEnv: NodeJS.ProcessEnv;
let ioredisClient: Redis | Cluster | null = null;
const testPrefix = 'Stream-Integration-Test';
beforeAll(async () => {
originalEnv = { ...process.env };
// Set up test environment
process.env.USE_REDIS = process.env.USE_REDIS ?? 'true';
process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379';
process.env.REDIS_KEY_PREFIX = testPrefix;
jest.resetModules();
// Import Redis client
const { ioredisClient: client } = await import('../../cache/redisClients');
ioredisClient = client;
if (!ioredisClient) {
console.warn('Redis not available, skipping integration tests');
}
});
afterEach(async () => {
if (!ioredisClient) {
return;
}
// Clean up all test keys (delete individually for cluster compatibility)
try {
const keys = await ioredisClient.keys(`${testPrefix}*`);
// Also clean up stream keys which use hash tags
const streamKeys = await ioredisClient.keys(`stream:*`);
const allKeys = [...keys, ...streamKeys];
// Delete individually to avoid CROSSSLOT errors in cluster mode
await Promise.all(allKeys.map((key) => ioredisClient!.del(key)));
} catch (error) {
console.warn('Error cleaning up test keys:', error);
}
});
afterAll(async () => {
if (ioredisClient) {
try {
// Use quit() to gracefully close - waits for pending commands
await ioredisClient.quit();
} catch {
// Fall back to disconnect if quit fails
try {
ioredisClient.disconnect();
} catch {
// Ignore
}
}
}
process.env = originalEnv;
});
describe('Job CRUD Operations', () => {
test('should create and retrieve a job', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const streamId = `test-stream-${Date.now()}`;
const userId = 'test-user-123';
const job = await store.createJob(streamId, userId, streamId);
expect(job).toMatchObject({
streamId,
userId,
status: 'running',
conversationId: streamId,
syncSent: false,
});
const retrieved = await store.getJob(streamId);
expect(retrieved).toMatchObject({
streamId,
userId,
status: 'running',
});
await store.destroy();
});
test('should update job status', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const streamId = `test-stream-${Date.now()}`;
await store.createJob(streamId, 'user-1', streamId);
await store.updateJob(streamId, { status: 'complete', completedAt: Date.now() });
const job = await store.getJob(streamId);
expect(job?.status).toBe('complete');
expect(job?.completedAt).toBeDefined();
await store.destroy();
});
test('should delete job and related data', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const streamId = `test-stream-${Date.now()}`;
await store.createJob(streamId, 'user-1', streamId);
// Add some chunks
await store.appendChunk(streamId, { event: 'on_message_delta', data: { text: 'Hello' } });
await store.deleteJob(streamId);
const job = await store.getJob(streamId);
expect(job).toBeNull();
await store.destroy();
});
});
describe('Horizontal Scaling - Multi-Instance Simulation', () => {
test('should share job state between two store instances', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
// Simulate two server instances with separate store instances
const instance1 = new RedisJobStore(ioredisClient);
const instance2 = new RedisJobStore(ioredisClient);
await instance1.initialize();
await instance2.initialize();
const streamId = `multi-instance-${Date.now()}`;
// Instance 1 creates job
await instance1.createJob(streamId, 'user-1', streamId);
// Instance 2 should see the job
const jobFromInstance2 = await instance2.getJob(streamId);
expect(jobFromInstance2).not.toBeNull();
expect(jobFromInstance2?.streamId).toBe(streamId);
// Instance 1 updates job
await instance1.updateJob(streamId, { sender: 'TestAgent', syncSent: true });
// Instance 2 should see the update
const updatedJob = await instance2.getJob(streamId);
expect(updatedJob?.sender).toBe('TestAgent');
expect(updatedJob?.syncSent).toBe(true);
await instance1.destroy();
await instance2.destroy();
});
test('should share chunks between instances for content reconstruction', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const instance1 = new RedisJobStore(ioredisClient);
const instance2 = new RedisJobStore(ioredisClient);
await instance1.initialize();
await instance2.initialize();
const streamId = `chunk-sharing-${Date.now()}`;
await instance1.createJob(streamId, 'user-1', streamId);
// Instance 1 emits chunks (simulating stream generation)
// Format must match what aggregateContent expects:
// - on_run_step: { id, index, stepDetails: { type } }
// - on_message_delta: { id, delta: { content: { type, text } } }
const chunks = [
{
event: 'on_run_step',
data: {
id: 'step-1',
runId: 'run-1',
index: 0,
stepDetails: { type: 'message_creation' },
},
},
{
event: 'on_message_delta',
data: { id: 'step-1', delta: { content: { type: 'text', text: 'Hello, ' } } },
},
{
event: 'on_message_delta',
data: { id: 'step-1', delta: { content: { type: 'text', text: 'world!' } } },
},
];
for (const chunk of chunks) {
await instance1.appendChunk(streamId, chunk);
}
// Instance 2 reconstructs content (simulating reconnect to different instance)
const result = await instance2.getContentParts(streamId);
// Should have reconstructed content
expect(result).not.toBeNull();
expect(result!.content.length).toBeGreaterThan(0);
await instance1.destroy();
await instance2.destroy();
});
test('should share run steps between instances', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const instance1 = new RedisJobStore(ioredisClient);
const instance2 = new RedisJobStore(ioredisClient);
await instance1.initialize();
await instance2.initialize();
const streamId = `runsteps-sharing-${Date.now()}`;
await instance1.createJob(streamId, 'user-1', streamId);
// Instance 1 saves run steps
const runSteps: Partial<Agents.RunStep>[] = [
{ id: 'step-1', runId: 'run-1', type: StepTypes.MESSAGE_CREATION, index: 0 },
{ id: 'step-2', runId: 'run-1', type: StepTypes.TOOL_CALLS, index: 1 },
];
await instance1.saveRunSteps!(streamId, runSteps as Agents.RunStep[]);
// Instance 2 retrieves run steps
const retrievedSteps = await instance2.getRunSteps(streamId);
expect(retrievedSteps).toHaveLength(2);
expect(retrievedSteps[0].id).toBe('step-1');
expect(retrievedSteps[1].id).toBe('step-2');
await instance1.destroy();
await instance2.destroy();
});
});
describe('Content Reconstruction', () => {
test('should reconstruct text content from message deltas', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const streamId = `text-reconstruction-${Date.now()}`;
await store.createJob(streamId, 'user-1', streamId);
// Simulate a streaming response with correct event format
const chunks = [
{
event: 'on_run_step',
data: {
id: 'step-1',
runId: 'run-1',
index: 0,
stepDetails: { type: 'message_creation' },
},
},
{
event: 'on_message_delta',
data: { id: 'step-1', delta: { content: { type: 'text', text: 'The ' } } },
},
{
event: 'on_message_delta',
data: { id: 'step-1', delta: { content: { type: 'text', text: 'quick ' } } },
},
{
event: 'on_message_delta',
data: { id: 'step-1', delta: { content: { type: 'text', text: 'brown ' } } },
},
{
event: 'on_message_delta',
data: { id: 'step-1', delta: { content: { type: 'text', text: 'fox.' } } },
},
];
for (const chunk of chunks) {
await store.appendChunk(streamId, chunk);
}
const result = await store.getContentParts(streamId);
expect(result).not.toBeNull();
// Content aggregator combines text deltas
const textPart = result!.content.find((p) => p.type === 'text');
expect(textPart).toBeDefined();
await store.destroy();
});
test('should reconstruct thinking content from reasoning deltas', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const streamId = `think-reconstruction-${Date.now()}`;
await store.createJob(streamId, 'user-1', streamId);
// on_reasoning_delta events need id and delta.content format
const chunks = [
{
event: 'on_run_step',
data: {
id: 'step-1',
runId: 'run-1',
index: 0,
stepDetails: { type: 'message_creation' },
},
},
{
event: 'on_reasoning_delta',
data: { id: 'step-1', delta: { content: { type: 'think', think: 'Let me think...' } } },
},
{
event: 'on_reasoning_delta',
data: {
id: 'step-1',
delta: { content: { type: 'think', think: ' about this problem.' } },
},
},
{
event: 'on_run_step',
data: {
id: 'step-2',
runId: 'run-1',
index: 1,
stepDetails: { type: 'message_creation' },
},
},
{
event: 'on_message_delta',
data: { id: 'step-2', delta: { content: { type: 'text', text: 'The answer is 42.' } } },
},
];
for (const chunk of chunks) {
await store.appendChunk(streamId, chunk);
}
const result = await store.getContentParts(streamId);
expect(result).not.toBeNull();
// Should have both think and text parts
const thinkPart = result!.content.find((p) => p.type === 'think');
const textPart = result!.content.find((p) => p.type === 'text');
expect(thinkPart).toBeDefined();
expect(textPart).toBeDefined();
await store.destroy();
});
test('should return null for empty chunks', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const streamId = `empty-chunks-${Date.now()}`;
await store.createJob(streamId, 'user-1', streamId);
// No chunks appended
const content = await store.getContentParts(streamId);
expect(content).toBeNull();
await store.destroy();
});
});
describe('Consumer Groups', () => {
test('should create consumer group and read chunks', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const streamId = `consumer-group-${Date.now()}`;
await store.createJob(streamId, 'user-1', streamId);
// Add some chunks
const chunks = [
{ event: 'on_message_delta', data: { type: 'text', text: 'Chunk 1' } },
{ event: 'on_message_delta', data: { type: 'text', text: 'Chunk 2' } },
{ event: 'on_message_delta', data: { type: 'text', text: 'Chunk 3' } },
];
for (const chunk of chunks) {
await store.appendChunk(streamId, chunk);
}
// Wait for Redis to sync
await new Promise((resolve) => setTimeout(resolve, 50));
// Create consumer group starting from beginning
const groupName = `client-${Date.now()}`;
await store.createConsumerGroup(streamId, groupName, '0');
// Read chunks from group
// Note: With '0' as lastId, we need to use getPendingChunks or read with '0' instead of '>'
// The '>' only gives new messages after group creation
const readChunks = await store.getPendingChunks(streamId, groupName, 'consumer-1');
// If pending is empty, the messages haven't been delivered yet
// Let's read from '0' using regular read
if (readChunks.length === 0) {
// Consumer groups created at '0' should have access to all messages
// but they need to be "claimed" first. Skip this test as consumer groups
// require more complex setup for historical messages.
console.log(
'Skipping consumer group test - requires claim mechanism for historical messages',
);
await store.deleteConsumerGroup(streamId, groupName);
await store.destroy();
return;
}
expect(readChunks.length).toBe(3);
// Acknowledge chunks
const ids = readChunks.map((c) => c.id);
await store.acknowledgeChunks(streamId, groupName, ids);
// Reading again should return empty (all acknowledged)
const moreChunks = await store.readChunksFromGroup(streamId, groupName, 'consumer-1');
expect(moreChunks.length).toBe(0);
// Cleanup
await store.deleteConsumerGroup(streamId, groupName);
await store.destroy();
});
// TODO: Debug consumer group timing with Redis Streams
test.skip('should resume from where client left off', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const streamId = `resume-test-${Date.now()}`;
await store.createJob(streamId, 'user-1', streamId);
// Create consumer group FIRST (before adding chunks) to track delivery
const groupName = `client-resume-${Date.now()}`;
await store.createConsumerGroup(streamId, groupName, '$'); // Start from end (only new messages)
// Add initial chunks (these will be "new" to the consumer group)
await store.appendChunk(streamId, {
event: 'on_message_delta',
data: { type: 'text', text: 'Part 1' },
});
await store.appendChunk(streamId, {
event: 'on_message_delta',
data: { type: 'text', text: 'Part 2' },
});
// Wait for Redis to sync
await new Promise((resolve) => setTimeout(resolve, 50));
// Client reads first batch
const firstRead = await store.readChunksFromGroup(streamId, groupName, 'consumer-1');
expect(firstRead.length).toBe(2);
// ACK the chunks
await store.acknowledgeChunks(
streamId,
groupName,
firstRead.map((c) => c.id),
);
// More chunks arrive while client is away
await store.appendChunk(streamId, {
event: 'on_message_delta',
data: { type: 'text', text: 'Part 3' },
});
await store.appendChunk(streamId, {
event: 'on_message_delta',
data: { type: 'text', text: 'Part 4' },
});
// Wait for Redis to sync
await new Promise((resolve) => setTimeout(resolve, 50));
// Client reconnects - should only get new chunks
const secondRead = await store.readChunksFromGroup(streamId, groupName, 'consumer-1');
expect(secondRead.length).toBe(2);
await store.deleteConsumerGroup(streamId, groupName);
await store.destroy();
});
});
describe('TTL and Cleanup', () => {
test('should set running TTL on chunk stream', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient, { runningTtl: 60 });
await store.initialize();
const streamId = `ttl-test-${Date.now()}`;
await store.createJob(streamId, 'user-1', streamId);
await store.appendChunk(streamId, {
event: 'on_message_delta',
data: { id: 'step-1', type: 'text', text: 'test' },
});
// Check that TTL was set on the stream key
// Note: ioredis client has keyPrefix, so we use the key WITHOUT the prefix
// Key uses hash tag format: stream:{streamId}:chunks
const ttl = await ioredisClient.ttl(`stream:{${streamId}}:chunks`);
expect(ttl).toBeGreaterThan(0);
expect(ttl).toBeLessThanOrEqual(60);
await store.destroy();
});
test('should clean up stale jobs', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
// Very short TTL for testing
const store = new RedisJobStore(ioredisClient, { runningTtl: 1 });
await store.initialize();
const streamId = `stale-job-${Date.now()}`;
// Manually create a job that looks old
// Note: ioredis client has keyPrefix, so we use the key WITHOUT the prefix
// Key uses hash tag format: stream:{streamId}:job
const jobKey = `stream:{${streamId}}:job`;
const veryOldTimestamp = Date.now() - 10000; // 10 seconds ago
await ioredisClient.hmset(jobKey, {
streamId,
userId: 'user-1',
status: 'running',
createdAt: veryOldTimestamp.toString(),
syncSent: '0',
});
await ioredisClient.sadd(`stream:running`, streamId);
// Run cleanup
const cleaned = await store.cleanup();
// Should have cleaned the stale job
expect(cleaned).toBeGreaterThanOrEqual(1);
await store.destroy();
});
});
describe('Active Jobs by User', () => {
test('should return active job IDs for a user', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const userId = `test-user-${Date.now()}`;
const streamId1 = `stream-1-${Date.now()}`;
const streamId2 = `stream-2-${Date.now()}`;
// Create two jobs for the same user
await store.createJob(streamId1, userId, streamId1);
await store.createJob(streamId2, userId, streamId2);
// Get active jobs for user
const activeJobs = await store.getActiveJobIdsByUser(userId);
expect(activeJobs).toHaveLength(2);
expect(activeJobs).toContain(streamId1);
expect(activeJobs).toContain(streamId2);
await store.destroy();
});
test('should return empty array for user with no jobs', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const userId = `nonexistent-user-${Date.now()}`;
const activeJobs = await store.getActiveJobIdsByUser(userId);
expect(activeJobs).toHaveLength(0);
await store.destroy();
});
test('should not return completed jobs', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const userId = `test-user-${Date.now()}`;
const streamId1 = `stream-1-${Date.now()}`;
const streamId2 = `stream-2-${Date.now()}`;
// Create two jobs
await store.createJob(streamId1, userId, streamId1);
await store.createJob(streamId2, userId, streamId2);
// Complete one job
await store.updateJob(streamId1, { status: 'complete', completedAt: Date.now() });
// Get active jobs - should only return the running one
const activeJobs = await store.getActiveJobIdsByUser(userId);
expect(activeJobs).toHaveLength(1);
expect(activeJobs).toContain(streamId2);
expect(activeJobs).not.toContain(streamId1);
await store.destroy();
});
test('should not return aborted jobs', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const userId = `test-user-${Date.now()}`;
const streamId = `stream-${Date.now()}`;
// Create a job and abort it
await store.createJob(streamId, userId, streamId);
await store.updateJob(streamId, { status: 'aborted', completedAt: Date.now() });
// Get active jobs - should be empty
const activeJobs = await store.getActiveJobIdsByUser(userId);
expect(activeJobs).toHaveLength(0);
await store.destroy();
});
test('should not return error jobs', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const userId = `test-user-${Date.now()}`;
const streamId = `stream-${Date.now()}`;
// Create a job with error status
await store.createJob(streamId, userId, streamId);
await store.updateJob(streamId, {
status: 'error',
error: 'Test error',
completedAt: Date.now(),
});
// Get active jobs - should be empty
const activeJobs = await store.getActiveJobIdsByUser(userId);
expect(activeJobs).toHaveLength(0);
await store.destroy();
});
test('should perform self-healing cleanup of stale entries', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const userId = `test-user-${Date.now()}`;
const streamId = `stream-${Date.now()}`;
const staleStreamId = `stale-stream-${Date.now()}`;
// Create a real job
await store.createJob(streamId, userId, streamId);
// Manually add a stale entry to the user's job set (simulating orphaned data)
const userJobsKey = `stream:user:{${userId}}:jobs`;
await ioredisClient.sadd(userJobsKey, staleStreamId);
// Verify both entries exist in the set
const beforeCleanup = await ioredisClient.smembers(userJobsKey);
expect(beforeCleanup).toContain(streamId);
expect(beforeCleanup).toContain(staleStreamId);
// Get active jobs - should trigger self-healing
const activeJobs = await store.getActiveJobIdsByUser(userId);
// Should only return the real job
expect(activeJobs).toHaveLength(1);
expect(activeJobs).toContain(streamId);
// Verify stale entry was removed
const afterCleanup = await ioredisClient.smembers(userJobsKey);
expect(afterCleanup).toContain(streamId);
expect(afterCleanup).not.toContain(staleStreamId);
await store.destroy();
});
test('should isolate jobs between different users', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const userId1 = `user-1-${Date.now()}`;
const userId2 = `user-2-${Date.now()}`;
const streamId1 = `stream-1-${Date.now()}`;
const streamId2 = `stream-2-${Date.now()}`;
// Create jobs for different users
await store.createJob(streamId1, userId1, streamId1);
await store.createJob(streamId2, userId2, streamId2);
// Get active jobs for user 1
const user1Jobs = await store.getActiveJobIdsByUser(userId1);
expect(user1Jobs).toHaveLength(1);
expect(user1Jobs).toContain(streamId1);
expect(user1Jobs).not.toContain(streamId2);
// Get active jobs for user 2
const user2Jobs = await store.getActiveJobIdsByUser(userId2);
expect(user2Jobs).toHaveLength(1);
expect(user2Jobs).toContain(streamId2);
expect(user2Jobs).not.toContain(streamId1);
await store.destroy();
});
test('should work across multiple store instances (horizontal scaling)', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
// Simulate two server instances
const instance1 = new RedisJobStore(ioredisClient);
const instance2 = new RedisJobStore(ioredisClient);
await instance1.initialize();
await instance2.initialize();
const userId = `test-user-${Date.now()}`;
const streamId = `stream-${Date.now()}`;
// Instance 1 creates a job
await instance1.createJob(streamId, userId, streamId);
// Instance 2 should see the active job
const activeJobs = await instance2.getActiveJobIdsByUser(userId);
expect(activeJobs).toHaveLength(1);
expect(activeJobs).toContain(streamId);
// Instance 1 completes the job
await instance1.updateJob(streamId, { status: 'complete', completedAt: Date.now() });
// Instance 2 should no longer see the job as active
const activeJobsAfter = await instance2.getActiveJobIdsByUser(userId);
expect(activeJobsAfter).toHaveLength(0);
await instance1.destroy();
await instance2.destroy();
});
test('should clean up user jobs set when job is deleted', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const userId = `test-user-${Date.now()}`;
const streamId = `stream-${Date.now()}`;
// Create a job
await store.createJob(streamId, userId, streamId);
// Verify job is in active list
let activeJobs = await store.getActiveJobIdsByUser(userId);
expect(activeJobs).toContain(streamId);
// Delete the job
await store.deleteJob(streamId);
// Job should no longer be in active list
activeJobs = await store.getActiveJobIdsByUser(userId);
expect(activeJobs).not.toContain(streamId);
await store.destroy();
});
});
describe('Local Graph Cache Optimization', () => {
test('should use local cache when available', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const streamId = `local-cache-${Date.now()}`;
await store.createJob(streamId, 'user-1', streamId);
// Create a mock graph
const mockContentParts = [{ type: 'text', text: 'From local cache' }];
const mockRunSteps = [{ id: 'step-1', type: 'message_creation', status: 'completed' }];
const mockGraph = {
getContentParts: () => mockContentParts,
getRunSteps: () => mockRunSteps,
};
// Set graph reference (will be cached locally)
store.setGraph(streamId, mockGraph as unknown as StandardGraph);
// Get content - should come from local cache, not Redis
const result = await store.getContentParts(streamId);
expect(result!.content).toEqual(mockContentParts);
// Get run steps - should come from local cache
const runSteps = await store.getRunSteps(streamId);
expect(runSteps).toEqual(mockRunSteps);
await store.destroy();
});
test('should fall back to Redis when local cache not available', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
// Instance 1 creates and populates data
const instance1 = new RedisJobStore(ioredisClient);
await instance1.initialize();
const streamId = `fallback-test-${Date.now()}`;
await instance1.createJob(streamId, 'user-1', streamId);
// Add chunks to Redis with correct format
await instance1.appendChunk(streamId, {
event: 'on_run_step',
data: {
id: 'step-1',
runId: 'run-1',
index: 0,
stepDetails: { type: 'message_creation' },
},
});
await instance1.appendChunk(streamId, {
event: 'on_message_delta',
data: { id: 'step-1', delta: { content: { type: 'text', text: 'From Redis' } } },
});
// Save run steps to Redis
await instance1.saveRunSteps!(streamId, [
{
id: 'step-1',
runId: 'run-1',
type: StepTypes.MESSAGE_CREATION,
index: 0,
} as unknown as Agents.RunStep,
]);
// Instance 2 has NO local cache - should fall back to Redis
const instance2 = new RedisJobStore(ioredisClient);
await instance2.initialize();
// Get content - should reconstruct from Redis chunks
const result = await instance2.getContentParts(streamId);
expect(result).not.toBeNull();
expect(result!.content.length).toBeGreaterThan(0);
// Get run steps - should fetch from Redis
const runSteps = await instance2.getRunSteps(streamId);
expect(runSteps).toHaveLength(1);
expect(runSteps[0].id).toBe('step-1');
await instance1.destroy();
await instance2.destroy();
});
});
});

View file

@ -0,0 +1,133 @@
import type { Redis, Cluster } from 'ioredis';
import { logger } from '@librechat/data-schemas';
import type { IJobStore, IEventTransport } from './interfaces/IJobStore';
import { InMemoryJobStore } from './implementations/InMemoryJobStore';
import { InMemoryEventTransport } from './implementations/InMemoryEventTransport';
import { RedisJobStore } from './implementations/RedisJobStore';
import { RedisEventTransport } from './implementations/RedisEventTransport';
import { cacheConfig } from '~/cache/cacheConfig';
import { ioredisClient } from '~/cache/redisClients';
/**
* Configuration for stream services (optional overrides)
*/
export interface StreamServicesConfig {
/**
* Override Redis detection. If not provided, uses cacheConfig.USE_REDIS.
*/
useRedis?: boolean;
/**
* Override Redis client. If not provided, uses ioredisClient from cache.
*/
redisClient?: Redis | Cluster | null;
/**
* Dedicated Redis client for pub/sub subscribing.
* If not provided, will duplicate the main client.
*/
redisSubscriber?: Redis | Cluster | null;
/**
* Options for in-memory job store
*/
inMemoryOptions?: {
ttlAfterComplete?: number;
maxJobs?: number;
};
}
/**
* Stream services result
*/
export interface StreamServices {
jobStore: IJobStore;
eventTransport: IEventTransport;
isRedis: boolean;
}
/**
* Create stream services (job store + event transport).
*
* Automatically detects Redis from cacheConfig.USE_REDIS_STREAMS and uses
* the existing ioredisClient. Falls back to in-memory if Redis
* is not configured or not available.
*
* USE_REDIS_STREAMS defaults to USE_REDIS if not explicitly set,
* allowing users to disable Redis for streams while keeping it for other caches.
*
* @example Auto-detect (uses cacheConfig)
* ```ts
* const services = createStreamServices();
* // Uses Redis if USE_REDIS_STREAMS=true (defaults to USE_REDIS), otherwise in-memory
* ```
*
* @example Force in-memory
* ```ts
* const services = createStreamServices({ useRedis: false });
* ```
*/
export function createStreamServices(config: StreamServicesConfig = {}): StreamServices {
// Use provided config or fall back to cache config (USE_REDIS_STREAMS for stream-specific override)
const useRedis = config.useRedis ?? cacheConfig.USE_REDIS_STREAMS;
const redisClient = config.redisClient ?? ioredisClient;
const { redisSubscriber, inMemoryOptions } = config;
// Check if we should and can use Redis
if (useRedis && redisClient) {
try {
// For subscribing, we need a dedicated connection
// If subscriber not provided, duplicate the main client
let subscriber = redisSubscriber;
if (!subscriber && 'duplicate' in redisClient) {
subscriber = (redisClient as Redis).duplicate();
logger.info('[StreamServices] Duplicated Redis client for subscriber');
}
if (!subscriber) {
logger.warn('[StreamServices] No subscriber client available, falling back to in-memory');
return createInMemoryServices(inMemoryOptions);
}
const jobStore = new RedisJobStore(redisClient);
const eventTransport = new RedisEventTransport(redisClient, subscriber);
logger.info('[StreamServices] Created Redis-backed stream services');
return {
jobStore,
eventTransport,
isRedis: true,
};
} catch (err) {
logger.error(
'[StreamServices] Failed to create Redis services, falling back to in-memory:',
err,
);
return createInMemoryServices(inMemoryOptions);
}
}
return createInMemoryServices(inMemoryOptions);
}
/**
* Create in-memory stream services
*/
function createInMemoryServices(options?: StreamServicesConfig['inMemoryOptions']): StreamServices {
const jobStore = new InMemoryJobStore({
ttlAfterComplete: options?.ttlAfterComplete ?? 300000, // 5 minutes
maxJobs: options?.maxJobs ?? 1000,
});
const eventTransport = new InMemoryEventTransport();
logger.info('[StreamServices] Created in-memory stream services');
return {
jobStore,
eventTransport,
isRedis: false,
};
}

View file

@ -0,0 +1,137 @@
import { EventEmitter } from 'events';
import { logger } from '@librechat/data-schemas';
import type { IEventTransport } from '../interfaces/IJobStore';
interface StreamState {
emitter: EventEmitter;
allSubscribersLeftCallback?: () => void;
}
/**
* In-memory event transport using Node.js EventEmitter.
* For horizontal scaling, replace with RedisEventTransport.
*/
export class InMemoryEventTransport implements IEventTransport {
private streams = new Map<string, StreamState>();
private getOrCreateStream(streamId: string): StreamState {
let state = this.streams.get(streamId);
if (!state) {
const emitter = new EventEmitter();
emitter.setMaxListeners(100);
state = { emitter };
this.streams.set(streamId, state);
}
return state;
}
subscribe(
streamId: string,
handlers: {
onChunk: (event: unknown) => void;
onDone?: (event: unknown) => void;
onError?: (error: string) => void;
},
): { unsubscribe: () => void } {
const state = this.getOrCreateStream(streamId);
const chunkHandler = (event: unknown) => handlers.onChunk(event);
const doneHandler = (event: unknown) => handlers.onDone?.(event);
const errorHandler = (error: string) => handlers.onError?.(error);
state.emitter.on('chunk', chunkHandler);
state.emitter.on('done', doneHandler);
state.emitter.on('error', errorHandler);
logger.debug(
`[InMemoryEventTransport] subscribe ${streamId}: listeners=${state.emitter.listenerCount('chunk')}`,
);
return {
unsubscribe: () => {
const currentState = this.streams.get(streamId);
if (currentState) {
currentState.emitter.off('chunk', chunkHandler);
currentState.emitter.off('done', doneHandler);
currentState.emitter.off('error', errorHandler);
// Check if all subscribers left - cleanup and notify
if (currentState.emitter.listenerCount('chunk') === 0) {
currentState.allSubscribersLeftCallback?.();
// Auto-cleanup the stream entry when no subscribers remain
currentState.emitter.removeAllListeners();
this.streams.delete(streamId);
}
}
},
};
}
emitChunk(streamId: string, event: unknown): void {
const state = this.streams.get(streamId);
state?.emitter.emit('chunk', event);
}
emitDone(streamId: string, event: unknown): void {
const state = this.streams.get(streamId);
state?.emitter.emit('done', event);
}
emitError(streamId: string, error: string): void {
const state = this.streams.get(streamId);
state?.emitter.emit('error', error);
}
getSubscriberCount(streamId: string): number {
const state = this.streams.get(streamId);
return state?.emitter.listenerCount('chunk') ?? 0;
}
onAllSubscribersLeft(streamId: string, callback: () => void): void {
const state = this.getOrCreateStream(streamId);
state.allSubscribersLeftCallback = callback;
}
/**
* Check if this is the first subscriber (for ready signaling)
*/
isFirstSubscriber(streamId: string): boolean {
const state = this.streams.get(streamId);
const count = state?.emitter.listenerCount('chunk') ?? 0;
logger.debug(`[InMemoryEventTransport] isFirstSubscriber ${streamId}: count=${count}`);
return count === 1;
}
/**
* Cleanup a stream's event emitter
*/
cleanup(streamId: string): void {
const state = this.streams.get(streamId);
if (state) {
state.emitter.removeAllListeners();
this.streams.delete(streamId);
}
}
/**
* Get count of tracked streams (for monitoring)
*/
getStreamCount(): number {
return this.streams.size;
}
/**
* Get all tracked stream IDs (for orphan cleanup)
*/
getTrackedStreamIds(): string[] {
return Array.from(this.streams.keys());
}
destroy(): void {
for (const state of this.streams.values()) {
state.emitter.removeAllListeners();
}
this.streams.clear();
logger.debug('[InMemoryEventTransport] Destroyed');
}
}

View file

@ -0,0 +1,303 @@
import { logger } from '@librechat/data-schemas';
import type { StandardGraph } from '@librechat/agents';
import type { Agents } from 'librechat-data-provider';
import type { IJobStore, SerializableJobData, JobStatus } from '~/stream/interfaces/IJobStore';
/**
* Content state for a job - volatile, in-memory only.
* Uses WeakRef to allow garbage collection of graph when no longer needed.
*/
interface ContentState {
contentParts: Agents.MessageContentComplex[];
graphRef: WeakRef<StandardGraph> | null;
}
/**
* In-memory implementation of IJobStore.
* Suitable for single-instance deployments.
* For horizontal scaling, use RedisJobStore.
*
* Content state is tied to jobs:
* - Uses WeakRef to graph for live access to contentParts and contentData (run steps)
* - No chunk persistence needed - same instance handles generation and reconnects
*/
export class InMemoryJobStore implements IJobStore {
private jobs = new Map<string, SerializableJobData>();
private contentState = new Map<string, ContentState>();
private cleanupInterval: NodeJS.Timeout | null = null;
/** Maps userId -> Set of streamIds (conversationIds) for active jobs */
private userJobMap = new Map<string, Set<string>>();
/** Time to keep completed jobs before cleanup (0 = immediate) */
private ttlAfterComplete = 0;
/** Maximum number of concurrent jobs */
private maxJobs = 1000;
constructor(options?: { ttlAfterComplete?: number; maxJobs?: number }) {
if (options?.ttlAfterComplete) {
this.ttlAfterComplete = options.ttlAfterComplete;
}
if (options?.maxJobs) {
this.maxJobs = options.maxJobs;
}
}
async initialize(): Promise<void> {
if (this.cleanupInterval) {
return;
}
this.cleanupInterval = setInterval(() => {
this.cleanup();
}, 60000);
if (this.cleanupInterval.unref) {
this.cleanupInterval.unref();
}
logger.debug('[InMemoryJobStore] Initialized with cleanup interval');
}
async createJob(
streamId: string,
userId: string,
conversationId?: string,
): Promise<SerializableJobData> {
if (this.jobs.size >= this.maxJobs) {
await this.evictOldest();
}
const job: SerializableJobData = {
streamId,
userId,
status: 'running',
createdAt: Date.now(),
conversationId,
syncSent: false,
};
this.jobs.set(streamId, job);
// Track job by userId for efficient user-scoped queries
let userJobs = this.userJobMap.get(userId);
if (!userJobs) {
userJobs = new Set();
this.userJobMap.set(userId, userJobs);
}
userJobs.add(streamId);
logger.debug(`[InMemoryJobStore] Created job: ${streamId}`);
return job;
}
async getJob(streamId: string): Promise<SerializableJobData | null> {
return this.jobs.get(streamId) ?? null;
}
async updateJob(streamId: string, updates: Partial<SerializableJobData>): Promise<void> {
const job = this.jobs.get(streamId);
if (!job) {
return;
}
Object.assign(job, updates);
}
async deleteJob(streamId: string): Promise<void> {
this.jobs.delete(streamId);
this.contentState.delete(streamId);
logger.debug(`[InMemoryJobStore] Deleted job: ${streamId}`);
}
async hasJob(streamId: string): Promise<boolean> {
return this.jobs.has(streamId);
}
async getRunningJobs(): Promise<SerializableJobData[]> {
const running: SerializableJobData[] = [];
for (const job of this.jobs.values()) {
if (job.status === 'running') {
running.push(job);
}
}
return running;
}
async cleanup(): Promise<number> {
const now = Date.now();
const toDelete: string[] = [];
for (const [streamId, job] of this.jobs) {
const isFinished = ['complete', 'error', 'aborted'].includes(job.status);
if (isFinished && job.completedAt) {
// TTL of 0 means immediate cleanup, otherwise wait for TTL to expire
if (this.ttlAfterComplete === 0 || now - job.completedAt > this.ttlAfterComplete) {
toDelete.push(streamId);
}
}
}
for (const id of toDelete) {
await this.deleteJob(id);
}
if (toDelete.length > 0) {
logger.debug(`[InMemoryJobStore] Cleaned up ${toDelete.length} expired jobs`);
}
return toDelete.length;
}
private async evictOldest(): Promise<void> {
let oldestId: string | null = null;
let oldestTime = Infinity;
for (const [streamId, job] of this.jobs) {
if (job.createdAt < oldestTime) {
oldestTime = job.createdAt;
oldestId = streamId;
}
}
if (oldestId) {
logger.warn(`[InMemoryJobStore] Evicting oldest job: ${oldestId}`);
await this.deleteJob(oldestId);
}
}
/** Get job count (for monitoring) */
async getJobCount(): Promise<number> {
return this.jobs.size;
}
/** Get job count by status (for monitoring) */
async getJobCountByStatus(status: JobStatus): Promise<number> {
let count = 0;
for (const job of this.jobs.values()) {
if (job.status === status) {
count++;
}
}
return count;
}
async destroy(): Promise<void> {
if (this.cleanupInterval) {
clearInterval(this.cleanupInterval);
this.cleanupInterval = null;
}
this.jobs.clear();
this.contentState.clear();
this.userJobMap.clear();
logger.debug('[InMemoryJobStore] Destroyed');
}
/**
* Get active job IDs for a user.
* Returns conversation IDs of running jobs belonging to the user.
* Also performs self-healing cleanup: removes stale entries for jobs that no longer exist.
*/
async getActiveJobIdsByUser(userId: string): Promise<string[]> {
const trackedIds = this.userJobMap.get(userId);
if (!trackedIds || trackedIds.size === 0) {
return [];
}
const activeIds: string[] = [];
for (const streamId of trackedIds) {
const job = this.jobs.get(streamId);
// Only include if job exists AND is still running
if (job && job.status === 'running') {
activeIds.push(streamId);
} else {
// Self-healing: job completed/deleted but mapping wasn't cleaned - fix it now
trackedIds.delete(streamId);
}
}
// Clean up empty set
if (trackedIds.size === 0) {
this.userJobMap.delete(userId);
}
return activeIds;
}
// ===== Content State Methods =====
/**
* Set the graph reference for a job.
* Uses WeakRef to allow garbage collection when graph is no longer needed.
*/
setGraph(streamId: string, graph: StandardGraph): void {
const existing = this.contentState.get(streamId);
if (existing) {
existing.graphRef = new WeakRef(graph);
} else {
this.contentState.set(streamId, {
contentParts: [],
graphRef: new WeakRef(graph),
});
}
}
/**
* Set content parts reference for a job.
*/
setContentParts(streamId: string, contentParts: Agents.MessageContentComplex[]): void {
const existing = this.contentState.get(streamId);
if (existing) {
existing.contentParts = contentParts;
} else {
this.contentState.set(streamId, { contentParts, graphRef: null });
}
}
/**
* Get content parts for a job.
* Returns live content from stored reference.
*/
async getContentParts(streamId: string): Promise<{
content: Agents.MessageContentComplex[];
} | null> {
const state = this.contentState.get(streamId);
if (!state?.contentParts) {
return null;
}
return {
content: state.contentParts,
};
}
/**
* Get run steps for a job from graph.contentData.
* Uses WeakRef - may return empty if graph has been GC'd.
*/
async getRunSteps(streamId: string): Promise<Agents.RunStep[]> {
const state = this.contentState.get(streamId);
if (!state?.graphRef) {
return [];
}
// Dereference WeakRef - may return undefined if GC'd
const graph = state.graphRef.deref();
return graph?.contentData ?? [];
}
/**
* No-op for in-memory - content available via graph reference.
*/
async appendChunk(): Promise<void> {
// No-op: content available via graph reference
}
/**
* Clear content state for a job.
*/
clearContentState(streamId: string): void {
this.contentState.delete(streamId);
}
}

View file

@ -0,0 +1,318 @@
import type { Redis, Cluster } from 'ioredis';
import { logger } from '@librechat/data-schemas';
import type { IEventTransport } from '~/stream/interfaces/IJobStore';
/**
* Redis key prefixes for pub/sub channels
*/
const CHANNELS = {
/** Main event channel: stream:{streamId}:events (hash tag for cluster compatibility) */
events: (streamId: string) => `stream:{${streamId}}:events`,
};
/**
* Event types for pub/sub messages
*/
const EventTypes = {
CHUNK: 'chunk',
DONE: 'done',
ERROR: 'error',
} as const;
interface PubSubMessage {
type: (typeof EventTypes)[keyof typeof EventTypes];
data?: unknown;
error?: string;
}
/**
* Subscriber state for a stream
*/
interface StreamSubscribers {
count: number;
handlers: Map<
string,
{
onChunk: (event: unknown) => void;
onDone?: (event: unknown) => void;
onError?: (error: string) => void;
}
>;
allSubscribersLeftCallbacks: Array<() => void>;
}
/**
* Redis Pub/Sub implementation of IEventTransport.
* Enables real-time event delivery across multiple instances.
*
* Architecture (inspired by https://upstash.com/blog/resumable-llm-streams):
* - Publisher: Emits events to Redis channel when chunks arrive
* - Subscriber: Listens to Redis channel and forwards to SSE clients
* - Decoupled: Generator and consumer don't need direct connection
*
* Note: Requires TWO Redis connections - one for publishing, one for subscribing.
* This is a Redis limitation: a client in subscribe mode can't publish.
*
* @example
* ```ts
* const transport = new RedisEventTransport(publisherClient, subscriberClient);
* transport.subscribe(streamId, { onChunk: (e) => res.write(e) });
* transport.emitChunk(streamId, { text: 'Hello' });
* ```
*/
export class RedisEventTransport implements IEventTransport {
/** Redis client for publishing events */
private publisher: Redis | Cluster;
/** Redis client for subscribing to events (separate connection required) */
private subscriber: Redis | Cluster;
/** Track subscribers per stream */
private streams = new Map<string, StreamSubscribers>();
/** Track which channels we're subscribed to */
private subscribedChannels = new Set<string>();
/** Counter for generating unique subscriber IDs */
private subscriberIdCounter = 0;
/**
* Create a new Redis event transport.
*
* @param publisher - Redis client for publishing (can be shared)
* @param subscriber - Redis client for subscribing (must be dedicated)
*/
constructor(publisher: Redis | Cluster, subscriber: Redis | Cluster) {
this.publisher = publisher;
this.subscriber = subscriber;
// Set up message handler for all subscriptions
this.subscriber.on('message', (channel: string, message: string) => {
this.handleMessage(channel, message);
});
}
/**
* Handle incoming pub/sub message
*/
private handleMessage(channel: string, message: string): void {
// Extract streamId from channel name: stream:{streamId}:events
// Use regex to extract the hash tag content
const match = channel.match(/^stream:\{([^}]+)\}:events$/);
if (!match) {
return;
}
const streamId = match[1];
const streamState = this.streams.get(streamId);
if (!streamState) {
return;
}
try {
const parsed = JSON.parse(message) as PubSubMessage;
for (const [, handlers] of streamState.handlers) {
switch (parsed.type) {
case EventTypes.CHUNK:
handlers.onChunk(parsed.data);
break;
case EventTypes.DONE:
handlers.onDone?.(parsed.data);
break;
case EventTypes.ERROR:
handlers.onError?.(parsed.error ?? 'Unknown error');
break;
}
}
} catch (err) {
logger.error(`[RedisEventTransport] Failed to parse message:`, err);
}
}
/**
* Subscribe to events for a stream.
*
* On first subscriber for a stream, subscribes to the Redis channel.
* Returns unsubscribe function that cleans up when last subscriber leaves.
*/
subscribe(
streamId: string,
handlers: {
onChunk: (event: unknown) => void;
onDone?: (event: unknown) => void;
onError?: (error: string) => void;
},
): { unsubscribe: () => void } {
const channel = CHANNELS.events(streamId);
const subscriberId = `sub_${++this.subscriberIdCounter}`;
// Initialize stream state if needed
if (!this.streams.has(streamId)) {
this.streams.set(streamId, {
count: 0,
handlers: new Map(),
allSubscribersLeftCallbacks: [],
});
}
const streamState = this.streams.get(streamId)!;
streamState.count++;
streamState.handlers.set(subscriberId, handlers);
// Subscribe to Redis channel if this is first subscriber
if (!this.subscribedChannels.has(channel)) {
this.subscribedChannels.add(channel);
this.subscriber.subscribe(channel).catch((err) => {
logger.error(`[RedisEventTransport] Failed to subscribe to ${channel}:`, err);
});
}
// Return unsubscribe function
return {
unsubscribe: () => {
const state = this.streams.get(streamId);
if (!state) {
return;
}
state.handlers.delete(subscriberId);
state.count--;
// If last subscriber left, unsubscribe from Redis and notify
if (state.count === 0) {
this.subscriber.unsubscribe(channel).catch((err) => {
logger.error(`[RedisEventTransport] Failed to unsubscribe from ${channel}:`, err);
});
this.subscribedChannels.delete(channel);
// Call all-subscribers-left callbacks
for (const callback of state.allSubscribersLeftCallbacks) {
try {
callback();
} catch (err) {
logger.error(`[RedisEventTransport] Error in allSubscribersLeft callback:`, err);
}
}
this.streams.delete(streamId);
}
},
};
}
/**
* Publish a chunk event to all subscribers across all instances.
*/
emitChunk(streamId: string, event: unknown): void {
const channel = CHANNELS.events(streamId);
const message: PubSubMessage = { type: EventTypes.CHUNK, data: event };
this.publisher.publish(channel, JSON.stringify(message)).catch((err) => {
logger.error(`[RedisEventTransport] Failed to publish chunk:`, err);
});
}
/**
* Publish a done event to all subscribers.
*/
emitDone(streamId: string, event: unknown): void {
const channel = CHANNELS.events(streamId);
const message: PubSubMessage = { type: EventTypes.DONE, data: event };
this.publisher.publish(channel, JSON.stringify(message)).catch((err) => {
logger.error(`[RedisEventTransport] Failed to publish done:`, err);
});
}
/**
* Publish an error event to all subscribers.
*/
emitError(streamId: string, error: string): void {
const channel = CHANNELS.events(streamId);
const message: PubSubMessage = { type: EventTypes.ERROR, error };
this.publisher.publish(channel, JSON.stringify(message)).catch((err) => {
logger.error(`[RedisEventTransport] Failed to publish error:`, err);
});
}
/**
* Get subscriber count for a stream (local instance only).
*
* Note: In a multi-instance setup, this only returns local subscriber count.
* For global count, would need to track in Redis (e.g., with a counter key).
*/
getSubscriberCount(streamId: string): number {
return this.streams.get(streamId)?.count ?? 0;
}
/**
* Check if this is the first subscriber (local instance only).
*/
isFirstSubscriber(streamId: string): boolean {
return this.getSubscriberCount(streamId) === 1;
}
/**
* Register callback for when all subscribers leave.
*/
onAllSubscribersLeft(streamId: string, callback: () => void): void {
const state = this.streams.get(streamId);
if (state) {
state.allSubscribersLeftCallbacks.push(callback);
} else {
// Create state just for the callback
this.streams.set(streamId, {
count: 0,
handlers: new Map(),
allSubscribersLeftCallbacks: [callback],
});
}
}
/**
* Get all tracked stream IDs (for orphan cleanup)
*/
getTrackedStreamIds(): string[] {
return Array.from(this.streams.keys());
}
/**
* Cleanup resources for a specific stream.
*/
cleanup(streamId: string): void {
const channel = CHANNELS.events(streamId);
const state = this.streams.get(streamId);
if (state) {
// Clear all handlers
state.handlers.clear();
state.allSubscribersLeftCallbacks = [];
}
// Unsubscribe from Redis channel
if (this.subscribedChannels.has(channel)) {
this.subscriber.unsubscribe(channel).catch((err) => {
logger.error(`[RedisEventTransport] Failed to cleanup ${channel}:`, err);
});
this.subscribedChannels.delete(channel);
}
this.streams.delete(streamId);
}
/**
* Destroy all resources.
*/
destroy(): void {
// Unsubscribe from all channels
for (const channel of this.subscribedChannels) {
this.subscriber.unsubscribe(channel).catch(() => {
// Ignore errors during shutdown
});
}
this.subscribedChannels.clear();
this.streams.clear();
// Note: Don't close Redis connections - they may be shared
logger.info('[RedisEventTransport] Destroyed');
}
}

View file

@ -0,0 +1,835 @@
import { logger } from '@librechat/data-schemas';
import { createContentAggregator } from '@librechat/agents';
import type { IJobStore, SerializableJobData, JobStatus } from '~/stream/interfaces/IJobStore';
import type { StandardGraph } from '@librechat/agents';
import type { Agents } from 'librechat-data-provider';
import type { Redis, Cluster } from 'ioredis';
/**
* Key prefixes for Redis storage.
* All keys include the streamId for easy cleanup.
* Note: streamId === conversationId, so no separate mapping needed.
*
* IMPORTANT: Uses hash tags {streamId} for Redis Cluster compatibility.
* All keys for the same stream hash to the same slot, enabling:
* - Pipeline operations across related keys
* - Atomic multi-key operations
*/
const KEYS = {
/** Job metadata: stream:{streamId}:job */
job: (streamId: string) => `stream:{${streamId}}:job`,
/** Chunk stream (Redis Streams): stream:{streamId}:chunks */
chunks: (streamId: string) => `stream:{${streamId}}:chunks`,
/** Run steps: stream:{streamId}:runsteps */
runSteps: (streamId: string) => `stream:{${streamId}}:runsteps`,
/** Running jobs set for cleanup (global set - single slot) */
runningJobs: 'stream:running',
/** User's active jobs set: stream:user:{userId}:jobs */
userJobs: (userId: string) => `stream:user:{${userId}}:jobs`,
};
/**
* Default TTL values in seconds.
* Can be overridden via constructor options.
*/
const DEFAULT_TTL = {
/** TTL for completed jobs (5 minutes) */
completed: 300,
/** TTL for running jobs/chunks (20 minutes - failsafe for crashed jobs) */
running: 1200,
/** TTL for chunks after completion (0 = delete immediately) */
chunksAfterComplete: 0,
/** TTL for run steps after completion (0 = delete immediately) */
runStepsAfterComplete: 0,
};
/**
* Redis implementation of IJobStore.
* Enables horizontal scaling with multi-instance deployments.
*
* Storage strategy:
* - Job metadata: Redis Hash (fast field access)
* - Chunks: Redis Streams (append-only, efficient for streaming)
* - Run steps: Redis String (JSON serialized)
*
* Note: streamId === conversationId, so getJob(conversationId) works directly.
*
* @example
* ```ts
* import { ioredisClient } from '~/cache';
* const store = new RedisJobStore(ioredisClient);
* await store.initialize();
* ```
*/
/**
* Configuration options for RedisJobStore
*/
export interface RedisJobStoreOptions {
/** TTL for completed jobs in seconds (default: 300 = 5 minutes) */
completedTtl?: number;
/** TTL for running jobs/chunks in seconds (default: 1200 = 20 minutes) */
runningTtl?: number;
/** TTL for chunks after completion in seconds (default: 0 = delete immediately) */
chunksAfterCompleteTtl?: number;
/** TTL for run steps after completion in seconds (default: 0 = delete immediately) */
runStepsAfterCompleteTtl?: number;
}
export class RedisJobStore implements IJobStore {
private redis: Redis | Cluster;
private cleanupInterval: NodeJS.Timeout | null = null;
private ttl: typeof DEFAULT_TTL;
/** Whether Redis client is in cluster mode (affects pipeline usage) */
private isCluster: boolean;
/**
* Local cache for graph references on THIS instance.
* Enables fast reconnects when client returns to the same server.
* Uses WeakRef to allow garbage collection when graph is no longer needed.
*/
private localGraphCache = new Map<string, WeakRef<StandardGraph>>();
/** Cleanup interval in ms (1 minute) */
private cleanupIntervalMs = 60000;
constructor(redis: Redis | Cluster, options?: RedisJobStoreOptions) {
this.redis = redis;
this.ttl = {
completed: options?.completedTtl ?? DEFAULT_TTL.completed,
running: options?.runningTtl ?? DEFAULT_TTL.running,
chunksAfterComplete: options?.chunksAfterCompleteTtl ?? DEFAULT_TTL.chunksAfterComplete,
runStepsAfterComplete: options?.runStepsAfterCompleteTtl ?? DEFAULT_TTL.runStepsAfterComplete,
};
// Detect cluster mode using ioredis's isCluster property
this.isCluster = (redis as Cluster).isCluster === true;
}
async initialize(): Promise<void> {
if (this.cleanupInterval) {
return;
}
// Start periodic cleanup
this.cleanupInterval = setInterval(() => {
this.cleanup().catch((err) => {
logger.error('[RedisJobStore] Cleanup error:', err);
});
}, this.cleanupIntervalMs);
if (this.cleanupInterval.unref) {
this.cleanupInterval.unref();
}
logger.info('[RedisJobStore] Initialized with cleanup interval');
}
async createJob(
streamId: string,
userId: string,
conversationId?: string,
): Promise<SerializableJobData> {
const job: SerializableJobData = {
streamId,
userId,
status: 'running',
createdAt: Date.now(),
conversationId,
syncSent: false,
};
const key = KEYS.job(streamId);
const userJobsKey = KEYS.userJobs(userId);
// For cluster mode, we can't pipeline keys on different slots
// The job key uses hash tag {streamId}, runningJobs and userJobs are on different slots
if (this.isCluster) {
await this.redis.hmset(key, this.serializeJob(job));
await this.redis.expire(key, this.ttl.running);
await this.redis.sadd(KEYS.runningJobs, streamId);
await this.redis.sadd(userJobsKey, streamId);
} else {
const pipeline = this.redis.pipeline();
pipeline.hmset(key, this.serializeJob(job));
pipeline.expire(key, this.ttl.running);
pipeline.sadd(KEYS.runningJobs, streamId);
pipeline.sadd(userJobsKey, streamId);
await pipeline.exec();
}
logger.debug(`[RedisJobStore] Created job: ${streamId}`);
return job;
}
async getJob(streamId: string): Promise<SerializableJobData | null> {
const data = await this.redis.hgetall(KEYS.job(streamId));
if (!data || Object.keys(data).length === 0) {
return null;
}
return this.deserializeJob(data);
}
async updateJob(streamId: string, updates: Partial<SerializableJobData>): Promise<void> {
const key = KEYS.job(streamId);
const exists = await this.redis.exists(key);
if (!exists) {
return;
}
const serialized = this.serializeJob(updates as SerializableJobData);
if (Object.keys(serialized).length === 0) {
return;
}
await this.redis.hmset(key, serialized);
// If status changed to complete/error/aborted, update TTL and remove from running set
// Note: userJobs cleanup is handled lazily via self-healing in getActiveJobIdsByUser
if (updates.status && ['complete', 'error', 'aborted'].includes(updates.status)) {
// In cluster mode, separate runningJobs (global) from stream-specific keys
if (this.isCluster) {
await this.redis.expire(key, this.ttl.completed);
await this.redis.srem(KEYS.runningJobs, streamId);
if (this.ttl.chunksAfterComplete === 0) {
await this.redis.del(KEYS.chunks(streamId));
} else {
await this.redis.expire(KEYS.chunks(streamId), this.ttl.chunksAfterComplete);
}
if (this.ttl.runStepsAfterComplete === 0) {
await this.redis.del(KEYS.runSteps(streamId));
} else {
await this.redis.expire(KEYS.runSteps(streamId), this.ttl.runStepsAfterComplete);
}
} else {
const pipeline = this.redis.pipeline();
pipeline.expire(key, this.ttl.completed);
pipeline.srem(KEYS.runningJobs, streamId);
if (this.ttl.chunksAfterComplete === 0) {
pipeline.del(KEYS.chunks(streamId));
} else {
pipeline.expire(KEYS.chunks(streamId), this.ttl.chunksAfterComplete);
}
if (this.ttl.runStepsAfterComplete === 0) {
pipeline.del(KEYS.runSteps(streamId));
} else {
pipeline.expire(KEYS.runSteps(streamId), this.ttl.runStepsAfterComplete);
}
await pipeline.exec();
}
}
}
async deleteJob(streamId: string): Promise<void> {
// Clear local caches
this.localGraphCache.delete(streamId);
// Note: userJobs cleanup is handled lazily via self-healing in getActiveJobIdsByUser
// In cluster mode, separate runningJobs (global) from stream-specific keys (same slot)
if (this.isCluster) {
// Stream-specific keys all hash to same slot due to {streamId}
const pipeline = this.redis.pipeline();
pipeline.del(KEYS.job(streamId));
pipeline.del(KEYS.chunks(streamId));
pipeline.del(KEYS.runSteps(streamId));
await pipeline.exec();
// Global set is on different slot - execute separately
await this.redis.srem(KEYS.runningJobs, streamId);
} else {
const pipeline = this.redis.pipeline();
pipeline.del(KEYS.job(streamId));
pipeline.del(KEYS.chunks(streamId));
pipeline.del(KEYS.runSteps(streamId));
pipeline.srem(KEYS.runningJobs, streamId);
await pipeline.exec();
}
logger.debug(`[RedisJobStore] Deleted job: ${streamId}`);
}
async hasJob(streamId: string): Promise<boolean> {
const exists = await this.redis.exists(KEYS.job(streamId));
return exists === 1;
}
async getRunningJobs(): Promise<SerializableJobData[]> {
const streamIds = await this.redis.smembers(KEYS.runningJobs);
if (streamIds.length === 0) {
return [];
}
const jobs: SerializableJobData[] = [];
for (const streamId of streamIds) {
const job = await this.getJob(streamId);
if (job && job.status === 'running') {
jobs.push(job);
}
}
return jobs;
}
async cleanup(): Promise<number> {
const now = Date.now();
const streamIds = await this.redis.smembers(KEYS.runningJobs);
let cleaned = 0;
// Clean up stale local graph cache entries (WeakRefs that were collected)
for (const [streamId, graphRef] of this.localGraphCache) {
if (!graphRef.deref()) {
this.localGraphCache.delete(streamId);
}
}
for (const streamId of streamIds) {
const job = await this.getJob(streamId);
// Job no longer exists (TTL expired) - remove from set
if (!job) {
await this.redis.srem(KEYS.runningJobs, streamId);
this.localGraphCache.delete(streamId);
cleaned++;
continue;
}
// Job completed but still in running set (shouldn't happen, but handle it)
if (job.status !== 'running') {
await this.redis.srem(KEYS.runningJobs, streamId);
this.localGraphCache.delete(streamId);
cleaned++;
continue;
}
// Stale running job (failsafe - running for > configured TTL)
if (now - job.createdAt > this.ttl.running * 1000) {
logger.warn(`[RedisJobStore] Cleaning up stale job: ${streamId}`);
await this.deleteJob(streamId);
cleaned++;
}
}
if (cleaned > 0) {
logger.debug(`[RedisJobStore] Cleaned up ${cleaned} jobs`);
}
return cleaned;
}
async getJobCount(): Promise<number> {
// This is approximate - counts jobs in running set + scans for job keys
// For exact count, would need to scan all job:* keys
const runningCount = await this.redis.scard(KEYS.runningJobs);
return runningCount;
}
async getJobCountByStatus(status: JobStatus): Promise<number> {
if (status === 'running') {
return this.redis.scard(KEYS.runningJobs);
}
// For other statuses, we'd need to scan - return 0 for now
// In production, consider maintaining separate sets per status if needed
return 0;
}
/**
* Get active job IDs for a user.
* Returns conversation IDs of running jobs belonging to the user.
* Also performs self-healing cleanup: removes stale entries for jobs that no longer exist.
*
* @param userId - The user ID to query
* @returns Array of conversation IDs with active jobs
*/
async getActiveJobIdsByUser(userId: string): Promise<string[]> {
const userJobsKey = KEYS.userJobs(userId);
const trackedIds = await this.redis.smembers(userJobsKey);
if (trackedIds.length === 0) {
return [];
}
const activeIds: string[] = [];
const staleIds: string[] = [];
for (const streamId of trackedIds) {
const job = await this.getJob(streamId);
// Only include if job exists AND is still running
if (job && job.status === 'running') {
activeIds.push(streamId);
} else {
// Self-healing: job completed/deleted but mapping wasn't cleaned - mark for removal
staleIds.push(streamId);
}
}
// Clean up stale entries
if (staleIds.length > 0) {
await this.redis.srem(userJobsKey, ...staleIds);
logger.debug(
`[RedisJobStore] Self-healed ${staleIds.length} stale job entries for user ${userId}`,
);
}
return activeIds;
}
async destroy(): Promise<void> {
if (this.cleanupInterval) {
clearInterval(this.cleanupInterval);
this.cleanupInterval = null;
}
// Clear local caches
this.localGraphCache.clear();
// Don't close the Redis connection - it's shared
logger.info('[RedisJobStore] Destroyed');
}
// ===== Content State Methods =====
// For Redis, content is primarily reconstructed from chunks.
// However, we keep a LOCAL graph cache for fast same-instance reconnects.
/**
* Store graph reference in local cache.
* This enables fast reconnects when client returns to the same instance.
* Falls back to Redis chunk reconstruction for cross-instance reconnects.
*
* @param streamId - The stream identifier
* @param graph - The graph instance (stored as WeakRef)
*/
setGraph(streamId: string, graph: StandardGraph): void {
this.localGraphCache.set(streamId, new WeakRef(graph));
}
/**
* No-op for Redis - content parts are reconstructed from chunks.
* Metadata (agentId, groupId) is embedded directly on content parts by the agent runtime.
*/
setContentParts(_streamId: string, _contentParts: Agents.MessageContentComplex[]): void {
// Content parts are reconstructed from chunks during getContentParts
// No separate storage needed
}
/**
* Get aggregated content - tries local cache first, falls back to Redis reconstruction.
*
* Optimization: If this instance has the live graph (same-instance reconnect),
* we return the content directly without Redis round-trip.
* For cross-instance reconnects, we reconstruct from Redis Streams.
*
* @param streamId - The stream identifier
* @returns Content parts array or null if not found
*/
async getContentParts(streamId: string): Promise<{
content: Agents.MessageContentComplex[];
} | null> {
// 1. Try local graph cache first (fast path for same-instance reconnect)
const graphRef = this.localGraphCache.get(streamId);
if (graphRef) {
const graph = graphRef.deref();
if (graph) {
const localParts = graph.getContentParts();
if (localParts && localParts.length > 0) {
return {
content: localParts,
};
}
} else {
// WeakRef was collected, remove from cache
this.localGraphCache.delete(streamId);
}
}
// 2. Fall back to Redis chunk reconstruction (cross-instance reconnect)
const chunks = await this.getChunks(streamId);
if (chunks.length === 0) {
return null;
}
// Use the same content aggregator as live streaming
const { contentParts, aggregateContent } = createContentAggregator();
// Valid event types for content aggregation
const validEvents = new Set([
'on_run_step',
'on_message_delta',
'on_reasoning_delta',
'on_run_step_delta',
'on_run_step_completed',
'on_agent_update',
]);
for (const chunk of chunks) {
const event = chunk as { event?: string; data?: unknown };
if (!event.event || !event.data || !validEvents.has(event.event)) {
continue;
}
// Pass event string directly - GraphEvents values are lowercase strings
// eslint-disable-next-line @typescript-eslint/no-explicit-any
aggregateContent({ event: event.event as any, data: event.data as any });
}
// Filter out undefined entries
const filtered: Agents.MessageContentComplex[] = [];
for (const part of contentParts) {
if (part !== undefined) {
filtered.push(part);
}
}
return {
content: filtered,
};
}
/**
* Get run steps - tries local cache first, falls back to Redis.
*
* Optimization: If this instance has the live graph, we get run steps
* directly without Redis round-trip.
*
* @param streamId - The stream identifier
* @returns Run steps array
*/
async getRunSteps(streamId: string): Promise<Agents.RunStep[]> {
// 1. Try local graph cache first (fast path for same-instance reconnect)
const graphRef = this.localGraphCache.get(streamId);
if (graphRef) {
const graph = graphRef.deref();
if (graph) {
const localSteps = graph.getRunSteps();
if (localSteps && localSteps.length > 0) {
return localSteps;
}
}
// Note: Don't delete from cache here - graph may still be valid
// but just not have run steps yet
}
// 2. Fall back to Redis (cross-instance reconnect)
const key = KEYS.runSteps(streamId);
const data = await this.redis.get(key);
if (!data) {
return [];
}
try {
return JSON.parse(data);
} catch {
return [];
}
}
/**
* Clear content state for a job.
* Removes both local cache and Redis data.
*/
clearContentState(streamId: string): void {
// Clear local caches immediately
this.localGraphCache.delete(streamId);
// Fire and forget - async cleanup for Redis
this.clearContentStateAsync(streamId).catch((err) => {
logger.error(`[RedisJobStore] Failed to clear content state for ${streamId}:`, err);
});
}
/**
* Clear content state async.
*/
private async clearContentStateAsync(streamId: string): Promise<void> {
const pipeline = this.redis.pipeline();
pipeline.del(KEYS.chunks(streamId));
pipeline.del(KEYS.runSteps(streamId));
await pipeline.exec();
}
/**
* Append a streaming chunk to Redis Stream.
* Uses XADD for efficient append-only storage.
* Sets TTL on first chunk to ensure cleanup if job crashes.
*/
async appendChunk(streamId: string, event: unknown): Promise<void> {
const key = KEYS.chunks(streamId);
const added = await this.redis.xadd(key, '*', 'event', JSON.stringify(event));
// Set TTL on first chunk (when stream is created)
// Subsequent chunks inherit the stream's TTL
if (added) {
const len = await this.redis.xlen(key);
if (len === 1) {
await this.redis.expire(key, this.ttl.running);
}
}
}
/**
* Get all chunks from Redis Stream.
*/
private async getChunks(streamId: string): Promise<unknown[]> {
const key = KEYS.chunks(streamId);
const entries = await this.redis.xrange(key, '-', '+');
return entries
.map(([, fields]) => {
const eventIdx = fields.indexOf('event');
if (eventIdx >= 0 && eventIdx + 1 < fields.length) {
try {
return JSON.parse(fields[eventIdx + 1]);
} catch {
return null;
}
}
return null;
})
.filter(Boolean);
}
/**
* Save run steps for resume state.
*/
async saveRunSteps(streamId: string, runSteps: Agents.RunStep[]): Promise<void> {
const key = KEYS.runSteps(streamId);
await this.redis.set(key, JSON.stringify(runSteps), 'EX', this.ttl.running);
}
// ===== Consumer Group Methods =====
// These enable tracking which chunks each client has seen.
// Based on https://upstash.com/blog/resumable-llm-streams
/**
* Create a consumer group for a stream.
* Used to track which chunks a client has already received.
*
* @param streamId - The stream identifier
* @param groupName - Unique name for the consumer group (e.g., session ID)
* @param startFrom - Where to start reading ('0' = from beginning, '$' = only new)
*/
async createConsumerGroup(
streamId: string,
groupName: string,
startFrom: '0' | '$' = '0',
): Promise<void> {
const key = KEYS.chunks(streamId);
try {
await this.redis.xgroup('CREATE', key, groupName, startFrom, 'MKSTREAM');
logger.debug(`[RedisJobStore] Created consumer group ${groupName} for ${streamId}`);
} catch (err) {
// BUSYGROUP error means group already exists - that's fine
const error = err as Error;
if (!error.message?.includes('BUSYGROUP')) {
throw err;
}
}
}
/**
* Read chunks from a consumer group (only unseen chunks).
* This is the key to the resumable stream pattern.
*
* @param streamId - The stream identifier
* @param groupName - Consumer group name
* @param consumerName - Name of the consumer within the group
* @param count - Maximum number of chunks to read (default: all available)
* @returns Array of { id, event } where id is the Redis stream entry ID
*/
async readChunksFromGroup(
streamId: string,
groupName: string,
consumerName: string = 'consumer-1',
count?: number,
): Promise<Array<{ id: string; event: unknown }>> {
const key = KEYS.chunks(streamId);
try {
// XREADGROUP GROUP groupName consumerName [COUNT count] STREAMS key >
// The '>' means only read new messages not yet delivered to this consumer
let result;
if (count) {
result = await this.redis.xreadgroup(
'GROUP',
groupName,
consumerName,
'COUNT',
count,
'STREAMS',
key,
'>',
);
} else {
result = await this.redis.xreadgroup('GROUP', groupName, consumerName, 'STREAMS', key, '>');
}
if (!result || result.length === 0) {
return [];
}
// Result format: [[streamKey, [[id, [field, value, ...]], ...]]]
const [, messages] = result[0] as [string, Array<[string, string[]]>];
const chunks: Array<{ id: string; event: unknown }> = [];
for (const [id, fields] of messages) {
const eventIdx = fields.indexOf('event');
if (eventIdx >= 0 && eventIdx + 1 < fields.length) {
try {
chunks.push({
id,
event: JSON.parse(fields[eventIdx + 1]),
});
} catch {
// Skip malformed entries
}
}
}
return chunks;
} catch (err) {
const error = err as Error;
// NOGROUP error means the group doesn't exist yet
if (error.message?.includes('NOGROUP')) {
return [];
}
throw err;
}
}
/**
* Acknowledge that chunks have been processed.
* This tells Redis we've successfully delivered these chunks to the client.
*
* @param streamId - The stream identifier
* @param groupName - Consumer group name
* @param messageIds - Array of Redis stream entry IDs to acknowledge
*/
async acknowledgeChunks(
streamId: string,
groupName: string,
messageIds: string[],
): Promise<void> {
if (messageIds.length === 0) {
return;
}
const key = KEYS.chunks(streamId);
await this.redis.xack(key, groupName, ...messageIds);
}
/**
* Delete a consumer group.
* Called when a client disconnects and won't reconnect.
*
* @param streamId - The stream identifier
* @param groupName - Consumer group name to delete
*/
async deleteConsumerGroup(streamId: string, groupName: string): Promise<void> {
const key = KEYS.chunks(streamId);
try {
await this.redis.xgroup('DESTROY', key, groupName);
logger.debug(`[RedisJobStore] Deleted consumer group ${groupName} for ${streamId}`);
} catch {
// Ignore errors - group may not exist
}
}
/**
* Get pending chunks for a consumer (chunks delivered but not acknowledged).
* Useful for recovering from crashes.
*
* @param streamId - The stream identifier
* @param groupName - Consumer group name
* @param consumerName - Consumer name
*/
async getPendingChunks(
streamId: string,
groupName: string,
consumerName: string = 'consumer-1',
): Promise<Array<{ id: string; event: unknown }>> {
const key = KEYS.chunks(streamId);
try {
// Read pending messages (delivered but not acked) by using '0' instead of '>'
const result = await this.redis.xreadgroup(
'GROUP',
groupName,
consumerName,
'STREAMS',
key,
'0',
);
if (!result || result.length === 0) {
return [];
}
const [, messages] = result[0] as [string, Array<[string, string[]]>];
const chunks: Array<{ id: string; event: unknown }> = [];
for (const [id, fields] of messages) {
const eventIdx = fields.indexOf('event');
if (eventIdx >= 0 && eventIdx + 1 < fields.length) {
try {
chunks.push({
id,
event: JSON.parse(fields[eventIdx + 1]),
});
} catch {
// Skip malformed entries
}
}
}
return chunks;
} catch {
return [];
}
}
/**
* Serialize job data for Redis hash storage.
* Converts complex types to strings.
*/
private serializeJob(job: Partial<SerializableJobData>): Record<string, string> {
const result: Record<string, string> = {};
for (const [key, value] of Object.entries(job)) {
if (value === undefined) {
continue;
}
if (typeof value === 'object') {
result[key] = JSON.stringify(value);
} else if (typeof value === 'boolean') {
result[key] = value ? '1' : '0';
} else {
result[key] = String(value);
}
}
return result;
}
/**
* Deserialize job data from Redis hash.
*/
private deserializeJob(data: Record<string, string>): SerializableJobData {
return {
streamId: data.streamId,
userId: data.userId,
status: data.status as JobStatus,
createdAt: parseInt(data.createdAt, 10),
completedAt: data.completedAt ? parseInt(data.completedAt, 10) : undefined,
conversationId: data.conversationId || undefined,
error: data.error || undefined,
userMessage: data.userMessage ? JSON.parse(data.userMessage) : undefined,
responseMessageId: data.responseMessageId || undefined,
sender: data.sender || undefined,
syncSent: data.syncSent === '1',
finalEvent: data.finalEvent || undefined,
endpoint: data.endpoint || undefined,
iconURL: data.iconURL || undefined,
model: data.model || undefined,
promptTokens: data.promptTokens ? parseInt(data.promptTokens, 10) : undefined,
};
}
}

View file

@ -0,0 +1,4 @@
export * from './InMemoryJobStore';
export * from './InMemoryEventTransport';
export * from './RedisJobStore';
export * from './RedisEventTransport';

View file

@ -0,0 +1,22 @@
export {
GenerationJobManager,
GenerationJobManagerClass,
type GenerationJobManagerOptions,
} from './GenerationJobManager';
export type {
AbortResult,
SerializableJobData,
JobStatus,
IJobStore,
IEventTransport,
} from './interfaces/IJobStore';
export { createStreamServices } from './createStreamServices';
export type { StreamServicesConfig, StreamServices } from './createStreamServices';
// Implementations (for advanced use cases)
export { InMemoryJobStore } from './implementations/InMemoryJobStore';
export { InMemoryEventTransport } from './implementations/InMemoryEventTransport';
export { RedisJobStore } from './implementations/RedisJobStore';
export { RedisEventTransport } from './implementations/RedisEventTransport';

View file

@ -0,0 +1,256 @@
import type { Agents } from 'librechat-data-provider';
import type { StandardGraph } from '@librechat/agents';
/**
* Job status enum
*/
export type JobStatus = 'running' | 'complete' | 'error' | 'aborted';
/**
* Serializable job data - no object references, suitable for Redis/external storage
*/
export interface SerializableJobData {
streamId: string;
userId: string;
status: JobStatus;
createdAt: number;
completedAt?: number;
conversationId?: string;
error?: string;
/** User message metadata */
userMessage?: {
messageId: string;
parentMessageId?: string;
conversationId?: string;
text?: string;
};
/** Response message ID for reconnection */
responseMessageId?: string;
/** Sender name for UI display */
sender?: string;
/** Whether sync has been sent to a client */
syncSent: boolean;
/** Serialized final event for replay */
finalEvent?: string;
/** Endpoint metadata for abort handling - avoids storing functions */
endpoint?: string;
iconURL?: string;
model?: string;
promptTokens?: number;
}
/**
* Result returned from aborting a job - contains all data needed
* for token spending and message saving without storing callbacks
*/
export interface AbortResult {
/** Whether the abort was successful */
success: boolean;
/** The job data at time of abort */
jobData: SerializableJobData | null;
/** Aggregated content from the stream */
content: Agents.MessageContentComplex[];
/** Final event to send to client */
finalEvent: unknown;
}
/**
* Resume state for reconnecting clients
*/
export interface ResumeState {
runSteps: Agents.RunStep[];
aggregatedContent: Agents.MessageContentComplex[];
userMessage?: SerializableJobData['userMessage'];
responseMessageId?: string;
conversationId?: string;
sender?: string;
}
/**
* Interface for job storage backend.
* Implementations can use in-memory Map, Redis, KV store, etc.
*
* Content state is tied to jobs:
* - In-memory: Holds WeakRef to graph for live content/run steps access
* - Redis: Persists chunks, reconstructs content on reconnect
*
* This consolidates job metadata + content state into a single interface.
*/
export interface IJobStore {
/** Initialize the store (e.g., connect to Redis, start cleanup intervals) */
initialize(): Promise<void>;
/** Create a new job */
createJob(
streamId: string,
userId: string,
conversationId?: string,
): Promise<SerializableJobData>;
/** Get a job by streamId (streamId === conversationId) */
getJob(streamId: string): Promise<SerializableJobData | null>;
/** Update job data */
updateJob(streamId: string, updates: Partial<SerializableJobData>): Promise<void>;
/** Delete a job */
deleteJob(streamId: string): Promise<void>;
/** Check if job exists */
hasJob(streamId: string): Promise<boolean>;
/** Get all running jobs (for cleanup) */
getRunningJobs(): Promise<SerializableJobData[]>;
/** Cleanup expired jobs */
cleanup(): Promise<number>;
/** Get total job count */
getJobCount(): Promise<number>;
/** Get job count by status */
getJobCountByStatus(status: JobStatus): Promise<number>;
/** Destroy the store and release resources */
destroy(): Promise<void>;
/**
* Get active job IDs for a user.
* Returns conversation IDs of running jobs belonging to the user.
* Also performs self-healing cleanup of stale entries.
*
* @param userId - The user ID to query
* @returns Array of conversation IDs with active jobs
*/
getActiveJobIdsByUser(userId: string): Promise<string[]>;
// ===== Content State Methods =====
// These methods manage volatile content state tied to each job.
// In-memory: Uses WeakRef to graph for live access
// Redis: Persists chunks and reconstructs on demand
/**
* Set the graph reference for a job (in-memory only).
* The graph provides live access to contentParts and contentData (run steps).
*
* In-memory: Stores WeakRef to graph
* Redis: No-op (graph not transferable, uses chunks instead)
*
* @param streamId - The stream identifier
* @param graph - The StandardGraph instance
*/
setGraph(streamId: string, graph: StandardGraph): void;
/**
* Set content parts reference for a job.
*
* In-memory: Stores direct reference to content array
* Redis: No-op (content built from chunks)
*
* @param streamId - The stream identifier
* @param contentParts - The content parts array
*/
setContentParts(streamId: string, contentParts: Agents.MessageContentComplex[]): void;
/**
* Get aggregated content for a job.
*
* In-memory: Returns live content from graph.contentParts or stored reference
* Redis: Reconstructs from stored chunks
*
* @param streamId - The stream identifier
* @returns Content parts or null if not available
*/
getContentParts(streamId: string): Promise<{
content: Agents.MessageContentComplex[];
} | null>;
/**
* Get run steps for a job (for resume state).
*
* In-memory: Returns live run steps from graph.contentData
* Redis: Fetches from persistent storage
*
* @param streamId - The stream identifier
* @returns Run steps or empty array
*/
getRunSteps(streamId: string): Promise<Agents.RunStep[]>;
/**
* Append a streaming chunk for later reconstruction.
*
* In-memory: No-op (content available via graph reference)
* Redis: Uses XADD for append-only log efficiency
*
* @param streamId - The stream identifier
* @param event - The SSE event to append
*/
appendChunk(streamId: string, event: unknown): Promise<void>;
/**
* Clear all content state for a job.
* Called on job completion/cleanup.
*
* @param streamId - The stream identifier
*/
clearContentState(streamId: string): void;
/**
* Save run steps to persistent storage.
* In-memory: No-op (run steps accessed via graph reference)
* Redis: Persists for resume across instances
*
* @param streamId - The stream identifier
* @param runSteps - Run steps to save
*/
saveRunSteps?(streamId: string, runSteps: Agents.RunStep[]): Promise<void>;
}
/**
* Interface for pub/sub event transport.
* Implementations can use EventEmitter, Redis Pub/Sub, etc.
*/
export interface IEventTransport {
/** Subscribe to events for a stream */
subscribe(
streamId: string,
handlers: {
onChunk: (event: unknown) => void;
onDone?: (event: unknown) => void;
onError?: (error: string) => void;
},
): { unsubscribe: () => void };
/** Publish a chunk event */
emitChunk(streamId: string, event: unknown): void;
/** Publish a done event */
emitDone(streamId: string, event: unknown): void;
/** Publish an error event */
emitError(streamId: string, error: string): void;
/** Get subscriber count for a stream */
getSubscriberCount(streamId: string): number;
/** Check if this is the first subscriber (for ready signaling) */
isFirstSubscriber(streamId: string): boolean;
/** Listen for all subscribers leaving */
onAllSubscribersLeft(streamId: string, callback: () => void): void;
/** Cleanup transport resources for a specific stream */
cleanup(streamId: string): void;
/** Get all tracked stream IDs (for orphan cleanup) */
getTrackedStreamIds(): string[];
/** Destroy all transport resources */
destroy(): void;
}

View file

@ -0,0 +1 @@
export * from './IJobStore';

View file

@ -29,12 +29,25 @@ export interface AudioProcessingResult {
bytes: number;
}
/** Google video block format */
export interface GoogleVideoBlock {
type: 'media';
mimeType: string;
data: string;
}
/** OpenRouter video block format */
export interface OpenRouterVideoBlock {
type: 'video_url';
video_url: {
url: string;
};
}
export type VideoBlock = GoogleVideoBlock | OpenRouterVideoBlock;
export interface VideoResult {
videos: Array<{
type: string;
mimeType: string;
data: string;
}>;
videos: VideoBlock[];
files: Array<{
file_id?: string;
temp_file_id?: string;
@ -100,12 +113,26 @@ export interface DocumentResult {
}>;
}
export interface AudioResult {
audios: Array<{
type: string;
mimeType: string;
/** Google audio block format */
export interface GoogleAudioBlock {
type: 'media';
mimeType: string;
data: string;
}
/** OpenRouter audio block format */
export interface OpenRouterAudioBlock {
type: 'input_audio';
input_audio: {
data: string;
}>;
format: string;
};
}
export type AudioBlock = GoogleAudioBlock | OpenRouterAudioBlock;
export interface AudioResult {
audios: AudioBlock[];
files: Array<{
file_id?: string;
temp_file_id?: string;

View file

@ -13,3 +13,4 @@ export type * from './openai';
export * from './prompts';
export * from './run';
export * from './tokens';
export * from './stream';

View file

@ -0,0 +1,49 @@
import type { EventEmitter } from 'events';
import type { Agents } from 'librechat-data-provider';
import type { ServerSentEvent } from '~/types';
export interface GenerationJobMetadata {
userId: string;
conversationId?: string;
/** User message data for rebuilding submission on reconnect */
userMessage?: Agents.UserMessageMeta;
/** Response message ID for tracking */
responseMessageId?: string;
/** Sender label for the response (e.g., "GPT-4.1", "Claude") */
sender?: string;
/** Endpoint identifier for abort handling */
endpoint?: string;
/** Icon URL for UI display */
iconURL?: string;
/** Model name for token tracking */
model?: string;
/** Prompt token count for abort token spending */
promptTokens?: number;
}
export type GenerationJobStatus = 'running' | 'complete' | 'error' | 'aborted';
export interface GenerationJob {
streamId: string;
emitter: EventEmitter;
status: GenerationJobStatus;
createdAt: number;
completedAt?: number;
abortController: AbortController;
error?: string;
metadata: GenerationJobMetadata;
readyPromise: Promise<void>;
resolveReady: () => void;
/** Final event when job completes */
finalEvent?: ServerSentEvent;
/** Flag to indicate if a sync event was already sent (prevent duplicate replays) */
syncSent?: boolean;
}
export type ContentPart = Agents.ContentPart;
export type ResumeState = Agents.ResumeState;
export type ChunkHandler = (event: ServerSentEvent) => void;
export type DoneHandler = (event: ServerSentEvent) => void;
export type ErrorHandler = (error: string) => void;
export type UnsubscribeFn = () => void;

View file

@ -0,0 +1,196 @@
/**
* Integration tests for math function with actual config patterns.
* These tests verify that real environment variable expressions from .env.example
* are correctly evaluated by the math function.
*/
import { math } from './math';
describe('math - integration with real config patterns', () => {
describe('SESSION_EXPIRY patterns', () => {
test('should evaluate default SESSION_EXPIRY (15 minutes)', () => {
const result = math('1000 * 60 * 15');
expect(result).toBe(900000); // 15 minutes in ms
});
test('should evaluate 30 minute session', () => {
const result = math('1000 * 60 * 30');
expect(result).toBe(1800000); // 30 minutes in ms
});
test('should evaluate 1 hour session', () => {
const result = math('1000 * 60 * 60');
expect(result).toBe(3600000); // 1 hour in ms
});
});
describe('REFRESH_TOKEN_EXPIRY patterns', () => {
test('should evaluate default REFRESH_TOKEN_EXPIRY (7 days)', () => {
const result = math('(1000 * 60 * 60 * 24) * 7');
expect(result).toBe(604800000); // 7 days in ms
});
test('should evaluate 1 day refresh token', () => {
const result = math('1000 * 60 * 60 * 24');
expect(result).toBe(86400000); // 1 day in ms
});
test('should evaluate 30 day refresh token', () => {
const result = math('(1000 * 60 * 60 * 24) * 30');
expect(result).toBe(2592000000); // 30 days in ms
});
});
describe('BAN_DURATION patterns', () => {
test('should evaluate default BAN_DURATION (2 hours)', () => {
const result = math('1000 * 60 * 60 * 2');
expect(result).toBe(7200000); // 2 hours in ms
});
test('should evaluate 24 hour ban', () => {
const result = math('1000 * 60 * 60 * 24');
expect(result).toBe(86400000); // 24 hours in ms
});
});
describe('Redis config patterns', () => {
test('should evaluate REDIS_RETRY_MAX_DELAY', () => {
expect(math('3000')).toBe(3000);
});
test('should evaluate REDIS_RETRY_MAX_ATTEMPTS', () => {
expect(math('10')).toBe(10);
});
test('should evaluate REDIS_CONNECT_TIMEOUT', () => {
expect(math('10000')).toBe(10000);
});
test('should evaluate REDIS_MAX_LISTENERS', () => {
expect(math('40')).toBe(40);
});
test('should evaluate REDIS_DELETE_CHUNK_SIZE', () => {
expect(math('1000')).toBe(1000);
});
});
describe('MCP config patterns', () => {
test('should evaluate MCP_OAUTH_DETECTION_TIMEOUT', () => {
expect(math('5000')).toBe(5000);
});
test('should evaluate MCP_CONNECTION_CHECK_TTL', () => {
expect(math('60000')).toBe(60000); // 1 minute
});
test('should evaluate MCP_USER_CONNECTION_IDLE_TIMEOUT (15 minutes)', () => {
const result = math('15 * 60 * 1000');
expect(result).toBe(900000); // 15 minutes in ms
});
test('should evaluate MCP_REGISTRY_CACHE_TTL', () => {
expect(math('5000')).toBe(5000); // 5 seconds
});
});
describe('Leader election config patterns', () => {
test('should evaluate LEADER_LEASE_DURATION (25 seconds)', () => {
expect(math('25')).toBe(25);
});
test('should evaluate LEADER_RENEW_INTERVAL (10 seconds)', () => {
expect(math('10')).toBe(10);
});
test('should evaluate LEADER_RENEW_ATTEMPTS', () => {
expect(math('3')).toBe(3);
});
test('should evaluate LEADER_RENEW_RETRY_DELAY (0.5 seconds)', () => {
expect(math('0.5')).toBe(0.5);
});
});
describe('OpenID config patterns', () => {
test('should evaluate OPENID_JWKS_URL_CACHE_TIME (10 minutes)', () => {
const result = math('600000');
expect(result).toBe(600000); // 10 minutes in ms
});
test('should evaluate custom cache time expression', () => {
const result = math('1000 * 60 * 10');
expect(result).toBe(600000); // 10 minutes in ms
});
});
describe('simulated process.env usage', () => {
const originalEnv = process.env;
beforeEach(() => {
process.env = { ...originalEnv };
});
afterEach(() => {
process.env = originalEnv;
});
test('should work with SESSION_EXPIRY from env', () => {
process.env.SESSION_EXPIRY = '1000 * 60 * 15';
const result = math(process.env.SESSION_EXPIRY, 900000);
expect(result).toBe(900000);
});
test('should work with REFRESH_TOKEN_EXPIRY from env', () => {
process.env.REFRESH_TOKEN_EXPIRY = '(1000 * 60 * 60 * 24) * 7';
const result = math(process.env.REFRESH_TOKEN_EXPIRY, 604800000);
expect(result).toBe(604800000);
});
test('should work with BAN_DURATION from env', () => {
process.env.BAN_DURATION = '1000 * 60 * 60 * 2';
const result = math(process.env.BAN_DURATION, 7200000);
expect(result).toBe(7200000);
});
test('should use fallback when env var is undefined', () => {
delete process.env.SESSION_EXPIRY;
const result = math(process.env.SESSION_EXPIRY, 900000);
expect(result).toBe(900000);
});
test('should use fallback when env var is empty string', () => {
process.env.SESSION_EXPIRY = '';
const result = math(process.env.SESSION_EXPIRY, 900000);
expect(result).toBe(900000);
});
test('should use fallback when env var has invalid expression', () => {
process.env.SESSION_EXPIRY = 'invalid';
const result = math(process.env.SESSION_EXPIRY, 900000);
expect(result).toBe(900000);
});
});
describe('time calculation helpers', () => {
// Helper functions to make time calculations more readable
const seconds = (n: number) => n * 1000;
const minutes = (n: number) => seconds(n * 60);
const hours = (n: number) => minutes(n * 60);
const days = (n: number) => hours(n * 24);
test('should match helper calculations', () => {
// Verify our math function produces same results as programmatic calculations
expect(math('1000 * 60 * 15')).toBe(minutes(15));
expect(math('1000 * 60 * 60 * 2')).toBe(hours(2));
expect(math('(1000 * 60 * 60 * 24) * 7')).toBe(days(7));
});
test('should handle complex expressions', () => {
// 2 hours + 30 minutes
expect(math('(1000 * 60 * 60 * 2) + (1000 * 60 * 30)')).toBe(hours(2) + minutes(30));
// Half a day
expect(math('(1000 * 60 * 60 * 24) / 2')).toBe(days(1) / 2);
});
});
});

View file

@ -0,0 +1,326 @@
import { math } from './math';
describe('math', () => {
describe('number input passthrough', () => {
test('should return number as-is when input is a number', () => {
expect(math(42)).toBe(42);
});
test('should return zero when input is 0', () => {
expect(math(0)).toBe(0);
});
test('should return negative numbers as-is', () => {
expect(math(-10)).toBe(-10);
});
test('should return decimal numbers as-is', () => {
expect(math(0.5)).toBe(0.5);
});
test('should return very large numbers as-is', () => {
expect(math(Number.MAX_SAFE_INTEGER)).toBe(Number.MAX_SAFE_INTEGER);
});
});
describe('simple string number parsing', () => {
test('should parse simple integer string', () => {
expect(math('42')).toBe(42);
});
test('should parse zero string', () => {
expect(math('0')).toBe(0);
});
test('should parse negative number string', () => {
expect(math('-10')).toBe(-10);
});
test('should parse decimal string', () => {
expect(math('0.5')).toBe(0.5);
});
test('should parse string with leading/trailing spaces', () => {
expect(math(' 42 ')).toBe(42);
});
test('should parse large number string', () => {
expect(math('9007199254740991')).toBe(Number.MAX_SAFE_INTEGER);
});
});
describe('mathematical expressions - multiplication', () => {
test('should evaluate simple multiplication', () => {
expect(math('2 * 3')).toBe(6);
});
test('should evaluate chained multiplication (BAN_DURATION pattern: 1000 * 60 * 60 * 2)', () => {
// 2 hours in milliseconds
expect(math('1000 * 60 * 60 * 2')).toBe(7200000);
});
test('should evaluate SESSION_EXPIRY pattern (1000 * 60 * 15)', () => {
// 15 minutes in milliseconds
expect(math('1000 * 60 * 15')).toBe(900000);
});
test('should evaluate multiplication without spaces', () => {
expect(math('2*3')).toBe(6);
});
});
describe('mathematical expressions - addition and subtraction', () => {
test('should evaluate simple addition', () => {
expect(math('2 + 3')).toBe(5);
});
test('should evaluate simple subtraction', () => {
expect(math('10 - 3')).toBe(7);
});
test('should evaluate mixed addition and subtraction', () => {
expect(math('10 + 5 - 3')).toBe(12);
});
test('should handle negative results', () => {
expect(math('3 - 10')).toBe(-7);
});
});
describe('mathematical expressions - division', () => {
test('should evaluate simple division', () => {
expect(math('10 / 2')).toBe(5);
});
test('should evaluate division resulting in decimal', () => {
expect(math('7 / 2')).toBe(3.5);
});
});
describe('mathematical expressions - parentheses', () => {
test('should evaluate expression with parentheses (REFRESH_TOKEN_EXPIRY pattern)', () => {
// 7 days in milliseconds: (1000 * 60 * 60 * 24) * 7
expect(math('(1000 * 60 * 60 * 24) * 7')).toBe(604800000);
});
test('should evaluate nested parentheses', () => {
expect(math('((2 + 3) * 4)')).toBe(20);
});
test('should respect operator precedence with parentheses', () => {
expect(math('2 * (3 + 4)')).toBe(14);
});
});
describe('mathematical expressions - modulo', () => {
test('should evaluate modulo operation', () => {
expect(math('10 % 3')).toBe(1);
});
test('should evaluate modulo with larger numbers', () => {
expect(math('100 % 7')).toBe(2);
});
});
describe('complex real-world expressions', () => {
test('should evaluate MCP_USER_CONNECTION_IDLE_TIMEOUT pattern (15 * 60 * 1000)', () => {
// 15 minutes in milliseconds
expect(math('15 * 60 * 1000')).toBe(900000);
});
test('should evaluate Redis default TTL (5000)', () => {
expect(math('5000')).toBe(5000);
});
test('should evaluate LEADER_RENEW_RETRY_DELAY decimal (0.5)', () => {
expect(math('0.5')).toBe(0.5);
});
test('should evaluate BAN_DURATION default (7200000)', () => {
// 2 hours in milliseconds
expect(math('7200000')).toBe(7200000);
});
test('should evaluate expression with mixed operators and parentheses', () => {
// (1 hour + 30 min) in ms
expect(math('(1000 * 60 * 60) + (1000 * 60 * 30)')).toBe(5400000);
});
});
describe('fallback value behavior', () => {
test('should return fallback when input is undefined', () => {
expect(math(undefined, 100)).toBe(100);
});
test('should return fallback when input is null', () => {
// @ts-expect-error - testing runtime behavior with invalid input
expect(math(null, 100)).toBe(100);
});
test('should return fallback when input contains invalid characters', () => {
expect(math('abc', 100)).toBe(100);
});
test('should return fallback when input has SQL injection attempt', () => {
expect(math('1; DROP TABLE users;', 100)).toBe(100);
});
test('should return fallback when input has function call attempt', () => {
expect(math('console.log("hacked")', 100)).toBe(100);
});
test('should return fallback when input is empty string', () => {
expect(math('', 100)).toBe(100);
});
test('should return zero fallback when specified', () => {
expect(math(undefined, 0)).toBe(0);
});
test('should use number input even when fallback is provided', () => {
expect(math(42, 100)).toBe(42);
});
test('should use valid string even when fallback is provided', () => {
expect(math('42', 100)).toBe(42);
});
});
describe('error cases without fallback', () => {
test('should throw error when input is undefined without fallback', () => {
expect(() => math(undefined)).toThrow('str is undefined, but should be a string');
});
test('should throw error when input is null without fallback', () => {
// @ts-expect-error - testing runtime behavior with invalid input
expect(() => math(null)).toThrow('str is object, but should be a string');
});
test('should throw error when input contains invalid characters without fallback', () => {
expect(() => math('abc')).toThrow('Invalid characters in string');
});
test('should throw error when input has letter characters', () => {
expect(() => math('10x')).toThrow('Invalid characters in string');
});
test('should throw error when input has special characters', () => {
expect(() => math('10!')).toThrow('Invalid characters in string');
});
test('should throw error for malicious code injection', () => {
expect(() => math('process.exit(1)')).toThrow('Invalid characters in string');
});
test('should throw error for require injection', () => {
expect(() => math('require("fs")')).toThrow('Invalid characters in string');
});
});
describe('security - input validation', () => {
test('should reject strings with alphabetic characters', () => {
expect(() => math('Math.PI')).toThrow('Invalid characters in string');
});
test('should reject strings with brackets', () => {
expect(() => math('[1,2,3]')).toThrow('Invalid characters in string');
});
test('should reject strings with curly braces', () => {
expect(() => math('{}')).toThrow('Invalid characters in string');
});
test('should reject strings with semicolons', () => {
expect(() => math('1;2')).toThrow('Invalid characters in string');
});
test('should reject strings with quotes', () => {
expect(() => math('"test"')).toThrow('Invalid characters in string');
});
test('should reject strings with backticks', () => {
expect(() => math('`test`')).toThrow('Invalid characters in string');
});
test('should reject strings with equals sign', () => {
expect(() => math('x=1')).toThrow('Invalid characters in string');
});
test('should reject strings with ampersand', () => {
expect(() => math('1 && 2')).toThrow('Invalid characters in string');
});
test('should reject strings with pipe', () => {
expect(() => math('1 || 2')).toThrow('Invalid characters in string');
});
});
describe('edge cases', () => {
test('should handle expression resulting in Infinity with fallback', () => {
// Division by zero returns Infinity, which is technically a number
expect(math('1 / 0')).toBe(Infinity);
});
test('should handle very small decimals', () => {
expect(math('0.001')).toBe(0.001);
});
test('should handle scientific notation format', () => {
// Note: 'e' is not in the allowed character set, so this should fail
expect(() => math('1e3')).toThrow('Invalid characters in string');
});
test('should handle expression with only whitespace with fallback', () => {
expect(math(' ', 100)).toBe(100);
});
test('should handle +number syntax', () => {
expect(math('+42')).toBe(42);
});
test('should handle expression starting with negative', () => {
expect(math('-5 + 10')).toBe(5);
});
test('should handle multiple decimal points with fallback', () => {
// Invalid syntax should return fallback value
expect(math('1.2.3', 100)).toBe(100);
});
test('should throw for multiple decimal points without fallback', () => {
expect(() => math('1.2.3')).toThrow();
});
});
describe('type coercion edge cases', () => {
test('should handle object input with fallback', () => {
// @ts-expect-error - testing runtime behavior with invalid input
expect(math({}, 100)).toBe(100);
});
test('should handle array input with fallback', () => {
// @ts-expect-error - testing runtime behavior with invalid input
expect(math([], 100)).toBe(100);
});
test('should handle boolean true with fallback', () => {
// @ts-expect-error - testing runtime behavior with invalid input
expect(math(true, 100)).toBe(100);
});
test('should handle boolean false with fallback', () => {
// @ts-expect-error - testing runtime behavior with invalid input
expect(math(false, 100)).toBe(100);
});
test('should throw for object input without fallback', () => {
// @ts-expect-error - testing runtime behavior with invalid input
expect(() => math({})).toThrow('str is object, but should be a string');
});
test('should throw for array input without fallback', () => {
// @ts-expect-error - testing runtime behavior with invalid input
expect(() => math([])).toThrow('str is object, but should be a string');
});
});
});

View file

@ -1,3 +1,5 @@
import { evaluate } from 'mathjs';
/**
* Evaluates a mathematical expression provided as a string and returns the result.
*
@ -5,6 +7,8 @@
* If the input is not a string or contains invalid characters, an error is thrown.
* If the evaluated result is not a number, an error is thrown.
*
* Uses mathjs for safe expression evaluation instead of eval().
*
* @param str - The mathematical expression to evaluate, or a number.
* @param fallbackValue - The default value to return if the input is not a string or number, or if the evaluated result is not a number.
*
@ -32,14 +36,22 @@ export function math(str: string | number | undefined, fallbackValue?: number):
throw new Error('Invalid characters in string');
}
const value = eval(str);
try {
const value = evaluate(str);
if (typeof value !== 'number') {
if (typeof value !== 'number') {
if (fallback) {
return fallbackValue;
}
throw new Error(`[math] str did not evaluate to a number but to a ${typeof value}`);
}
return value;
} catch (error) {
if (fallback) {
return fallbackValue;
}
throw new Error(`[math] str did not evaluate to a number but to a ${typeof value}`);
const originalMessage = error instanceof Error ? error.message : String(error);
throw new Error(`[math] Error while evaluating mathematical expression: ${originalMessage}`);
}
return value;
}

View file

@ -27,4 +27,4 @@ export function sanitizeTitle(rawTitle: string): string {
// Step 5: Return trimmed result or fallback if empty
return trimmed.length > 0 ? trimmed : DEFAULT_FALLBACK;
}
}

View file

@ -8,7 +8,7 @@
"target": "es2015",
"moduleResolution": "node",
"allowSyntheticDefaultImports": true,
"lib": ["es2017", "dom", "ES2021.String"],
"lib": ["es2017", "dom", "ES2021.String", "ES2021.WeakRef"],
"allowJs": true,
"skipLibCheck": true,
"esModuleInterop": true,

View file

@ -1,6 +1,6 @@
{
"name": "@librechat/client",
"version": "0.4.1",
"version": "0.4.3",
"description": "React components for LibreChat",
"repository": {
"type": "git",
@ -35,10 +35,10 @@
"@dicebear/core": "^9.2.2",
"@headlessui/react": "^2.1.2",
"@radix-ui/react-accordion": "^1.2.11",
"@radix-ui/react-alert-dialog": "^1.1.15",
"@radix-ui/react-alert-dialog": "1.0.2",
"@radix-ui/react-checkbox": "^1.0.3",
"@radix-ui/react-collapsible": "^1.1.11",
"@radix-ui/react-dialog": "^1.1.15",
"@radix-ui/react-dialog": "1.0.2",
"@radix-ui/react-dropdown-menu": "^2.1.1",
"@radix-ui/react-hover-card": "^1.0.5",
"@radix-ui/react-icons": "^1.3.0",

View file

@ -39,7 +39,7 @@ const AccordionContent = React.forwardRef<
>(({ className = '', children, ...props }, ref) => (
<AccordionPrimitive.Content
ref={ref}
className="overflow-x-visible text-sm data-[state=closed]:animate-accordion-up data-[state=open]:animate-accordion-down"
className="overflow-y-hidden overflow-x-visible text-sm transition-opacity data-[state=closed]:animate-accordion-up data-[state=open]:animate-accordion-down data-[state=closed]:opacity-0 data-[state=open]:opacity-100"
{...props}
>
<div className={cn('pb-4 pt-0', className)}>{children}</div>

View file

@ -103,7 +103,7 @@ const Menu: React.FC<MenuProps> = ({
>
<Ariakit.MenuButton
className={cn(
'group flex w-full cursor-pointer items-center justify-between gap-2 rounded-lg px-3 py-3.5 text-sm text-text-primary outline-none transition-colors duration-200 hover:bg-surface-hover focus:bg-surface-hover md:px-2.5 md:py-2',
'group flex w-full cursor-pointer items-center justify-between gap-2 rounded-lg px-3 py-3.5 text-sm text-text-primary outline-none hover:bg-surface-hover focus:bg-surface-hover md:px-2.5 md:py-2',
itemClassName,
)}
disabled={item.disabled}
@ -138,7 +138,7 @@ const Menu: React.FC<MenuProps> = ({
key={`${keyPrefix ?? ''}${index}-${item.id ?? ''}`}
id={item.id}
className={cn(
'group flex w-full cursor-pointer items-center gap-2 rounded-lg px-3 py-3.5 text-sm text-text-primary outline-none transition-colors duration-200 hover:bg-surface-hover focus:bg-surface-hover md:px-2.5 md:py-2',
'group flex w-full cursor-pointer items-center gap-2 rounded-lg px-3 py-3.5 text-sm text-text-primary outline-none hover:bg-surface-hover focus:bg-surface-hover md:px-2.5 md:py-2',
itemClassName,
item.className,
)}

View file

@ -0,0 +1,54 @@
import * as React from 'react';
import { cn } from '~/utils';
export interface FilterInputProps
extends Omit<React.InputHTMLAttributes<HTMLInputElement>, 'placeholder'> {
/** The label text shown in the floating label */
label: string;
/** Unique identifier for the input - used to link label */
inputId: string;
/** Container className for custom styling */
containerClassName?: string;
}
/**
* A standardized filter/search input component with a floating label
* that animates up when focused or has a value.
*
* @example
* <FilterInput
* inputId="bookmarks-filter"
* label={localize('com_ui_bookmarks_filter')}
* value={searchQuery}
* onChange={(e) => setSearchQuery(e.target.value)}
* />
*/
const FilterInput = React.forwardRef<HTMLInputElement, FilterInputProps>(
({ className, label, inputId, containerClassName, ...props }, ref) => {
return (
<div className={cn('relative', containerClassName)}>
<input
id={inputId}
ref={ref}
placeholder=" "
aria-label={label}
className={cn(
'peer flex h-10 w-full rounded-lg border border-border-light bg-transparent px-3 py-2 text-sm ring-offset-background placeholder:text-muted-foreground focus-visible:outline-none disabled:cursor-not-allowed disabled:opacity-50',
className,
)}
{...props}
/>
<label
htmlFor={inputId}
className="pointer-events-none absolute left-3 top-1/2 -translate-y-1/2 text-sm text-text-secondary transition-all duration-200 peer-focus:top-0 peer-focus:bg-background peer-focus:px-1 peer-focus:text-xs peer-[:not(:placeholder-shown)]:top-0 peer-[:not(:placeholder-shown)]:bg-background peer-[:not(:placeholder-shown)]:px-1 peer-[:not(:placeholder-shown)]:text-xs"
>
{label}
</label>
</div>
);
},
);
FilterInput.displayName = 'FilterInput';
export { FilterInput };

View file

@ -7,7 +7,7 @@ const Input = React.forwardRef<HTMLInputElement, InputProps>(({ className, ...pr
return (
<input
className={cn(
'flex h-10 w-full rounded-lg border border-input bg-transparent px-3 py-2 text-sm ring-offset-background placeholder:text-muted-foreground focus-visible:outline-none disabled:cursor-not-allowed disabled:opacity-50',
'flex h-10 w-full rounded-lg border border-border-light bg-transparent px-3 py-2 text-sm ring-offset-background placeholder:text-muted-foreground focus-visible:outline-none disabled:cursor-not-allowed disabled:opacity-50',
className ?? '',
)}
ref={ref}

View file

@ -3,13 +3,19 @@ import * as DialogPrimitive from '@radix-ui/react-dialog';
import { X } from 'lucide-react';
import { cn } from '~/utils';
const DialogDepthContext = React.createContext(0);
interface OGDialogProps extends DialogPrimitive.DialogProps {
triggerRef?: React.RefObject<HTMLButtonElement | HTMLInputElement | HTMLDivElement | null>;
triggerRefs?: React.RefObject<HTMLButtonElement | HTMLInputElement | HTMLDivElement | null>[];
}
const Dialog = React.forwardRef<HTMLDivElement, OGDialogProps>(
({ children, triggerRef, triggerRefs, onOpenChange, ...props }, _ref) => {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
({ children, triggerRef, triggerRefs, onOpenChange, ...props }, ref) => {
const parentDepth = React.useContext(DialogDepthContext);
const currentDepth = parentDepth + 1;
const handleOpenChange = (open: boolean) => {
if (!open && triggerRef?.current) {
setTimeout(() => {
@ -29,9 +35,11 @@ const Dialog = React.forwardRef<HTMLDivElement, OGDialogProps>(
};
return (
<DialogPrimitive.Root {...props} onOpenChange={handleOpenChange}>
{children}
</DialogPrimitive.Root>
<DialogDepthContext.Provider value={currentDepth}>
<DialogPrimitive.Root {...props} onOpenChange={handleOpenChange}>
{children}
</DialogPrimitive.Root>
</DialogDepthContext.Provider>
);
},
);
@ -45,16 +53,22 @@ const DialogClose = DialogPrimitive.Close;
export const DialogOverlay = React.forwardRef<
React.ElementRef<typeof DialogPrimitive.Overlay>,
React.ComponentPropsWithoutRef<typeof DialogPrimitive.Overlay>
>(({ className, ...props }, ref) => (
<DialogPrimitive.Overlay
ref={ref}
className={cn(
'fixed inset-0 z-50 bg-black/80 data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0',
className,
)}
{...props}
/>
));
>(({ className, style, ...props }, ref) => {
const depth = React.useContext(DialogDepthContext);
const overlayZIndex = 50 + (depth - 1) * 60;
return (
<DialogPrimitive.Overlay
ref={ref}
style={{ ...style, zIndex: overlayZIndex }}
className={cn(
'fixed inset-0 bg-black/80 data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0',
className,
)}
{...props}
/>
);
});
DialogOverlay.displayName = DialogPrimitive.Overlay.displayName;
type DialogContentProps = React.ComponentPropsWithoutRef<typeof DialogPrimitive.Content> & {
@ -73,34 +87,47 @@ const DialogContent = React.forwardRef<
overlayClassName,
showCloseButton = true,
children,
style,
onEscapeKeyDown: propsOnEscapeKeyDown,
...props
},
ref,
) => {
/* Handle Escape key to prevent closing dialog if a tooltip is open
const depth = React.useContext(DialogDepthContext);
const contentZIndex = 100 + (depth - 1) * 60;
/* Handle Escape key to prevent closing dialog if a tooltip or dropdown is open
(this is a workaround in order to achieve WCAG compliance which requires
that our tooltips be dismissable with Escape key) */
const handleEscapeKeyDown = React.useCallback(
(event: KeyboardEvent) => {
const tooltips = document.querySelectorAll('.tooltip');
const dropdownMenus = document.querySelectorAll('[role="menu"]');
for (const tooltip of Array.from(tooltips)) {
for (const tooltip of tooltips) {
const computedStyle = window.getComputedStyle(tooltip);
const opacity = parseFloat(computedStyle.opacity);
if (
tooltip.parentElement &&
computedStyle.display !== 'none' &&
computedStyle.visibility !== 'hidden' &&
opacity > 0
parseFloat(computedStyle.opacity) > 0
) {
event.preventDefault();
return;
}
}
for (const dropdownMenu of dropdownMenus) {
const computedStyle = window.getComputedStyle(dropdownMenu);
if (
computedStyle.display !== 'none' &&
computedStyle.visibility !== 'hidden' &&
parseFloat(computedStyle.opacity) > 0
) {
event.preventDefault();
return;
}
}
// Call the original handler if it exists
propsOnEscapeKeyDown?.(event);
},
[propsOnEscapeKeyDown],
@ -111,9 +138,10 @@ const DialogContent = React.forwardRef<
<DialogOverlay className={overlayClassName} />
<DialogPrimitive.Content
ref={ref}
style={{ ...style, zIndex: contentZIndex }}
onEscapeKeyDown={handleEscapeKeyDown}
className={cn(
'max-w-11/12 fixed left-[50%] top-[50%] z-50 grid max-h-[90vh] w-full translate-x-[-50%] translate-y-[-50%] gap-4 overflow-y-auto rounded-2xl bg-background p-6 text-text-primary shadow-lg duration-200 data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[state=closed]:slide-out-to-left-1/2 data-[state=closed]:slide-out-to-top-[48%] data-[state=open]:slide-in-from-left-1/2 data-[state=open]:slide-in-from-top-[48%]',
'max-w-11/12 fixed left-[50%] top-[50%] grid max-h-[90vh] w-full translate-x-[-50%] translate-y-[-50%] gap-4 overflow-y-auto rounded-2xl bg-background p-6 text-text-primary shadow-lg duration-200 data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[state=closed]:slide-out-to-left-1/2 data-[state=closed]:slide-out-to-top-[48%] data-[state=open]:slide-in-from-left-1/2 data-[state=open]:slide-in-from-top-[48%]',
className,
)}
{...props}
@ -122,7 +150,6 @@ const DialogContent = React.forwardRef<
{showCloseButton && (
<DialogPrimitive.Close className="absolute right-4 top-4 rounded-sm opacity-70 ring-ring-primary ring-offset-background transition-opacity hover:opacity-100 focus:outline-none focus:ring-2 focus:ring-ring focus:ring-offset-2 disabled:pointer-events-none data-[state=open]:bg-accent data-[state=open]:text-muted-foreground">
<X className="h-6 w-6" aria-hidden="true" />
{/* eslint-disable-next-line i18next/no-literal-string */}
<span className="sr-only">Close</span>
</DialogPrimitive.Close>
)}

View file

@ -1,4 +1,4 @@
import React, { useState, useRef, useLayoutEffect, useCallback, memo } from 'react';
import React, { useState, useRef, useLayoutEffect, useEffect, useCallback, memo } from 'react';
import { useLocalize } from '~/hooks';
interface Option {
@ -25,8 +25,9 @@ const Radio = memo(function Radio({
fullWidth = false,
}: RadioProps) {
const localize = useLocalize();
const [currentValue, setCurrentValue] = useState<string>(value ?? '');
const buttonRefs = useRef<(HTMLButtonElement | null)[]>([]);
const [isMounted, setIsMounted] = useState(false);
const [currentValue, setCurrentValue] = useState<string>(value ?? '');
const [backgroundStyle, setBackgroundStyle] = useState<React.CSSProperties>({});
const handleChange = (newValue: string) => {
@ -51,9 +52,21 @@ const Radio = memo(function Radio({
}
}, [currentValue, options]);
// Mark as mounted after dialog animations settle
// Timeout ensures we wait for CSS transitions to complete
useEffect(() => {
const timeout = setTimeout(() => {
setIsMounted(true);
}, 50);
return () => clearTimeout(timeout);
}, []);
useLayoutEffect(() => {
updateBackgroundStyle();
}, [updateBackgroundStyle]);
if (isMounted) {
updateBackgroundStyle();
}
}, [isMounted, updateBackgroundStyle]);
useLayoutEffect(() => {
if (value !== undefined) {
@ -81,7 +94,7 @@ const Radio = memo(function Radio({
className={`relative ${fullWidth ? 'flex' : 'inline-flex'} items-center rounded-lg bg-muted p-1 ${className}`}
role="radiogroup"
>
{selectedIndex >= 0 && (
{selectedIndex >= 0 && isMounted && (
<div
className="pointer-events-none absolute inset-y-1 rounded-md border border-border/50 bg-background shadow-sm transition-all duration-300 ease-out"
style={backgroundStyle}

View file

@ -0,0 +1,106 @@
import * as React from 'react';
import { useState, useCallback } from 'react';
import { Eye, EyeOff, Copy, Check } from 'lucide-react';
import { cn } from '~/utils';
export interface SecretInputProps
extends Omit<React.InputHTMLAttributes<HTMLInputElement>, 'type'> {
/** Show copy button */
showCopy?: boolean;
/** Callback when value is copied */
onCopy?: () => void;
/** Duration in ms to show checkmark after copy (default: 2000) */
copyFeedbackDuration?: number;
}
const SecretInput = React.forwardRef<HTMLInputElement, SecretInputProps>(
(
{ className, showCopy = false, onCopy, copyFeedbackDuration = 2000, disabled, value, ...props },
ref,
) => {
const [isVisible, setIsVisible] = useState(false);
const [isCopied, setIsCopied] = useState(false);
const toggleVisibility = useCallback(() => {
setIsVisible((prev) => !prev);
}, []);
const handleCopy = useCallback(async () => {
if (isCopied || disabled) {
return;
}
const textToCopy = typeof value === 'string' ? value : '';
if (!textToCopy) {
return;
}
try {
await navigator.clipboard.writeText(textToCopy);
setIsCopied(true);
onCopy?.();
setTimeout(() => {
setIsCopied(false);
}, copyFeedbackDuration);
} catch (err) {
console.error('Failed to copy:', err);
}
}, [value, isCopied, disabled, onCopy, copyFeedbackDuration]);
return (
<div className="relative flex items-center">
<input
type={isVisible ? 'text' : 'password'}
className={cn(
'flex h-10 w-full rounded-lg border border-input bg-transparent px-3 py-2 text-sm ring-offset-background placeholder:text-muted-foreground focus-visible:outline-none disabled:cursor-not-allowed disabled:opacity-50',
showCopy ? 'pr-20' : 'pr-10',
className ?? '',
)}
ref={ref}
disabled={disabled}
value={value}
autoComplete="off"
spellCheck={false}
{...props}
/>
<div className="absolute right-1 flex items-center gap-0.5">
{showCopy && (
<button
type="button"
onClick={handleCopy}
disabled={disabled || !value}
className={cn(
'flex size-8 items-center justify-center rounded-md text-text-secondary transition-colors',
disabled || !value
? 'cursor-not-allowed opacity-50'
: 'hover:bg-surface-hover hover:text-text-primary',
)}
aria-label={isCopied ? 'Copied' : 'Copy to clipboard'}
>
{isCopied ? <Check className="size-4" /> : <Copy className="size-4" />}
</button>
)}
<button
type="button"
onClick={toggleVisibility}
disabled={disabled}
className={cn(
'flex size-8 items-center justify-center rounded-md text-text-secondary transition-colors',
disabled
? 'cursor-not-allowed opacity-50'
: 'hover:bg-surface-hover hover:text-text-primary',
)}
aria-label={isVisible ? 'Hide password' : 'Show password'}
>
{isVisible ? <EyeOff className="size-4" /> : <Eye className="size-4" />}
</button>
</div>
</div>
);
},
);
SecretInput.displayName = 'SecretInput';
export { SecretInput };

View file

@ -5,10 +5,16 @@ import { useToast } from '~/hooks';
export function Toast() {
const { toast, onOpenChange } = useToast();
const severityClassName = {
/* Going up by 100 units in terms of darkness (eg bg-green-500 to bg-green-600) for
* bg colors produces colors that are too visually dissimilar to LibreChat's standard color palette.
* These colors were derived by adjusting the values in the HSV color space using CCA
* until the 4.5:1 contrast ratio threshold was met against white text while maintaining
* a relatively recognizable color scheme for toasts without compromising accessibility.
* */
[NotificationSeverity.INFO]: 'border-gray-500 bg-gray-500',
[NotificationSeverity.SUCCESS]: 'border-green-500 bg-green-500',
[NotificationSeverity.WARNING]: 'border-orange-600 bg-orange-600',
[NotificationSeverity.ERROR]: 'border-red-500 bg-red-500',
[NotificationSeverity.SUCCESS]: 'border-[#02855E] bg-[#02855E]',
[NotificationSeverity.WARNING]: 'border-[#C75209] bg-[#C75209]',
[NotificationSeverity.ERROR]: 'border-[#E02F1F] bg-[#E02F1F]',
};
return (

View file

@ -1,5 +1,5 @@
.tooltip {
z-index: 50;
z-index: 150;
cursor: pointer;
pointer-events: auto;
border-radius: 0.275rem;

View file

@ -9,6 +9,8 @@ export * from './DropdownMenu';
export * from './HoverCard';
export * from './Input';
export * from './InputNumber';
export * from './SecretInput';
export * from './FilterInput';
export * from './Label';
export * from './OriginalDialog';
export * from './QuestionMark';
@ -30,12 +32,12 @@ export * from './InputOTP';
export * from './MultiSearch';
export * from './Resizable';
export * from './Select';
export { default as DataTable } from './DataTable';
export { default as Radio } from './Radio';
export { default as Badge } from './Badge';
export { default as Avatar } from './Avatar';
export { default as Combobox } from './Combobox';
export { default as Dropdown } from './Dropdown';
export { default as DataTable } from './DataTable';
export { default as SplitText } from './SplitText';
export { default as FormInput } from './FormInput';
export { default as PixelCard } from './PixelCard';

View file

@ -1,6 +1,6 @@
{
"name": "librechat-data-provider",
"version": "0.8.200",
"version": "0.8.210",
"description": "data services for librechat apps",
"main": "dist/index.js",
"module": "dist/index.es.js",
@ -23,7 +23,7 @@
"build:watch": "rollup -c -w",
"rollup:api": "npx rollup -c server-rollup.config.js --bundleConfigAsCjs",
"test": "jest --coverage --watch",
"test:ci": "jest --coverage --ci",
"test:ci": "jest --coverage --ci --logHeapUsage",
"verify": "npm run test:ci",
"b:clean": "bun run rimraf dist",
"b:build": "bun run b:clean && bun run rollup -c --silent --bundleConfigAsCjs"
@ -48,7 +48,6 @@
"@babel/preset-env": "^7.21.5",
"@babel/preset-react": "^7.18.6",
"@babel/preset-typescript": "^7.21.0",
"@langchain/core": "^0.3.62",
"@rollup/plugin-alias": "^5.1.0",
"@rollup/plugin-commonjs": "^29.0.0",
"@rollup/plugin-json": "^6.1.0",

View file

@ -66,6 +66,8 @@ export const messages = (params: q.MessagesListParams) => {
export const messagesArtifacts = (messageId: string) => `${messagesRoot}/artifact/${messageId}`;
export const messagesBranch = () => `${messagesRoot}/branch`;
const shareRoot = `${BASE_URL}/api/share`;
export const shareMessages = (shareId: string) => `${shareRoot}/${shareId}`;
export const getSharedLink = (conversationId: string) => `${shareRoot}/link/${conversationId}`;
@ -101,7 +103,8 @@ export const conversations = (params: q.ConversationListParams) => {
export const conversationById = (id: string) => `${conversationsRoot}/${id}`;
export const genTitle = () => `${conversationsRoot}/gen_title`;
export const genTitle = (conversationId: string) =>
`${conversationsRoot}/gen_title/${encodeURIComponent(conversationId)}`;
export const updateConversation = () => `${conversationsRoot}/update`;
@ -226,6 +229,8 @@ export const agents = ({ path = '', options }: { path?: string; options?: object
return url;
};
export const activeJobs = () => `${BASE_URL}/api/agents/chat/active`;
export const mcp = {
tools: `${BASE_URL}/api/mcp/tools`,
servers: `${BASE_URL}/api/mcp/servers`,

View file

@ -849,6 +849,11 @@ export const configSchema = z.object({
includedTools: z.array(z.string()).optional(),
filteredTools: z.array(z.string()).optional(),
mcpServers: MCPServersSchema.optional(),
mcpSettings: z
.object({
allowedDomains: z.array(z.string()).optional(),
})
.optional(),
interface: interfaceSchema,
turnstile: turnstileSchema.optional(),
fileStrategy: fileSourceSchema.default(FileSources.local),
@ -1234,6 +1239,7 @@ export enum InfiniteCollections {
*/
export enum Time {
ONE_DAY = 86400000,
TWELVE_HOURS = 43200000,
ONE_HOUR = 3600000,
THIRTY_MINUTES = 1800000,
TEN_MINUTES = 600000,
@ -1595,7 +1601,7 @@ export enum TTSProviders {
/** Enum for app-wide constants */
export enum Constants {
/** Key for the app's version. */
VERSION = 'v0.8.1',
VERSION = 'v0.8.2-rc1',
/** Key for the Custom Config's version (librechat.yaml). */
CONFIG_VERSION = '1.3.1',
/** Standard value for the first message's `parentMessageId` value, to indicate no parent exists. */

View file

@ -5,6 +5,7 @@ import * as s from './schemas';
export default function createPayload(submission: t.TSubmission) {
const {
isEdited,
addedConvo,
userMessage,
isContinued,
isTemporary,
@ -32,6 +33,7 @@ export default function createPayload(submission: t.TSubmission) {
...userMessage,
...endpointOption,
endpoint,
addedConvo,
isTemporary,
isRegenerate,
editedContent,

View file

@ -724,7 +724,7 @@ export function archiveConversation(
}
export function genTitle(payload: m.TGenTitleRequest): Promise<m.TGenTitleResponse> {
return request.post(endpoints.genTitle(), payload);
return request.get(endpoints.genTitle(payload.conversationId));
}
export const listMessages = (params?: q.MessagesListParams): Promise<q.MessagesListResponse> => {
@ -756,6 +756,12 @@ export const editArtifact = async ({
return request.post(endpoints.messagesArtifacts(messageId), params);
};
export const branchMessage = async (
payload: m.TBranchMessageRequest,
): Promise<m.TBranchMessageResponse> => {
return request.post(endpoints.messagesBranch(), payload);
};
export function getMessagesByConvoId(conversationId: string): Promise<s.TMessage[]> {
if (
conversationId === config.Constants.NEW_CONVO ||
@ -1037,3 +1043,12 @@ export function getGraphApiToken(params: q.GraphTokenParams): Promise<q.GraphTok
export function getDomainServerBaseUrl(): string {
return `${endpoints.apiBaseUrl()}/api`;
}
/* Active Jobs */
export interface ActiveJobsResponse {
activeJobIds: string[];
}
export const getActiveJobs = (): Promise<ActiveJobsResponse> => {
return request.get(endpoints.activeJobs());
};

View file

@ -53,7 +53,11 @@ export const fullMimeTypesList = [
'image/heic',
'image/heif',
'application/x-tar',
'application/x-sh',
'application/typescript',
'application/sql',
'application/yaml',
'application/vnd.coffeescript',
'application/xml',
'application/zip',
'image/svg',
@ -140,7 +144,7 @@ export const textMimeTypes =
/^(text\/(x-c|x-csharp|tab-separated-values|x-c\+\+|x-h|x-java|html|markdown|x-php|x-python|x-script\.python|x-ruby|x-tex|plain|css|vtt|javascript|csv|xml))$/;
export const applicationMimeTypes =
/^(application\/(epub\+zip|csv|json|pdf|x-tar|typescript|vnd\.openxmlformats-officedocument\.(wordprocessingml\.document|presentationml\.presentation|spreadsheetml\.sheet)|xml|zip))$/;
/^(application\/(epub\+zip|csv|json|pdf|x-tar|x-sh|typescript|sql|yaml|vnd\.coffeescript|vnd\.openxmlformats-officedocument\.(wordprocessingml\.document|presentationml\.presentation|spreadsheetml\.sheet)|xml|zip))$/;
export const imageMimeTypes = /^image\/(jpeg|gif|png|webp|heic|heif)$/;
@ -180,24 +184,110 @@ export const codeInterpreterMimeTypes = [
];
export const codeTypeMapping: { [key: string]: string } = {
c: 'text/x-c',
cs: 'text/x-csharp',
cpp: 'text/x-c++',
h: 'text/x-h',
md: 'text/markdown',
php: 'text/x-php',
py: 'text/x-python',
rb: 'text/x-ruby',
tex: 'text/x-tex',
js: 'text/javascript',
sh: 'application/x-sh',
ts: 'application/typescript',
tar: 'application/x-tar',
zip: 'application/zip',
yml: 'application/x-yaml',
yaml: 'application/x-yaml',
log: 'text/plain',
tsv: 'text/tab-separated-values',
c: 'text/x-c', // .c - C source
cs: 'text/x-csharp', // .cs - C# source
cpp: 'text/x-c++', // .cpp - C++ source
h: 'text/x-h', // .h - C/C++ header
md: 'text/markdown', // .md - Markdown
php: 'text/x-php', // .php - PHP source
py: 'text/x-python', // .py - Python source
rb: 'text/x-ruby', // .rb - Ruby source
tex: 'text/x-tex', // .tex - LaTeX source
js: 'text/javascript', // .js - JavaScript source
sh: 'application/x-sh', // .sh - Shell script
ts: 'application/typescript', // .ts - TypeScript source
tar: 'application/x-tar', // .tar - Tar archive
zip: 'application/zip', // .zip - ZIP archive
log: 'text/plain', // .log - Log file
tsv: 'text/tab-separated-values', // .tsv - Tab-separated values
yml: 'application/yaml', // .yml - YAML
yaml: 'application/yaml', // .yaml - YAML
sql: 'application/sql', // .sql - SQL (IANA registered)
dart: 'text/plain', // .dart - Dart source
coffee: 'application/vnd.coffeescript', // .coffee - CoffeeScript (IANA registered)
go: 'text/plain', // .go - Go source
rs: 'text/plain', // .rs - Rust source
swift: 'text/plain', // .swift - Swift source
kt: 'text/plain', // .kt - Kotlin source
kts: 'text/plain', // .kts - Kotlin script
scala: 'text/plain', // .scala - Scala source
lua: 'text/plain', // .lua - Lua source
r: 'text/plain', // .r - R source
pl: 'text/plain', // .pl - Perl source
pm: 'text/plain', // .pm - Perl module
groovy: 'text/plain', // .groovy - Groovy source
gradle: 'text/plain', // .gradle - Gradle build script
clj: 'text/plain', // .clj - Clojure source
cljs: 'text/plain', // .cljs - ClojureScript source
cljc: 'text/plain', // .cljc - Clojure common source
elm: 'text/plain', // .elm - Elm source
erl: 'text/plain', // .erl - Erlang source
hrl: 'text/plain', // .hrl - Erlang header
ex: 'text/plain', // .ex - Elixir source
exs: 'text/plain', // .exs - Elixir script
hs: 'text/plain', // .hs - Haskell source
lhs: 'text/plain', // .lhs - Literate Haskell source
ml: 'text/plain', // .ml - OCaml source
mli: 'text/plain', // .mli - OCaml interface
fs: 'text/plain', // .fs - F# source
fsx: 'text/plain', // .fsx - F# script
lisp: 'text/plain', // .lisp - Lisp source
cl: 'text/plain', // .cl - Common Lisp source
scm: 'text/plain', // .scm - Scheme source
rkt: 'text/plain', // .rkt - Racket source
jsx: 'text/plain', // .jsx - React JSX
tsx: 'text/plain', // .tsx - React TSX
vue: 'text/plain', // .vue - Vue component
svelte: 'text/plain', // .svelte - Svelte component
astro: 'text/plain', // .astro - Astro component
scss: 'text/plain', // .scss - SCSS source
sass: 'text/plain', // .sass - Sass source
less: 'text/plain', // .less - Less source
styl: 'text/plain', // .styl - Stylus source
toml: 'text/plain', // .toml - TOML config
ini: 'text/plain', // .ini - INI config
cfg: 'text/plain', // .cfg - Config file
conf: 'text/plain', // .conf - Config file
env: 'text/plain', // .env - Environment file
properties: 'text/plain', // .properties - Java properties
graphql: 'text/plain', // .graphql - GraphQL schema/query
gql: 'text/plain', // .gql - GraphQL schema/query
proto: 'text/plain', // .proto - Protocol Buffers
dockerfile: 'text/plain', // Dockerfile
makefile: 'text/plain', // Makefile
cmake: 'text/plain', // .cmake - CMake script
rake: 'text/plain', // .rake - Rake task
gemspec: 'text/plain', // .gemspec - Ruby gem spec
bash: 'text/plain', // .bash - Bash script
zsh: 'text/plain', // .zsh - Zsh script
fish: 'text/plain', // .fish - Fish script
ps1: 'text/plain', // .ps1 - PowerShell script
psm1: 'text/plain', // .psm1 - PowerShell module
bat: 'text/plain', // .bat - Batch script
cmd: 'text/plain', // .cmd - Windows command script
asm: 'text/plain', // .asm - Assembly source
s: 'text/plain', // .s - Assembly source
v: 'text/plain', // .v - V or Verilog source
zig: 'text/plain', // .zig - Zig source
nim: 'text/plain', // .nim - Nim source
cr: 'text/plain', // .cr - Crystal source
d: 'text/plain', // .d - D source
pas: 'text/plain', // .pas - Pascal source
pp: 'text/plain', // .pp - Pascal/Puppet source
f90: 'text/plain', // .f90 - Fortran 90 source
f95: 'text/plain', // .f95 - Fortran 95 source
f03: 'text/plain', // .f03 - Fortran 2003 source
jl: 'text/plain', // .jl - Julia source
m: 'text/plain', // .m - Objective-C/MATLAB source
mm: 'text/plain', // .mm - Objective-C++ source
ada: 'text/plain', // .ada - Ada source
adb: 'text/plain', // .adb - Ada body
ads: 'text/plain', // .ads - Ada spec
cob: 'text/plain', // .cob - COBOL source
cbl: 'text/plain', // .cbl - COBOL source
tcl: 'text/plain', // .tcl - Tcl source
awk: 'text/plain', // .awk - AWK script
sed: 'text/plain', // .sed - Sed script
};
/** Maps image extensions to MIME types for formats browsers may not recognize */

View file

@ -60,6 +60,8 @@ export enum QueryKeys {
/* MCP Servers */
mcpServers = 'mcpServers',
mcpServer = 'mcpServer',
/* Active Jobs */
activeJobs = 'activeJobs',
}
// Dynamic query keys that require parameters

View file

@ -22,6 +22,12 @@ export type TModelSpec = {
* - If omitted, the spec appears as a standalone item at the top level
*/
group?: string;
/**
* Optional icon URL for the group this spec belongs to.
* Only needs to be set on one spec per group - the first one found with a groupIcon will be used.
* Can be a URL or an endpoint name to use its icon.
*/
groupIcon?: string | EModelEndpoint;
showIconInMenu?: boolean;
showIconInHeader?: boolean;
iconURL?: string | EModelEndpoint; // Allow using project-included icons
@ -40,6 +46,7 @@ export const tModelSpecSchema = z.object({
default: z.boolean().optional(),
description: z.string().optional(),
group: z.string().optional(),
groupIcon: z.union([z.string(), eModelEndpointSchema]).optional(),
showIconInMenu: z.boolean().optional(),
showIconInHeader: z.boolean().optional(),
iconURL: z.union([z.string(), eModelEndpointSchema]).optional(),

View file

@ -197,7 +197,7 @@ const extractOmniVersion = (modelStr: string): string => {
return '';
};
export const getResponseSender = (endpointOption: t.TEndpointOption): string => {
export const getResponseSender = (endpointOption: Partial<t.TEndpointOption>): string => {
const {
model: _m,
endpoint: _e,
@ -216,10 +216,11 @@ export const getResponseSender = (endpointOption: t.TEndpointOption): string =>
if (
[EModelEndpoint.openAI, EModelEndpoint.bedrock, EModelEndpoint.azureOpenAI].includes(endpoint)
) {
if (chatGptLabel) {
return chatGptLabel;
} else if (modelLabel) {
if (modelLabel) {
return modelLabel;
} else if (chatGptLabel) {
// @deprecated - prefer modelLabel
return chatGptLabel;
} else if (model && extractOmniVersion(model)) {
return extractOmniVersion(model);
} else if (model && (model.includes('mistral') || model.includes('codestral'))) {
@ -255,6 +256,7 @@ export const getResponseSender = (endpointOption: t.TEndpointOption): string =>
if (modelLabel) {
return modelLabel;
} else if (chatGptLabel) {
// @deprecated - prefer modelLabel
return chatGptLabel;
} else if (model && extractOmniVersion(model)) {
return extractOmniVersion(model);
@ -414,3 +416,138 @@ export function replaceSpecialVars({ text, user }: { text: string; user?: t.TUse
return result;
}
/**
* Parsed ephemeral agent ID result
*/
export type ParsedEphemeralAgentId = {
endpoint: string;
model: string;
sender?: string;
index?: number;
};
/**
* Encodes an ephemeral agent ID from endpoint, model, optional sender, and optional index.
* Uses __ to replace : (reserved in graph node names) and ___ to separate sender.
*
* Format: endpoint__model___sender or endpoint__model___sender____index (if index provided)
*
* @example
* encodeEphemeralAgentId({ endpoint: 'openAI', model: 'gpt-4o', sender: 'GPT-4o' })
* // => 'openAI__gpt-4o___GPT-4o'
*
* @example
* encodeEphemeralAgentId({ endpoint: 'openAI', model: 'gpt-4o', sender: 'GPT-4o', index: 1 })
* // => 'openAI__gpt-4o___GPT-4o____1'
*/
export function encodeEphemeralAgentId({
endpoint,
model,
sender,
index,
}: {
endpoint: string;
model: string;
sender?: string;
index?: number;
}): string {
const base = `${endpoint}:${model}`.replace(/:/g, '__');
let result = base;
if (sender) {
// Use ___ as separator before sender to distinguish from __ in model names
result = `${base}___${sender.replace(/:/g, '__')}`;
}
if (index != null) {
// Use ____ (4 underscores) as separator for index
result = `${result}____${index}`;
}
return result;
}
/**
* Parses an ephemeral agent ID back into its components.
* Returns undefined if the ID doesn't match the expected format.
*
* Format: endpoint__model___sender or endpoint__model___sender____index
* - ____ (4 underscores) separates optional index suffix
* - ___ (triple underscore) separates model from optional sender
* - __ (double underscore) replaces : in endpoint/model names
*
* @example
* parseEphemeralAgentId('openAI__gpt-4o___GPT-4o')
* // => { endpoint: 'openAI', model: 'gpt-4o', sender: 'GPT-4o' }
*
* @example
* parseEphemeralAgentId('openAI__gpt-4o___GPT-4o____1')
* // => { endpoint: 'openAI', model: 'gpt-4o', sender: 'GPT-4o', index: 1 }
*/
export function parseEphemeralAgentId(agentId: string): ParsedEphemeralAgentId | undefined {
if (!agentId.includes('__')) {
return undefined;
}
// First check for index suffix (separated by ____)
let index: number | undefined;
let workingId = agentId;
if (agentId.includes('____')) {
const lastIndexSep = agentId.lastIndexOf('____');
const indexStr = agentId.slice(lastIndexSep + 4);
const parsedIndex = parseInt(indexStr, 10);
if (!isNaN(parsedIndex)) {
index = parsedIndex;
workingId = agentId.slice(0, lastIndexSep);
}
}
// Check for sender (separated by ___)
let sender: string | undefined;
let mainPart = workingId;
if (workingId.includes('___')) {
const [before, after] = workingId.split('___');
mainPart = before;
// Restore colons in sender if any
sender = after?.replace(/__/g, ':');
}
const [endpoint, ...modelParts] = mainPart.split('__');
if (!endpoint || modelParts.length === 0) {
return undefined;
}
// Restore colons in model name (model names can contain colons like claude-3:opus)
const model = modelParts.join(':');
return { endpoint, model, sender, index };
}
/**
* Checks if an agent ID represents an ephemeral (non-saved) agent.
* Real agent IDs always start with "agent_", so anything else is ephemeral.
*/
export function isEphemeralAgentId(agentId: string | null | undefined): boolean {
return !agentId?.startsWith('agent_');
}
/**
* Strips the index suffix (____N) from an agent ID if present.
* Works with both ephemeral and real agent IDs.
*
* @example
* stripAgentIdSuffix('agent_abc123____1') // => 'agent_abc123'
* stripAgentIdSuffix('openAI__gpt-4o___GPT-4o____1') // => 'openAI__gpt-4o___GPT-4o'
* stripAgentIdSuffix('agent_abc123') // => 'agent_abc123' (unchanged)
*/
export function stripAgentIdSuffix(agentId: string): string {
return agentId.replace(/____\d+$/, '');
}
/**
* Appends an index suffix (____N) to an agent ID.
* Used to distinguish parallel agents with the same base ID.
*
* @example
* appendAgentIdSuffix('agent_abc123', 1) // => 'agent_abc123____1'
* appendAgentIdSuffix('openAI__gpt-4o___GPT-4o', 1) // => 'openAI__gpt-4o___GPT-4o____1'
*/
export function appendAgentIdSuffix(agentId: string, index: number): string {
return `${agentId}____${index}`;
}

View file

@ -49,7 +49,8 @@ export const documentSupportedProviders = new Set<string>([
EModelEndpoint.anthropic,
EModelEndpoint.openAI,
EModelEndpoint.custom,
EModelEndpoint.azureOpenAI,
// handled in AttachFileMenu and DragDropModal since azureOpenAI only supports documents with Use Responses API set to true
// EModelEndpoint.azureOpenAI,
EModelEndpoint.google,
Providers.VERTEXAI,
Providers.MISTRALAI,

View file

@ -109,6 +109,8 @@ export type TPayload = Partial<TMessage> &
isTemporary: boolean;
ephemeralAgent?: TEphemeralAgent | null;
editedContent?: TEditedContent | null;
/** Added conversation for multi-convo feature */
addedConvo?: TConversation;
};
export type TEditedContent =
@ -136,6 +138,8 @@ export type TSubmission = {
clientTimestamp?: string;
ephemeralAgent?: TEphemeralAgent | null;
editedContent?: TEditedContent | null;
/** Added conversation for multi-convo feature */
addedConvo?: TConversation;
};
export type EventSubmission = Omit<TSubmission, 'initialResponse'> & { initialResponse: TMessage };

View file

@ -33,11 +33,26 @@ export namespace Agents {
image_url: string | { url: string; detail?: ImageDetail };
};
export type MessageContentVideoUrl = {
type: ContentTypes.VIDEO_URL;
video_url: { url: string };
};
export type MessageContentInputAudio = {
type: ContentTypes.INPUT_AUDIO;
input_audio: {
data: string;
format: string;
};
};
export type MessageContentComplex =
| ReasoningContentText
| AgentUpdate
| MessageContentText
| MessageContentImageUrl
| MessageContentVideoUrl
| MessageContentInputAudio
// eslint-disable-next-line @typescript-eslint/no-explicit-any
| (Record<string, any> & { type?: ContentTypes | string })
// eslint-disable-next-line @typescript-eslint/no-explicit-any
@ -166,11 +181,40 @@ export namespace Agents {
type: StepTypes;
id: string; // #new
runId?: string; // #new
agentId?: string; // #new
index: number; // #new
stepIndex?: number; // #new
/** Group ID for parallel content - parts with same groupId are displayed in columns */
groupId?: number; // #new
stepDetails: StepDetails;
usage: null | object;
};
/** Content part for aggregated message content */
export interface ContentPart {
type: string;
text?: string;
[key: string]: unknown;
}
/** User message metadata for rebuilding submission on reconnect */
export interface UserMessageMeta {
messageId: string;
parentMessageId?: string;
conversationId?: string;
text?: string;
}
/** State data sent to reconnecting clients */
export interface ResumeState {
runSteps: RunStep[];
/** Aggregated content parts - can be MessageContentComplex[] or ContentPart[] */
aggregatedContent?: MessageContentComplex[];
userMessage?: UserMessageMeta;
responseMessageId?: string;
conversationId?: string;
sender?: string;
}
/**
* Represents a run step delta i.e. any changed fields on a run step during
* streaming.
@ -266,6 +310,8 @@ export namespace Agents {
| ContentTypes.THINK
| ContentTypes.TEXT
| ContentTypes.IMAGE_URL
| ContentTypes.VIDEO_URL
| ContentTypes.INPUT_AUDIO
| string;
}

View file

@ -166,6 +166,7 @@ export type AgentModelParameters = {
top_p: AgentParameterValue;
frequency_penalty: AgentParameterValue;
presence_penalty: AgentParameterValue;
useResponsesApi?: boolean;
};
export interface AgentBaseResource {
@ -466,8 +467,17 @@ export type PartMetadata = {
action?: boolean;
auth?: string;
expires_at?: number;
/** Index indicating parallel sibling content (same stepIndex in multi-agent runs) */
siblingIndex?: number;
/** Agent ID for parallel agent rendering - identifies which agent produced this content */
agentId?: string;
/** Group ID for parallel content - parts with same groupId are displayed in columns */
groupId?: number;
};
/** Metadata for parallel content rendering - subset of PartMetadata */
export type ContentMetadata = Pick<PartMetadata, 'agentId' | 'groupId'>;
export type ContentPart = (
| CodeToolCall
| RetrievalToolCall
@ -482,18 +492,18 @@ export type ContentPart = (
export type TextData = (Text & PartMetadata) | undefined;
export type TMessageContentParts =
| {
| ({
type: ContentTypes.ERROR;
text?: string | TextData;
error?: string;
}
| { type: ContentTypes.THINK; think?: string | TextData }
| {
} & ContentMetadata)
| ({ type: ContentTypes.THINK; think?: string | TextData } & ContentMetadata)
| ({
type: ContentTypes.TEXT;
text?: string | TextData;
tool_call_ids?: string[];
}
| {
} & ContentMetadata)
| ({
type: ContentTypes.TOOL_CALL;
tool_call: (
| CodeToolCall
@ -503,10 +513,12 @@ export type TMessageContentParts =
| Agents.AgentToolCall
) &
PartMetadata;
}
| { type: ContentTypes.IMAGE_FILE; image_file: ImageFile & PartMetadata }
| Agents.AgentUpdate
| Agents.MessageContentImageUrl;
} & ContentMetadata)
| ({ type: ContentTypes.IMAGE_FILE; image_file: ImageFile & PartMetadata } & ContentMetadata)
| (Agents.AgentUpdate & ContentMetadata)
| (Agents.MessageContentImageUrl & ContentMetadata)
| (Agents.MessageContentVideoUrl & ContentMetadata)
| (Agents.MessageContentInputAudio & ContentMetadata);
export type StreamContentData = TMessageContentParts & {
/** The index of the current content part */

View file

@ -381,6 +381,20 @@ export type EditArtifactOptions = MutationOptions<
Error
>;
export type TBranchMessageRequest = {
messageId: string;
agentId: string;
};
export type TBranchMessageResponse = types.TMessage;
export type BranchMessageOptions = MutationOptions<
TBranchMessageResponse,
TBranchMessageRequest,
unknown,
Error
>;
export type TLogoutResponse = {
message: string;
redirect?: string;

View file

@ -5,6 +5,8 @@ export enum ContentTypes {
TOOL_CALL = 'tool_call',
IMAGE_FILE = 'image_file',
IMAGE_URL = 'image_url',
VIDEO_URL = 'video_url',
INPUT_AUDIO = 'input_audio',
AGENT_UPDATE = 'agent_update',
ERROR = 'error',
}

View file

@ -1,5 +1,4 @@
import type { Logger as WinstonLogger } from 'winston';
import type { RunnableConfig } from '@langchain/core/runnables';
export type SearchRefType = 'search' | 'image' | 'news' | 'video' | 'ref';
@ -174,16 +173,6 @@ export interface CohereRerankerResponse {
export type SafeSearchLevel = 0 | 1 | 2;
export type Logger = WinstonLogger;
export interface SearchToolConfig extends SearchConfig, ProcessSourcesConfig, FirecrawlConfig {
logger?: Logger;
safeSearch?: SafeSearchLevel;
jinaApiKey?: string;
jinaApiUrl?: string;
cohereApiKey?: string;
rerankerType?: RerankerType;
onSearchResults?: (results: SearchResult, runnableConfig?: RunnableConfig) => void;
onGetHighlights?: (link: string) => void;
}
export interface MediaReference {
originalUrl: string;
title?: string;
@ -290,18 +279,6 @@ export interface FirecrawlScraperConfig {
logger?: Logger;
}
export type GetSourcesParams = {
query: string;
date?: DATE_RANGE;
country?: string;
numResults?: number;
safeSearch?: SearchToolConfig['safeSearch'];
images?: boolean;
videos?: boolean;
news?: boolean;
type?: 'search' | 'images' | 'videos' | 'news';
};
/** Serper API */
export interface VideoResult {
title?: string;
@ -609,12 +586,3 @@ export interface SearXNGResult {
publishedDate?: string;
img_src?: string;
}
export type ProcessSourcesFields = {
result: SearchResult;
numElements: number;
query: string;
news: boolean;
proMode: boolean;
onGetHighlights: SearchToolConfig['onGetHighlights'];
};

View file

@ -1,6 +1,6 @@
{
"name": "@librechat/data-schemas",
"version": "0.0.31",
"version": "0.0.32",
"description": "Mongoose schemas and models for LibreChat",
"type": "module",
"main": "dist/index.cjs",

View file

@ -60,7 +60,8 @@ export const AppService = async (params?: {
const availableTools = systemTools;
const mcpConfig = config.mcpServers || null;
const mcpServersConfig = config.mcpServers || null;
const mcpSettings = config.mcpSettings || null;
const registration = config.registration ?? configDefaults.registration;
const interfaceConfig = await loadDefaultInterface({ config, configDefaults });
const turnstileConfig = loadTurnstileConfig(config, configDefaults);
@ -74,7 +75,8 @@ export const AppService = async (params?: {
speech,
balance,
transactions,
mcpConfig,
mcpConfig: mcpServersConfig,
mcpSettings,
webSearch,
fileStrategy,
registration,

View file

@ -4,7 +4,7 @@ export * from './crypto';
export * from './schema';
export * from './utils';
export { createModels } from './models';
export { createMethods } from './methods';
export { createMethods, DEFAULT_REFRESH_TOKEN_EXPIRY, DEFAULT_SESSION_EXPIRY } from './methods';
export type * from './types';
export type * from './methods';
export { default as logger } from './config/winston';

View file

@ -723,7 +723,9 @@ describe('AclEntry Model Tests', () => {
expect(permissionsMap.size).toBe(2); // Only resource1 and resource2 for user
expect(permissionsMap.get(resource1.toString())).toBe(PermissionBits.VIEW);
expect(permissionsMap.get(resource2.toString())).toBe(PermissionBits.VIEW | PermissionBits.EDIT);
expect(permissionsMap.get(resource2.toString())).toBe(
PermissionBits.VIEW | PermissionBits.EDIT,
);
expect(permissionsMap.get(resource3.toString())).toBeUndefined(); // User has no access
});
@ -772,7 +774,9 @@ describe('AclEntry Model Tests', () => {
expect(permissionsMap.size).toBe(2);
/** Resource1 should have VIEW | EDIT (from user + group) */
expect(permissionsMap.get(resource1.toString())).toBe(PermissionBits.VIEW | PermissionBits.EDIT);
expect(permissionsMap.get(resource1.toString())).toBe(
PermissionBits.VIEW | PermissionBits.EDIT,
);
/** Resource2 should have only VIEW (from user) */
expect(permissionsMap.get(resource2.toString())).toBe(PermissionBits.VIEW);
});
@ -847,7 +851,9 @@ describe('AclEntry Model Tests', () => {
);
expect(permissionsMap.size).toBe(2);
expect(permissionsMap.get(resource1.toString())).toBe(PermissionBits.VIEW | PermissionBits.EDIT);
expect(permissionsMap.get(resource1.toString())).toBe(
PermissionBits.VIEW | PermissionBits.EDIT,
);
expect(permissionsMap.get(resource2.toString())).toBe(PermissionBits.VIEW);
});
@ -903,7 +909,9 @@ describe('AclEntry Model Tests', () => {
/** Resources 20-29: USER VIEW | GROUP EDIT */
for (let i = 20; i < 30; i++) {
expect(permissionsMap.get(resources[i].toString())).toBe(PermissionBits.VIEW | PermissionBits.EDIT);
expect(permissionsMap.get(resources[i].toString())).toBe(
PermissionBits.VIEW | PermissionBits.EDIT,
);
}
/** Resources 30-39: GROUP EDIT only */

View file

@ -1,7 +1,9 @@
import { createSessionMethods, type SessionMethods } from './session';
import { createSessionMethods, DEFAULT_REFRESH_TOKEN_EXPIRY, type SessionMethods } from './session';
import { createTokenMethods, type TokenMethods } from './token';
import { createRoleMethods, type RoleMethods } from './role';
import { createUserMethods, type UserMethods } from './user';
import { createUserMethods, DEFAULT_SESSION_EXPIRY, type UserMethods } from './user';
export { DEFAULT_REFRESH_TOKEN_EXPIRY, DEFAULT_SESSION_EXPIRY };
import { createKeyMethods, type KeyMethods } from './key';
import { createFileMethods, type FileMethods } from './file';
/* Memories */

View file

@ -5,8 +5,8 @@ import logger from '~/config/winston';
import { nanoid } from 'nanoid';
const NORMALIZED_LIMIT_DEFAULT = 20;
const MAX_CREATE_RETRIES = 3;
const RETRY_BASE_DELAY_MS = 10;
const MAX_CREATE_RETRIES = 5;
const RETRY_BASE_DELAY_MS = 25;
/**
* Helper to check if an error is a MongoDB duplicate key error.

View file

@ -12,8 +12,8 @@ export class SessionError extends Error {
}
}
const { REFRESH_TOKEN_EXPIRY } = process.env ?? {};
const expires = REFRESH_TOKEN_EXPIRY ? eval(REFRESH_TOKEN_EXPIRY) : 1000 * 60 * 60 * 24 * 7; // 7 days default
/** Default refresh token expiry: 7 days in milliseconds */
export const DEFAULT_REFRESH_TOKEN_EXPIRY = 1000 * 60 * 60 * 24 * 7;
// Factory function that takes mongoose instance and returns the methods
export function createSessionMethods(mongoose: typeof import('mongoose')) {
@ -28,11 +28,13 @@ export function createSessionMethods(mongoose: typeof import('mongoose')) {
throw new SessionError('User ID is required', 'INVALID_USER_ID');
}
const expiresIn = options.expiresIn ?? DEFAULT_REFRESH_TOKEN_EXPIRY;
try {
const Session = mongoose.models.Session;
const currentSession = new Session({
user: userId,
expiration: options.expiration || new Date(Date.now() + expires),
expiration: options.expiration || new Date(Date.now() + expiresIn),
});
const refreshToken = await generateRefreshToken(currentSession);
@ -105,7 +107,10 @@ export function createSessionMethods(mongoose: typeof import('mongoose')) {
async function updateExpiration(
session: t.ISession | string,
newExpiration?: Date,
options: t.UpdateExpirationOptions = {},
): Promise<t.ISession> {
const expiresIn = options.expiresIn ?? DEFAULT_REFRESH_TOKEN_EXPIRY;
try {
const Session = mongoose.models.Session;
const sessionDoc = typeof session === 'string' ? await Session.findById(session) : session;
@ -114,7 +119,7 @@ export function createSessionMethods(mongoose: typeof import('mongoose')) {
throw new SessionError('Session not found', 'SESSION_NOT_FOUND');
}
sessionDoc.expiration = newExpiration || new Date(Date.now() + expires);
sessionDoc.expiration = newExpiration || new Date(Date.now() + expiresIn);
return await sessionDoc.save();
} catch (error) {
logger.error('[updateExpiration] Error updating session:', error);
@ -208,7 +213,9 @@ export function createSessionMethods(mongoose: typeof import('mongoose')) {
}
try {
const expiresIn = session.expiration ? session.expiration.getTime() : Date.now() + expires;
const expiresIn = session.expiration
? session.expiration.getTime()
: Date.now() + DEFAULT_REFRESH_TOKEN_EXPIRY;
if (!session.expiration) {
session.expiration = new Date(expiresIn);

View file

@ -31,11 +31,10 @@ describe('User Methods', () => {
} as IUser;
afterEach(() => {
delete process.env.SESSION_EXPIRY;
delete process.env.JWT_SECRET;
});
it('should default to 15 minutes when SESSION_EXPIRY is not set', async () => {
it('should default to 15 minutes when expiresIn is not provided', async () => {
process.env.JWT_SECRET = 'test-secret';
mockSignPayload.mockResolvedValue('mocked-token');
@ -49,16 +48,15 @@ describe('User Methods', () => {
email: mockUser.email,
},
secret: 'test-secret',
expirationTime: 900, // 15 minutes in seconds
expirationTime: 900, // 15 minutes in seconds (DEFAULT_SESSION_EXPIRY / 1000)
});
});
it('should default to 15 minutes when SESSION_EXPIRY is empty string', async () => {
process.env.SESSION_EXPIRY = '';
it('should default to 15 minutes when expiresIn is undefined', async () => {
process.env.JWT_SECRET = 'test-secret';
mockSignPayload.mockResolvedValue('mocked-token');
await userMethods.generateToken(mockUser);
await userMethods.generateToken(mockUser, undefined);
expect(mockSignPayload).toHaveBeenCalledWith({
payload: {
@ -68,16 +66,15 @@ describe('User Methods', () => {
email: mockUser.email,
},
secret: 'test-secret',
expirationTime: 900, // 15 minutes in seconds
expirationTime: 900, // 15 minutes in seconds (DEFAULT_SESSION_EXPIRY / 1000)
});
});
it('should use custom expiry when SESSION_EXPIRY is set to a valid expression', async () => {
process.env.SESSION_EXPIRY = '1000 * 60 * 30'; // 30 minutes
it('should use custom expiry when expiresIn is provided', async () => {
process.env.JWT_SECRET = 'test-secret';
mockSignPayload.mockResolvedValue('mocked-token');
await userMethods.generateToken(mockUser);
await userMethods.generateToken(mockUser, 1000 * 60 * 30); // 30 minutes
expect(mockSignPayload).toHaveBeenCalledWith({
payload: {
@ -91,12 +88,12 @@ describe('User Methods', () => {
});
});
it('should default to 15 minutes when SESSION_EXPIRY evaluates to falsy value', async () => {
process.env.SESSION_EXPIRY = '0'; // This will evaluate to 0, which is falsy
it('should use 0 when expiresIn is 0', async () => {
process.env.JWT_SECRET = 'test-secret';
mockSignPayload.mockResolvedValue('mocked-token');
await userMethods.generateToken(mockUser);
// When 0 is passed, it should use 0 (caller's responsibility to pass valid value)
await userMethods.generateToken(mockUser, 0);
expect(mockSignPayload).toHaveBeenCalledWith({
payload: {
@ -106,7 +103,7 @@ describe('User Methods', () => {
email: mockUser.email,
},
secret: 'test-secret',
expirationTime: 900, // 15 minutes in seconds
expirationTime: 0, // 0 seconds
});
});
@ -119,45 +116,13 @@ describe('User Methods', () => {
});
it('should return the token from signPayload', async () => {
process.env.SESSION_EXPIRY = '1000 * 60 * 60'; // 1 hour
process.env.JWT_SECRET = 'test-secret';
const expectedToken = 'generated-jwt-token';
mockSignPayload.mockResolvedValue(expectedToken);
const token = await userMethods.generateToken(mockUser);
const token = await userMethods.generateToken(mockUser, 1000 * 60 * 60); // 1 hour
expect(token).toBe(expectedToken);
});
it('should handle invalid SESSION_EXPIRY expressions gracefully', async () => {
process.env.SESSION_EXPIRY = 'invalid expression';
process.env.JWT_SECRET = 'test-secret';
mockSignPayload.mockResolvedValue('mocked-token');
// Mock console.warn to verify it's called
const consoleWarnSpy = jest.spyOn(console, 'warn').mockImplementation();
await userMethods.generateToken(mockUser);
// Should use default value when eval fails
expect(mockSignPayload).toHaveBeenCalledWith({
payload: {
id: mockUser._id,
username: mockUser.username,
provider: mockUser.provider,
email: mockUser.email,
},
secret: 'test-secret',
expirationTime: 900, // 15 minutes in seconds (default)
});
// Verify warning was logged
expect(consoleWarnSpy).toHaveBeenCalledWith(
'Invalid SESSION_EXPIRY expression, using default:',
expect.any(SyntaxError),
);
consoleWarnSpy.mockRestore();
});
});
});

View file

@ -2,6 +2,9 @@ import mongoose, { FilterQuery } from 'mongoose';
import type { IUser, BalanceConfig, CreateUserRequest, UserDeleteResult } from '~/types';
import { signPayload } from '~/crypto';
/** Default JWT session expiry: 15 minutes in milliseconds */
export const DEFAULT_SESSION_EXPIRY = 1000 * 60 * 15;
/** Factory function that takes mongoose instance and returns the methods */
export function createUserMethods(mongoose: typeof import('mongoose')) {
/**
@ -161,24 +164,15 @@ export function createUserMethods(mongoose: typeof import('mongoose')) {
/**
* Generates a JWT token for a given user.
* @param user - The user object
* @param expiresIn - Optional expiry time in milliseconds. Default: 15 minutes
*/
async function generateToken(user: IUser): Promise<string> {
async function generateToken(user: IUser, expiresIn?: number): Promise<string> {
if (!user) {
throw new Error('No user provided');
}
let expires = 1000 * 60 * 15;
if (process.env.SESSION_EXPIRY !== undefined && process.env.SESSION_EXPIRY !== '') {
try {
const evaluated = eval(process.env.SESSION_EXPIRY);
if (evaluated) {
expires = evaluated;
}
} catch (error) {
console.warn('Invalid SESSION_EXPIRY expression, using default:', error);
}
}
const expires = expiresIn ?? DEFAULT_SESSION_EXPIRY;
return await signPayload({
payload: {

View file

@ -5,5 +5,8 @@ import type * as t from '~/types';
* Creates or returns the AgentCategory model using the provided mongoose instance and schema
*/
export function createAgentCategoryModel(mongoose: typeof import('mongoose')) {
return mongoose.models.AgentCategory || mongoose.model<t.IAgentCategory>('AgentCategory', agentCategorySchema);
}
return (
mongoose.models.AgentCategory ||
mongoose.model<t.IAgentCategory>('AgentCategory', agentCategorySchema)
);
}

View file

@ -0,0 +1,110 @@
import { MongoMemoryServer } from 'mongodb-memory-server';
import mongoose from 'mongoose';
import { EModelEndpoint } from 'librechat-data-provider';
import { createConversationModel } from '~/models/convo';
import { createMessageModel } from '~/models/message';
import { SchemaWithMeiliMethods } from '~/models/plugins/mongoMeili';
const mockAddDocuments = jest.fn();
const mockIndex = jest.fn().mockReturnValue({
getRawInfo: jest.fn(),
updateSettings: jest.fn(),
addDocuments: mockAddDocuments,
getDocuments: jest.fn().mockReturnValue({ results: [] }),
});
jest.mock('meilisearch', () => {
return {
MeiliSearch: jest.fn().mockImplementation(() => {
return {
index: mockIndex,
};
}),
};
});
describe('Meilisearch Mongoose plugin', () => {
const OLD_ENV = process.env;
let mongoServer: MongoMemoryServer;
beforeAll(async () => {
process.env = {
...OLD_ENV,
// Set a fake meilisearch host/key so that we activate the meilisearch plugin
MEILI_HOST: 'foo',
MEILI_MASTER_KEY: 'bar',
};
mongoServer = await MongoMemoryServer.create();
const mongoUri = mongoServer.getUri();
await mongoose.connect(mongoUri);
});
beforeEach(() => {
mockAddDocuments.mockClear();
});
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
process.env = OLD_ENV;
});
test('saving conversation indexes w/ meilisearch', async () => {
await createConversationModel(mongoose).create({
conversationId: new mongoose.Types.ObjectId(),
user: new mongoose.Types.ObjectId(),
title: 'Test Conversation',
endpoint: EModelEndpoint.openAI,
});
expect(mockAddDocuments).toHaveBeenCalled();
});
test('saving TTL conversation does NOT index w/ meilisearch', async () => {
await createConversationModel(mongoose).create({
conversationId: new mongoose.Types.ObjectId(),
user: new mongoose.Types.ObjectId(),
title: 'Test Conversation',
endpoint: EModelEndpoint.openAI,
expiredAt: new Date(),
});
expect(mockAddDocuments).not.toHaveBeenCalled();
});
test('saving messages indexes w/ meilisearch', async () => {
await createMessageModel(mongoose).create({
messageId: new mongoose.Types.ObjectId(),
conversationId: new mongoose.Types.ObjectId(),
user: new mongoose.Types.ObjectId(),
isCreatedByUser: true,
});
expect(mockAddDocuments).toHaveBeenCalled();
});
test('saving TTL messages does NOT index w/ meilisearch', async () => {
await createMessageModel(mongoose).create({
messageId: new mongoose.Types.ObjectId(),
conversationId: new mongoose.Types.ObjectId(),
user: new mongoose.Types.ObjectId(),
isCreatedByUser: true,
expiredAt: new Date(),
});
expect(mockAddDocuments).not.toHaveBeenCalled();
});
test('sync w/ meili does not include TTL documents', async () => {
const conversationModel = createConversationModel(mongoose) as SchemaWithMeiliMethods;
await conversationModel.create({
conversationId: new mongoose.Types.ObjectId(),
user: new mongoose.Types.ObjectId(),
title: 'Test Conversation',
endpoint: EModelEndpoint.openAI,
expiredAt: new Date(),
});
await conversationModel.syncWithMeili();
expect(mockAddDocuments).not.toHaveBeenCalled();
});
});

View file

@ -183,7 +183,9 @@ const createMeiliMongooseModel = ({
);
// Build query with resume capability
const query: FilterQuery<unknown> = {};
const query: FilterQuery<unknown> = {
expiredAt: { $exists: false }, // Do not sync TTL documents
};
if (options?.resumeFromId) {
query._id = { $gt: options.resumeFromId };
}
@ -430,6 +432,11 @@ const createMeiliMongooseModel = ({
this: DocumentWithMeiliIndex,
next: CallbackWithoutResultAndOptionalError,
): Promise<void> {
// If this conversation or message has a TTL, don't index it
if (!_.isNil(this.expiredAt)) {
return next();
}
const object = this.preprocessObjectForIndex!();
const maxRetries = 3;
let retryCount = 0;

View file

@ -38,7 +38,7 @@ const transactionSchema: Schema<ITransaction> = new Schema(
},
model: {
type: String,
index: true
index: true,
},
context: {
type: String,

View file

@ -82,6 +82,8 @@ export interface AppConfig {
speech?: TCustomConfig['speech'];
/** MCP server configuration */
mcpConfig?: TCustomConfig['mcpServers'] | null;
/** MCP settings (domain allowlist, etc.) */
mcpSettings?: TCustomConfig['mcpSettings'] | null;
/** File configuration */
fileConfig?: TFileConfig;
/** Secure image links configuration */

View file

@ -8,6 +8,13 @@ export interface ISession extends Document {
export interface CreateSessionOptions {
expiration?: Date;
/** Duration in milliseconds for session expiry. Default: 7 days */
expiresIn?: number;
}
export interface UpdateExpirationOptions {
/** Duration in milliseconds for session expiry. Default: 7 days */
expiresIn?: number;
}
export interface SessionSearchParams {