Merge branch 'dev' into feat/context-window-ui

This commit is contained in:
Marco Beretta 2025-12-29 02:07:54 +01:00
commit cb8322ca85
No known key found for this signature in database
GPG key ID: D918033D8E74CC11
407 changed files with 25479 additions and 19894 deletions

View file

@ -1,6 +1,6 @@
{
"name": "@librechat/api",
"version": "1.7.0",
"version": "1.7.10",
"type": "commonjs",
"description": "MCP services for LibreChat",
"main": "dist/index.js",
@ -23,7 +23,8 @@
"test:cache-integration:core": "jest --testPathPatterns=\"src/cache/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false",
"test:cache-integration:cluster": "jest --testPathPatterns=\"src/cluster/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false --runInBand",
"test:cache-integration:mcp": "jest --testPathPatterns=\"src/mcp/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false",
"test:cache-integration": "npm run test:cache-integration:core && npm run test:cache-integration:cluster && npm run test:cache-integration:mcp",
"test:cache-integration:stream": "jest --testPathPatterns=\"src/stream/.*\\.stream_integration\\.spec\\.ts$\" --coverage=false --runInBand --forceExit",
"test:cache-integration": "npm run test:cache-integration:core && npm run test:cache-integration:cluster && npm run test:cache-integration:mcp && npm run test:cache-integration:stream",
"verify": "npm run test:ci",
"b:clean": "bun run rimraf dist",
"b:build": "bun run b:clean && bun run rollup -c --silent --bundleConfigAsCjs",
@ -78,15 +79,17 @@
"registry": "https://registry.npmjs.org/"
},
"peerDependencies": {
"@aws-sdk/client-bedrock-runtime": "^3.941.0",
"@aws-sdk/client-s3": "^3.758.0",
"@azure/identity": "^4.7.0",
"@azure/search-documents": "^12.0.0",
"@azure/storage-blob": "^12.27.0",
"@keyv/redis": "^4.3.3",
"@langchain/core": "^0.3.79",
"@librechat/agents": "^3.0.50",
"@langchain/core": "^0.3.80",
"@librechat/agents": "^3.0.61",
"@librechat/data-schemas": "*",
"@modelcontextprotocol/sdk": "^1.24.3",
"@modelcontextprotocol/sdk": "^1.25.1",
"@smithy/node-http-handler": "^4.4.5",
"axios": "^1.12.1",
"connect-redis": "^8.1.0",
"diff": "^7.0.0",
@ -95,12 +98,14 @@
"express-session": "^1.18.2",
"firebase": "^11.0.2",
"form-data": "^4.0.4",
"https-proxy-agent": "^7.0.6",
"ioredis": "^5.3.2",
"js-yaml": "^4.1.1",
"jsonwebtoken": "^9.0.0",
"keyv": "^5.3.2",
"keyv-file": "^5.1.2",
"librechat-data-provider": "*",
"mathjs": "^15.1.0",
"memorystore": "^1.6.7",
"mongoose": "^8.12.1",
"node-fetch": "2.7.0",

View file

@ -17,6 +17,7 @@ import type { TAttachment, MemoryArtifact } from 'librechat-data-provider';
import type { ObjectId, MemoryMethods } from '@librechat/data-schemas';
import type { BaseMessage, ToolMessage } from '@langchain/core/messages';
import type { Response as ServerResponse } from 'express';
import { GenerationJobManager } from '~/stream/GenerationJobManager';
import { Tokenizer } from '~/utils';
type RequiredMemoryMethods = Pick<
@ -250,6 +251,7 @@ export class BasicToolEndHandler implements EventHandler {
constructor(callback?: ToolEndCallback) {
this.callback = callback;
}
handle(
event: string,
data: StreamEventData | undefined,
@ -282,6 +284,7 @@ export async function processMemory({
llmConfig,
tokenLimit,
totalTokens = 0,
streamId = null,
}: {
res: ServerResponse;
setMemory: MemoryMethods['setMemory'];
@ -296,6 +299,7 @@ export async function processMemory({
tokenLimit?: number;
totalTokens?: number;
llmConfig?: Partial<LLMConfig>;
streamId?: string | null;
}): Promise<(TAttachment | null)[] | undefined> {
try {
const memoryTool = createMemoryTool({
@ -363,7 +367,7 @@ ${memory ?? 'No existing memories'}`;
}
const artifactPromises: Promise<TAttachment | null>[] = [];
const memoryCallback = createMemoryCallback({ res, artifactPromises });
const memoryCallback = createMemoryCallback({ res, artifactPromises, streamId });
const customHandlers = {
[GraphEvents.TOOL_END]: new BasicToolEndHandler(memoryCallback),
};
@ -416,6 +420,7 @@ export async function createMemoryProcessor({
memoryMethods,
conversationId,
config = {},
streamId = null,
}: {
res: ServerResponse;
messageId: string;
@ -423,6 +428,7 @@ export async function createMemoryProcessor({
userId: string | ObjectId;
memoryMethods: RequiredMemoryMethods;
config?: MemoryConfig;
streamId?: string | null;
}): Promise<[string, (messages: BaseMessage[]) => Promise<(TAttachment | null)[] | undefined>]> {
const { validKeys, instructions, llmConfig, tokenLimit } = config;
const finalInstructions = instructions || getDefaultInstructions(validKeys, tokenLimit);
@ -443,6 +449,7 @@ export async function createMemoryProcessor({
llmConfig,
messageId,
tokenLimit,
streamId,
conversationId,
memory: withKeys,
totalTokens: totalTokens || 0,
@ -461,10 +468,12 @@ async function handleMemoryArtifact({
res,
data,
metadata,
streamId = null,
}: {
res: ServerResponse;
data: ToolEndData;
metadata?: ToolEndMetadata;
streamId?: string | null;
}) {
const output = data?.output as ToolMessage | undefined;
if (!output) {
@ -490,7 +499,11 @@ async function handleMemoryArtifact({
if (!res.headersSent) {
return attachment;
}
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
if (streamId) {
GenerationJobManager.emitChunk(streamId, { event: 'attachment', data: attachment });
} else {
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
}
return attachment;
}
@ -499,14 +512,17 @@ async function handleMemoryArtifact({
* @param params - The parameters object
* @param params.res - The server response object
* @param params.artifactPromises - Array to collect artifact promises
* @param params.streamId - The stream ID for resumable mode, or null for standard mode
* @returns The memory callback function
*/
export function createMemoryCallback({
res,
artifactPromises,
streamId = null,
}: {
res: ServerResponse;
artifactPromises: Promise<Partial<TAttachment> | null>[];
streamId?: string | null;
}): ToolEndCallback {
return async (data: ToolEndData, metadata?: Record<string, unknown>) => {
const output = data?.output as ToolMessage | undefined;
@ -515,7 +531,7 @@ export function createMemoryCallback({
return;
}
artifactPromises.push(
handleMemoryArtifact({ res, data, metadata }).catch((error) => {
handleMemoryArtifact({ res, data, metadata, streamId }).catch((error) => {
logger.error('Error processing memory artifact content:', error);
return null;
}),

View file

@ -1,5 +1,10 @@
/* eslint-disable @typescript-eslint/ban-ts-comment */
import { isEmailDomainAllowed, isActionDomainAllowed } from './domain';
import {
isEmailDomainAllowed,
isActionDomainAllowed,
extractMCPServerDomain,
isMCPDomainAllowed,
} from './domain';
describe('isEmailDomainAllowed', () => {
afterEach(() => {
@ -213,3 +218,209 @@ describe('isActionDomainAllowed', () => {
});
});
});
describe('extractMCPServerDomain', () => {
afterEach(() => {
jest.clearAllMocks();
});
describe('URL extraction', () => {
it('should extract domain from HTTPS URL', () => {
const config = { url: 'https://api.example.com/sse' };
expect(extractMCPServerDomain(config)).toBe('api.example.com');
});
it('should extract domain from HTTP URL', () => {
const config = { url: 'http://api.example.com/sse' };
expect(extractMCPServerDomain(config)).toBe('api.example.com');
});
it('should extract domain from WebSocket URL', () => {
const config = { url: 'wss://ws.example.com' };
expect(extractMCPServerDomain(config)).toBe('ws.example.com');
});
it('should handle URL with port', () => {
const config = { url: 'https://localhost:3001/sse' };
expect(extractMCPServerDomain(config)).toBe('localhost');
});
it('should strip www prefix', () => {
const config = { url: 'https://www.example.com/api' };
expect(extractMCPServerDomain(config)).toBe('example.com');
});
it('should handle URL with path and query parameters', () => {
const config = { url: 'https://api.example.com/v1/sse?token=abc' };
expect(extractMCPServerDomain(config)).toBe('api.example.com');
});
});
describe('stdio transports (no URL)', () => {
it('should return null for stdio transport with command only', () => {
const config = { command: 'npx', args: ['-y', '@modelcontextprotocol/server-puppeteer'] };
expect(extractMCPServerDomain(config)).toBeNull();
});
it('should return null when url is undefined', () => {
const config = { command: 'node', args: ['server.js'] };
expect(extractMCPServerDomain(config)).toBeNull();
});
it('should return null for empty object', () => {
const config = {};
expect(extractMCPServerDomain(config)).toBeNull();
});
});
describe('invalid URLs', () => {
it('should return null for invalid URL format', () => {
const config = { url: 'not-a-valid-url' };
expect(extractMCPServerDomain(config)).toBeNull();
});
it('should return null for empty URL string', () => {
const config = { url: '' };
expect(extractMCPServerDomain(config)).toBeNull();
});
it('should return null for non-string url', () => {
const config = { url: 12345 };
expect(extractMCPServerDomain(config)).toBeNull();
});
it('should return null for null url', () => {
const config = { url: null };
expect(extractMCPServerDomain(config)).toBeNull();
});
});
});
describe('isMCPDomainAllowed', () => {
afterEach(() => {
jest.clearAllMocks();
});
describe('stdio transports (always allowed)', () => {
it('should allow stdio transport regardless of allowlist', async () => {
const config = { command: 'npx', args: ['-y', '@modelcontextprotocol/server-puppeteer'] };
expect(await isMCPDomainAllowed(config, ['example.com'])).toBe(true);
});
it('should allow stdio transport even with empty allowlist', async () => {
const config = { command: 'node', args: ['server.js'] };
expect(await isMCPDomainAllowed(config, [])).toBe(true);
});
it('should allow stdio transport when no URL present', async () => {
const config = {};
expect(await isMCPDomainAllowed(config, ['restricted.com'])).toBe(true);
});
});
describe('permissive defaults (no restrictions)', () => {
it('should allow all domains when allowedDomains is null', async () => {
const config = { url: 'https://any-domain.com/sse' };
expect(await isMCPDomainAllowed(config, null)).toBe(true);
});
it('should allow all domains when allowedDomains is undefined', async () => {
const config = { url: 'https://any-domain.com/sse' };
expect(await isMCPDomainAllowed(config, undefined)).toBe(true);
});
it('should allow all domains when allowedDomains is empty array', async () => {
const config = { url: 'https://any-domain.com/sse' };
expect(await isMCPDomainAllowed(config, [])).toBe(true);
});
});
describe('exact domain matching', () => {
const allowedDomains = ['example.com', 'localhost', 'trusted-mcp.com'];
it('should allow exact domain match', async () => {
const config = { url: 'https://example.com/api' };
expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(true);
});
it('should allow localhost', async () => {
const config = { url: 'http://localhost:3001/sse' };
expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(true);
});
it('should reject non-allowed domain', async () => {
const config = { url: 'https://malicious.com/sse' };
expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(false);
});
it('should reject subdomain when only parent is allowed', async () => {
const config = { url: 'https://api.example.com/sse' };
expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(false);
});
});
describe('wildcard domain matching', () => {
const allowedDomains = ['*.example.com', 'localhost'];
it('should allow subdomain with wildcard', async () => {
const config = { url: 'https://api.example.com/sse' };
expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(true);
});
it('should allow any subdomain with wildcard', async () => {
const config = { url: 'https://staging.example.com/sse' };
expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(true);
});
it('should allow base domain with wildcard', async () => {
const config = { url: 'https://example.com/sse' };
expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(true);
});
it('should allow nested subdomain with wildcard', async () => {
const config = { url: 'https://deep.nested.example.com/sse' };
expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(true);
});
it('should reject different domain even with wildcard', async () => {
const config = { url: 'https://api.other.com/sse' };
expect(await isMCPDomainAllowed(config, allowedDomains)).toBe(false);
});
});
describe('case insensitivity', () => {
it('should match domains case-insensitively', async () => {
const config = { url: 'https://EXAMPLE.COM/sse' };
expect(await isMCPDomainAllowed(config, ['example.com'])).toBe(true);
});
it('should match with uppercase in allowlist', async () => {
const config = { url: 'https://example.com/sse' };
expect(await isMCPDomainAllowed(config, ['EXAMPLE.COM'])).toBe(true);
});
it('should match with mixed case', async () => {
const config = { url: 'https://Api.Example.Com/sse' };
expect(await isMCPDomainAllowed(config, ['*.example.com'])).toBe(true);
});
});
describe('www prefix handling', () => {
it('should strip www prefix from URL before matching', async () => {
const config = { url: 'https://www.example.com/sse' };
expect(await isMCPDomainAllowed(config, ['example.com'])).toBe(true);
});
it('should match www in allowlist to non-www URL', async () => {
const config = { url: 'https://example.com/sse' };
expect(await isMCPDomainAllowed(config, ['www.example.com'])).toBe(true);
});
});
describe('invalid URL handling', () => {
it('should allow config with invalid URL (treated as stdio)', async () => {
const config = { url: 'not-a-valid-url' };
expect(await isMCPDomainAllowed(config, ['example.com'])).toBe(true);
});
});
});

View file

@ -96,3 +96,45 @@ export async function isActionDomainAllowed(
return false;
}
/**
* Extracts domain from MCP server config URL.
* Returns null for stdio transports (no URL) or invalid URLs.
* @param config - MCP server configuration (accepts any config with optional url field)
*/
export function extractMCPServerDomain(config: Record<string, unknown>): string | null {
const url = config.url;
// Stdio transports don't have URLs - always allowed
if (!url || typeof url !== 'string') {
return null;
}
try {
const parsedUrl = new URL(url);
return parsedUrl.hostname.replace(/^www\./i, '');
} catch {
return null;
}
}
/**
* Validates MCP server domain against allowedDomains.
* Reuses isActionDomainAllowed for consistent validation logic.
* Stdio transports (no URL) are always allowed.
* @param config - MCP server configuration with optional url field
* @param allowedDomains - List of allowed domains (with wildcard support)
*/
export async function isMCPDomainAllowed(
config: Record<string, unknown>,
allowedDomains?: string[] | null,
): Promise<boolean> {
const domain = extractMCPServerDomain(config);
// Stdio transports don't have domains - always allowed
if (!domain) {
return true;
}
// Reuse existing validation logic (includes wildcard support)
return isActionDomainAllowed(domain, allowedDomains);
}

View file

@ -10,6 +10,7 @@ describe('cacheConfig', () => {
delete process.env.REDIS_KEY_PREFIX_VAR;
delete process.env.REDIS_KEY_PREFIX;
delete process.env.USE_REDIS;
delete process.env.USE_REDIS_STREAMS;
delete process.env.USE_REDIS_CLUSTER;
delete process.env.REDIS_PING_INTERVAL;
delete process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES;
@ -130,6 +131,53 @@ describe('cacheConfig', () => {
});
});
describe('USE_REDIS_STREAMS configuration', () => {
test('should default to USE_REDIS value when USE_REDIS_STREAMS is not set', async () => {
process.env.USE_REDIS = 'true';
process.env.REDIS_URI = 'redis://localhost:6379';
const { cacheConfig } = await import('../cacheConfig');
expect(cacheConfig.USE_REDIS).toBe(true);
expect(cacheConfig.USE_REDIS_STREAMS).toBe(true);
});
test('should default to false when both USE_REDIS and USE_REDIS_STREAMS are not set', async () => {
const { cacheConfig } = await import('../cacheConfig');
expect(cacheConfig.USE_REDIS).toBe(false);
expect(cacheConfig.USE_REDIS_STREAMS).toBe(false);
});
test('should be false when explicitly set to false even if USE_REDIS is true', async () => {
process.env.USE_REDIS = 'true';
process.env.USE_REDIS_STREAMS = 'false';
process.env.REDIS_URI = 'redis://localhost:6379';
const { cacheConfig } = await import('../cacheConfig');
expect(cacheConfig.USE_REDIS).toBe(true);
expect(cacheConfig.USE_REDIS_STREAMS).toBe(false);
});
test('should be true when explicitly set to true', async () => {
process.env.USE_REDIS = 'true';
process.env.USE_REDIS_STREAMS = 'true';
process.env.REDIS_URI = 'redis://localhost:6379';
const { cacheConfig } = await import('../cacheConfig');
expect(cacheConfig.USE_REDIS_STREAMS).toBe(true);
});
test('should allow streams without general Redis (explicit override)', async () => {
// Edge case: someone might want streams with Redis but not general caching
// This would require REDIS_URI but not USE_REDIS
process.env.USE_REDIS_STREAMS = 'true';
process.env.REDIS_URI = 'redis://localhost:6379';
const { cacheConfig } = await import('../cacheConfig');
expect(cacheConfig.USE_REDIS).toBe(false);
expect(cacheConfig.USE_REDIS_STREAMS).toBe(true);
});
});
describe('REDIS_CA file reading', () => {
test('should be null when REDIS_CA is not set', async () => {
const { cacheConfig } = await import('../cacheConfig');

View file

@ -17,6 +17,14 @@ if (USE_REDIS && !process.env.REDIS_URI) {
throw new Error('USE_REDIS is enabled but REDIS_URI is not set.');
}
// USE_REDIS_STREAMS controls whether Redis is used for resumable stream job storage.
// Defaults to true if USE_REDIS is enabled but USE_REDIS_STREAMS is not explicitly set.
// Set to 'false' to use in-memory storage for streams while keeping Redis for other caches.
const USE_REDIS_STREAMS =
process.env.USE_REDIS_STREAMS !== undefined
? isEnabled(process.env.USE_REDIS_STREAMS)
: USE_REDIS;
// Comma-separated list of cache namespaces that should be forced to use in-memory storage
// even when Redis is enabled. This allows selective performance optimization for specific caches.
const FORCED_IN_MEMORY_CACHE_NAMESPACES = process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES
@ -60,6 +68,7 @@ const getRedisCA = (): string | null => {
const cacheConfig = {
FORCED_IN_MEMORY_CACHE_NAMESPACES,
USE_REDIS,
USE_REDIS_STREAMS,
REDIS_URI: process.env.REDIS_URI,
REDIS_USERNAME: process.env.REDIS_USERNAME,
REDIS_PASSWORD: process.env.REDIS_PASSWORD,
@ -112,6 +121,13 @@ const cacheConfig = {
* @default 1000
*/
REDIS_SCAN_COUNT: math(process.env.REDIS_SCAN_COUNT, 1000),
/**
* TTL in milliseconds for MCP registry read-through cache.
* This cache reduces redundant lookups within a single request flow.
* @default 5000 (5 seconds)
*/
MCP_REGISTRY_CACHE_TTL: math(process.env.MCP_REGISTRY_CACHE_TTL, 5000),
};
export { cacheConfig };

View file

@ -130,7 +130,7 @@ export async function fetchModels({
const options: {
headers: Record<string, string>;
timeout: number;
httpsAgent?: HttpsProxyAgent;
httpsAgent?: HttpsProxyAgent<string>;
} = {
headers: {
...(headers ?? {}),

View file

@ -79,6 +79,21 @@ export async function encodeAndFormatAudios(
mimeType: file.type,
data: content,
});
} else if (provider === Providers.OPENROUTER) {
// Extract format from filename extension (e.g., 'audio.mp3' -> 'mp3')
// OpenRouter expects format values like: wav, mp3, aiff, aac, ogg, flac, m4a, pcm16, pcm24
// Note: MIME types don't always match (e.g., 'audio/mpeg' is mp3, not mpeg), so that is why we are using the file extension instead
const format = file.filename.split('.').pop()?.toLowerCase();
if (!format) {
throw new Error(`Could not extract audio format from filename: ${file.filename}`);
}
result.audios.push({
type: 'input_audio',
input_audio: {
data: content,
format,
},
});
}
result.files.push(metadata);

View file

@ -79,6 +79,13 @@ export async function encodeAndFormatVideos(
mimeType: file.type,
data: content,
});
} else if (provider === Providers.OPENROUTER) {
result.videos.push({
type: 'video_url',
video_url: {
url: `data:${file.type};base64,${content}`,
},
});
}
result.files.push(metadata);

View file

@ -129,22 +129,61 @@ export class FlowStateManager<T = unknown> {
return new Promise<T>((resolve, reject) => {
const checkInterval = 2000;
let elapsedTime = 0;
let isCleanedUp = false;
let intervalId: NodeJS.Timeout | null = null;
// Cleanup function to avoid duplicate cleanup
const cleanup = () => {
if (isCleanedUp) return;
isCleanedUp = true;
if (intervalId) {
clearInterval(intervalId);
this.intervals.delete(intervalId);
}
if (signal && abortHandler) {
signal.removeEventListener('abort', abortHandler);
}
};
// Immediate abort handler - responds instantly to abort signal
const abortHandler = async () => {
cleanup();
logger.warn(`[${flowKey}] Flow aborted (immediate)`);
const message = `${type} flow aborted`;
try {
await this.keyv.delete(flowKey);
} catch {
// Ignore delete errors during abort
}
reject(new Error(message));
};
// Register abort handler immediately if signal provided
if (signal) {
if (signal.aborted) {
// Already aborted, reject immediately
cleanup();
reject(new Error(`${type} flow aborted`));
return;
}
signal.addEventListener('abort', abortHandler, { once: true });
}
intervalId = setInterval(async () => {
if (isCleanedUp) return;
const intervalId = setInterval(async () => {
try {
const flowState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined;
if (!flowState) {
clearInterval(intervalId);
this.intervals.delete(intervalId);
cleanup();
logger.error(`[${flowKey}] Flow state not found`);
reject(new Error(`${type} Flow state not found`));
return;
}
if (signal?.aborted) {
clearInterval(intervalId);
this.intervals.delete(intervalId);
cleanup();
logger.warn(`[${flowKey}] Flow aborted`);
const message = `${type} flow aborted`;
await this.keyv.delete(flowKey);
@ -153,8 +192,7 @@ export class FlowStateManager<T = unknown> {
}
if (flowState.status !== 'PENDING') {
clearInterval(intervalId);
this.intervals.delete(intervalId);
cleanup();
logger.debug(`[${flowKey}] Flow completed`);
if (flowState.status === 'COMPLETED' && flowState.result !== undefined) {
@ -168,8 +206,7 @@ export class FlowStateManager<T = unknown> {
elapsedTime += checkInterval;
if (elapsedTime >= this.ttl) {
clearInterval(intervalId);
this.intervals.delete(intervalId);
cleanup();
logger.error(
`[${flowKey}] Flow timed out | Elapsed time: ${elapsedTime} | TTL: ${this.ttl}`,
);
@ -179,8 +216,7 @@ export class FlowStateManager<T = unknown> {
logger.debug(`[${flowKey}] Flow state elapsed time: ${elapsedTime}, checking again...`);
} catch (error) {
logger.error(`[${flowKey}] Error checking flow state:`, error);
clearInterval(intervalId);
this.intervals.delete(intervalId);
cleanup();
reject(error);
}
}, checkInterval);

View file

@ -9,6 +9,7 @@ export * from './mcp/connection';
export * from './mcp/oauth';
export * from './mcp/auth';
export * from './mcp/zod';
export * from './mcp/errors';
/* Utilities */
export * from './mcp/utils';
export * from './utils';
@ -38,6 +39,8 @@ export * from './tools';
export * from './web';
/* Cache */
export * from './cache';
/* Stream */
export * from './stream';
/* types */
export type * from './mcp/types';
export type * from './flow/types';

View file

@ -1,7 +1,7 @@
import { logger } from '@librechat/data-schemas';
import type { OAuthClientInformation } from '@modelcontextprotocol/sdk/shared/auth.js';
import type { TokenMethods } from '@librechat/data-schemas';
import type { MCPOAuthTokens, MCPOAuthFlowMetadata, OAuthMetadata } from '~/mcp/oauth';
import type { MCPOAuthTokens, OAuthMetadata } from '~/mcp/oauth';
import type { FlowStateManager } from '~/flow/manager';
import type { FlowMetadata } from '~/flow/types';
import type * as t from './types';
@ -173,9 +173,10 @@ export class MCPConnectionFactory {
// Create the flow state so the OAuth callback can find it
// We spawn this in the background without waiting for it
this.flowManager!.createFlow(flowId, 'mcp_oauth', flowMetadata).catch(() => {
// Pass signal so the flow can be aborted if the request is cancelled
this.flowManager!.createFlow(flowId, 'mcp_oauth', flowMetadata, this.signal).catch(() => {
// The OAuth callback will resolve this flow, so we expect it to timeout here
// which is fine - we just need the flow state to exist
// or it will be aborted if the request is cancelled - both are fine
});
if (this.oauthStart) {
@ -354,56 +355,26 @@ export class MCPConnectionFactory {
/** Check if there's already an ongoing OAuth flow for this flowId */
const existingFlow = await this.flowManager.getFlowState(flowId, 'mcp_oauth');
if (existingFlow && existingFlow.status === 'PENDING') {
// If any flow exists (PENDING, COMPLETED, FAILED), cancel it and start fresh
// This ensures the user always gets a new OAuth URL instead of waiting for stale flows
if (existingFlow) {
logger.debug(
`${this.logPrefix} OAuth flow already exists for ${flowId}, waiting for completion`,
`${this.logPrefix} Found existing OAuth flow (status: ${existingFlow.status}), cancelling to start fresh`,
);
/** Tokens from existing flow to complete */
const tokens = await this.flowManager.createFlow(flowId, 'mcp_oauth');
if (typeof this.oauthEnd === 'function') {
await this.oauthEnd();
}
logger.info(
`${this.logPrefix} OAuth flow completed, tokens received for ${this.serverName}`,
);
/** Client information from the existing flow metadata */
const existingMetadata = existingFlow.metadata as unknown as MCPOAuthFlowMetadata;
const clientInfo = existingMetadata?.clientInfo;
return { tokens, clientInfo };
}
// Clean up old completed/failed flows, but only if they're actually stale
// This prevents race conditions where we delete a flow that's still being processed
if (existingFlow && existingFlow.status !== 'PENDING') {
const STALE_FLOW_THRESHOLD = 2 * 60 * 1000; // 2 minutes
const { isStale, age, status } = await this.flowManager.isFlowStale(
flowId,
'mcp_oauth',
STALE_FLOW_THRESHOLD,
);
if (isStale) {
try {
try {
if (existingFlow.status === 'PENDING') {
await this.flowManager.failFlow(
flowId,
'mcp_oauth',
new Error('Cancelled for new OAuth request'),
);
} else {
await this.flowManager.deleteFlow(flowId, 'mcp_oauth');
logger.debug(
`${this.logPrefix} Cleared stale ${status} OAuth flow (age: ${Math.round(age / 1000)}s)`,
);
} catch (error) {
logger.warn(`${this.logPrefix} Failed to clear stale OAuth flow`, error);
}
} else {
logger.debug(
`${this.logPrefix} Skipping cleanup of recent ${status} flow (age: ${Math.round(age / 1000)}s, threshold: ${STALE_FLOW_THRESHOLD / 1000}s)`,
);
// If flow is recent but not pending, something might be wrong
if (status === 'FAILED') {
logger.warn(
`${this.logPrefix} Recent OAuth flow failed, will retry after ${Math.round((STALE_FLOW_THRESHOLD - age) / 1000)}s`,
);
}
} catch (error) {
logger.warn(`${this.logPrefix} Failed to cancel existing OAuth flow`, error);
}
// Continue to start a new flow below
}
logger.debug(`${this.logPrefix} Initiating new OAuth flow for ${this.serverName}...`);

View file

@ -188,3 +188,344 @@ describe('MCPConnection Error Detection', () => {
});
});
});
/**
* Tests for extractSSEErrorMessage function.
* This function extracts meaningful error messages from SSE transport errors,
* particularly handling the "SSE error: undefined" case from the MCP SDK.
*/
describe('extractSSEErrorMessage', () => {
/**
* Standalone implementation of extractSSEErrorMessage for testing.
* This mirrors the function in connection.ts.
* Keep in sync with the actual implementation.
*/
function extractSSEErrorMessage(error: unknown): {
message: string;
code?: number;
isProxyHint: boolean;
isTransient: boolean;
} {
if (!error || typeof error !== 'object') {
return {
message: 'Unknown SSE transport error',
isProxyHint: true,
isTransient: true,
};
}
const errorObj = error as { message?: string; code?: number; event?: unknown };
const rawMessage = errorObj.message ?? '';
const code = errorObj.code;
// Handle the common "SSE error: undefined" case
if (rawMessage === 'SSE error: undefined' || rawMessage === 'undefined' || !rawMessage) {
return {
message:
'SSE connection closed. This can occur due to: (1) idle connection timeout (normal), ' +
'(2) reverse proxy buffering (check proxy_buffering config), or (3) network interruption.',
code,
isProxyHint: true,
isTransient: true,
};
}
// Check for timeout patterns with case-insensitive matching
const lowerMessage = rawMessage.toLowerCase();
if (
rawMessage.includes('ETIMEDOUT') ||
rawMessage.includes('ESOCKETTIMEDOUT') ||
lowerMessage.includes('timed out') ||
lowerMessage.includes('timeout after') ||
lowerMessage.includes('request timeout')
) {
return {
message: `SSE connection timed out: ${rawMessage}. If behind a reverse proxy, increase proxy_read_timeout.`,
code,
isProxyHint: true,
isTransient: true,
};
}
// Connection reset is often transient
if (rawMessage.includes('ECONNRESET')) {
return {
message: `SSE connection reset: ${rawMessage}. The server or proxy may have restarted.`,
code,
isProxyHint: false,
isTransient: true,
};
}
// Connection refused is more serious
if (rawMessage.includes('ECONNREFUSED')) {
return {
message: `SSE connection refused: ${rawMessage}. Verify the MCP server is running and accessible.`,
code,
isProxyHint: false,
isTransient: false,
};
}
// DNS failure
if (rawMessage.includes('ENOTFOUND') || rawMessage.includes('getaddrinfo')) {
return {
message: `SSE DNS resolution failed: ${rawMessage}. Check the server URL is correct.`,
code,
isProxyHint: false,
isTransient: false,
};
}
// Check for HTTP status codes
const statusMatch = rawMessage.match(/\b(4\d{2}|5\d{2})\b/);
if (statusMatch) {
const statusCode = parseInt(statusMatch[1], 10);
const isServerError = statusCode >= 500 && statusCode < 600;
return {
message: rawMessage,
code: statusCode,
isProxyHint: statusCode === 502 || statusCode === 503 || statusCode === 504,
isTransient: isServerError,
};
}
return {
message: rawMessage,
code,
isProxyHint: false,
isTransient: false,
};
}
describe('undefined/empty error handling', () => {
it('should handle "SSE error: undefined" from MCP SDK', () => {
const error = { message: 'SSE error: undefined', code: undefined };
const result = extractSSEErrorMessage(error);
expect(result.message).toContain('SSE connection closed');
expect(result.isProxyHint).toBe(true);
expect(result.isTransient).toBe(true);
});
it('should handle empty message', () => {
const error = { message: '' };
const result = extractSSEErrorMessage(error);
expect(result.message).toContain('SSE connection closed');
expect(result.isTransient).toBe(true);
});
it('should handle message "undefined"', () => {
const error = { message: 'undefined' };
const result = extractSSEErrorMessage(error);
expect(result.message).toContain('SSE connection closed');
expect(result.isTransient).toBe(true);
});
it('should handle null error', () => {
const result = extractSSEErrorMessage(null);
expect(result.message).toBe('Unknown SSE transport error');
expect(result.isTransient).toBe(true);
});
it('should handle undefined error', () => {
const result = extractSSEErrorMessage(undefined);
expect(result.message).toBe('Unknown SSE transport error');
expect(result.isTransient).toBe(true);
});
it('should handle non-object error', () => {
const result = extractSSEErrorMessage('string error');
expect(result.message).toBe('Unknown SSE transport error');
expect(result.isTransient).toBe(true);
});
});
describe('timeout errors', () => {
it('should detect ETIMEDOUT', () => {
const error = { message: 'connect ETIMEDOUT 1.2.3.4:443' };
const result = extractSSEErrorMessage(error);
expect(result.message).toContain('SSE connection timed out');
expect(result.message).toContain('proxy_read_timeout');
expect(result.isProxyHint).toBe(true);
expect(result.isTransient).toBe(true);
});
it('should detect ESOCKETTIMEDOUT', () => {
const error = { message: 'ESOCKETTIMEDOUT' };
const result = extractSSEErrorMessage(error);
expect(result.message).toContain('SSE connection timed out');
expect(result.isTransient).toBe(true);
});
it('should detect "timed out" (case insensitive)', () => {
const error = { message: 'Connection Timed Out' };
const result = extractSSEErrorMessage(error);
expect(result.message).toContain('SSE connection timed out');
expect(result.isTransient).toBe(true);
});
it('should detect "timeout after"', () => {
const error = { message: 'Request timeout after 60000ms' };
const result = extractSSEErrorMessage(error);
expect(result.message).toContain('SSE connection timed out');
expect(result.isTransient).toBe(true);
});
it('should detect "request timeout"', () => {
const error = { message: 'Request Timeout' };
const result = extractSSEErrorMessage(error);
expect(result.message).toContain('SSE connection timed out');
expect(result.isTransient).toBe(true);
});
it('should NOT match "timeout" in unrelated context', () => {
// URL containing "timeout" should not trigger timeout detection
const error = { message: 'Failed to connect to https://api.example.com/timeout-settings' };
const result = extractSSEErrorMessage(error);
expect(result.message).not.toContain('SSE connection timed out');
expect(result.message).toBe('Failed to connect to https://api.example.com/timeout-settings');
});
});
describe('connection errors', () => {
it('should detect ECONNRESET as transient', () => {
const error = { message: 'read ECONNRESET' };
const result = extractSSEErrorMessage(error);
expect(result.message).toContain('SSE connection reset');
expect(result.isProxyHint).toBe(false);
expect(result.isTransient).toBe(true);
});
it('should detect ECONNREFUSED as non-transient', () => {
const error = { message: 'connect ECONNREFUSED 127.0.0.1:8080' };
const result = extractSSEErrorMessage(error);
expect(result.message).toContain('SSE connection refused');
expect(result.message).toContain('Verify the MCP server is running');
expect(result.isTransient).toBe(false);
});
});
describe('DNS errors', () => {
it('should detect ENOTFOUND', () => {
const error = { message: 'getaddrinfo ENOTFOUND unknown.host.com' };
const result = extractSSEErrorMessage(error);
expect(result.message).toContain('SSE DNS resolution failed');
expect(result.message).toContain('Check the server URL');
expect(result.isTransient).toBe(false);
});
it('should detect getaddrinfo errors', () => {
const error = { message: 'getaddrinfo EAI_AGAIN example.com' };
const result = extractSSEErrorMessage(error);
expect(result.message).toContain('SSE DNS resolution failed');
expect(result.isTransient).toBe(false);
});
});
describe('HTTP status code errors', () => {
it('should detect 502 as proxy hint and transient', () => {
const error = { message: 'Non-200 status code (502): Bad Gateway' };
const result = extractSSEErrorMessage(error);
expect(result.code).toBe(502);
expect(result.isProxyHint).toBe(true);
expect(result.isTransient).toBe(true);
});
it('should detect 503 as proxy hint and transient', () => {
const error = { message: 'Error: Service Unavailable (503)' };
const result = extractSSEErrorMessage(error);
expect(result.code).toBe(503);
expect(result.isProxyHint).toBe(true);
expect(result.isTransient).toBe(true);
});
it('should detect 504 as proxy hint and transient', () => {
const error = { message: 'Gateway Timeout 504' };
const result = extractSSEErrorMessage(error);
expect(result.code).toBe(504);
expect(result.isProxyHint).toBe(true);
expect(result.isTransient).toBe(true);
});
it('should detect 500 as transient but not proxy hint', () => {
const error = { message: 'Internal Server Error (500)' };
const result = extractSSEErrorMessage(error);
expect(result.code).toBe(500);
expect(result.isProxyHint).toBe(false);
expect(result.isTransient).toBe(true);
});
it('should detect 404 as non-transient', () => {
const error = { message: 'Not Found (404)' };
const result = extractSSEErrorMessage(error);
expect(result.code).toBe(404);
expect(result.isProxyHint).toBe(false);
expect(result.isTransient).toBe(false);
});
it('should detect 401 as non-transient', () => {
const error = { message: 'Unauthorized (401)' };
const result = extractSSEErrorMessage(error);
expect(result.code).toBe(401);
expect(result.isTransient).toBe(false);
});
});
describe('SseError from MCP SDK', () => {
it('should handle SseError with event property', () => {
const error = {
message: 'SSE error: undefined',
code: undefined,
event: { type: 'error', code: undefined, message: undefined },
};
const result = extractSSEErrorMessage(error);
expect(result.message).toContain('SSE connection closed');
expect(result.isTransient).toBe(true);
});
it('should preserve code from SseError', () => {
const error = {
message: 'SSE error: Server sent HTTP 204, not reconnecting',
code: 204,
};
const result = extractSSEErrorMessage(error);
expect(result.code).toBe(204);
});
});
describe('regular error messages', () => {
it('should pass through regular error messages', () => {
const error = { message: 'Some specific error message', code: 42 };
const result = extractSSEErrorMessage(error);
expect(result.message).toBe('Some specific error message');
expect(result.code).toBe(42);
expect(result.isProxyHint).toBe(false);
expect(result.isTransient).toBe(false);
});
});
});

View file

@ -331,12 +331,14 @@ describe('MCPConnectionFactory', () => {
expect(deleteCallOrder).toBeLessThan(createCallOrder);
// Verify createFlow was called with fresh metadata
// 4th arg is the abort signal (undefined in this test since no signal was provided)
expect(mockFlowManager.createFlow).toHaveBeenCalledWith(
'user123:test-server',
'mcp_oauth',
expect.objectContaining({
codeVerifier: 'new-code-verifier-xyz',
}),
undefined,
);
});
});

View file

@ -68,6 +68,145 @@ function isStreamableHTTPOptions(options: t.MCPOptions): options is t.Streamable
const FIVE_MINUTES = 5 * 60 * 1000;
const DEFAULT_TIMEOUT = 60000;
/** SSE connections through proxies may need longer initial handshake time */
const SSE_CONNECT_TIMEOUT = 120000;
/**
* Headers for SSE connections.
*
* Headers we intentionally DO NOT include:
* - Accept: text/event-stream - Already set by eventsource library AND MCP SDK
* - X-Accel-Buffering: This is a RESPONSE header for Nginx, not a request header.
* The upstream MCP server must send this header for Nginx to respect it.
* - Connection: keep-alive: Forbidden in HTTP/2 (RFC 7540 §8.1.2.2).
* HTTP/2 manages connection persistence differently.
*/
const SSE_REQUEST_HEADERS = {
'Cache-Control': 'no-cache',
};
/**
* Extracts a meaningful error message from SSE transport errors.
* The MCP SDK's SSEClientTransport can produce "SSE error: undefined" when the
* underlying eventsource library encounters connection issues without a specific message.
*
* @returns Object containing:
* - message: Human-readable error description
* - code: HTTP status code if available
* - isProxyHint: Whether this error suggests proxy misconfiguration
* - isTransient: Whether this is likely a transient error that will auto-reconnect
*/
function extractSSEErrorMessage(error: unknown): {
message: string;
code?: number;
isProxyHint: boolean;
isTransient: boolean;
} {
if (!error || typeof error !== 'object') {
return {
message: 'Unknown SSE transport error',
isProxyHint: true,
isTransient: true,
};
}
const errorObj = error as { message?: string; code?: number; event?: unknown };
const rawMessage = errorObj.message ?? '';
const code = errorObj.code;
/**
* Handle the common "SSE error: undefined" case.
* This typically occurs when:
* 1. A reverse proxy buffers the SSE stream (proxy issue)
* 2. The server closes an idle connection (normal SSE behavior)
* 3. Network interruption without specific error details
*
* In all cases, the eventsource library will attempt to reconnect automatically.
*/
if (rawMessage === 'SSE error: undefined' || rawMessage === 'undefined' || !rawMessage) {
return {
message:
'SSE connection closed. This can occur due to: (1) idle connection timeout (normal), ' +
'(2) reverse proxy buffering (check proxy_buffering config), or (3) network interruption.',
code,
isProxyHint: true,
isTransient: true,
};
}
/**
* Check for timeout patterns. Use case-insensitive matching for common timeout error codes:
* - ETIMEDOUT: TCP connection timeout
* - ESOCKETTIMEDOUT: Socket timeout
* - "timed out" / "timeout": Generic timeout messages
*/
const lowerMessage = rawMessage.toLowerCase();
if (
rawMessage.includes('ETIMEDOUT') ||
rawMessage.includes('ESOCKETTIMEDOUT') ||
lowerMessage.includes('timed out') ||
lowerMessage.includes('timeout after') ||
lowerMessage.includes('request timeout')
) {
return {
message: `SSE connection timed out: ${rawMessage}. If behind a reverse proxy, increase proxy_read_timeout.`,
code,
isProxyHint: true,
isTransient: true,
};
}
// Connection reset is often transient (server restart, proxy reload)
if (rawMessage.includes('ECONNRESET')) {
return {
message: `SSE connection reset: ${rawMessage}. The server or proxy may have restarted.`,
code,
isProxyHint: false,
isTransient: true,
};
}
// Connection refused is more serious - server may be down
if (rawMessage.includes('ECONNREFUSED')) {
return {
message: `SSE connection refused: ${rawMessage}. Verify the MCP server is running and accessible.`,
code,
isProxyHint: false,
isTransient: false,
};
}
// DNS failure is usually a configuration issue, not transient
if (rawMessage.includes('ENOTFOUND') || rawMessage.includes('getaddrinfo')) {
return {
message: `SSE DNS resolution failed: ${rawMessage}. Check the server URL is correct.`,
code,
isProxyHint: false,
isTransient: false,
};
}
// Check for HTTP status codes in the message
const statusMatch = rawMessage.match(/\b(4\d{2}|5\d{2})\b/);
if (statusMatch) {
const statusCode = parseInt(statusMatch[1], 10);
// 5xx errors are often transient, 4xx are usually not
const isServerError = statusCode >= 500 && statusCode < 600;
return {
message: rawMessage,
code: statusCode,
isProxyHint: statusCode === 502 || statusCode === 503 || statusCode === 504,
isTransient: isServerError,
};
}
return {
message: rawMessage,
code,
isProxyHint: false,
isTransient: false,
};
}
interface MCPConnectionParams {
serverName: string;
@ -258,18 +397,29 @@ export class MCPConnection extends EventEmitter {
headers['Authorization'] = `Bearer ${this.oauthTokens.access_token}`;
}
const timeoutValue = this.timeout || DEFAULT_TIMEOUT;
/**
* SSE connections need longer timeouts for reliability.
* The connect timeout is extended because proxies may delay initial response.
*/
const sseTimeout = this.timeout || SSE_CONNECT_TIMEOUT;
const transport = new SSEClientTransport(url, {
requestInit: {
headers,
/** User/OAuth headers override SSE defaults */
headers: { ...SSE_REQUEST_HEADERS, ...headers },
signal: abortController.signal,
},
eventSourceInit: {
fetch: (url, init) => {
const fetchHeaders = new Headers(Object.assign({}, init?.headers, headers));
/** Merge headers: SSE defaults < init headers < user headers (user wins) */
const fetchHeaders = new Headers(
Object.assign({}, SSE_REQUEST_HEADERS, init?.headers, headers),
);
const agent = new Agent({
bodyTimeout: timeoutValue,
headersTimeout: timeoutValue,
bodyTimeout: sseTimeout,
headersTimeout: sseTimeout,
/** Extended keep-alive for long-lived SSE connections */
keepAliveTimeout: sseTimeout,
keepAliveMaxTimeout: sseTimeout * 2,
});
return undiciFetch(url, {
...init,
@ -280,7 +430,7 @@ export class MCPConnection extends EventEmitter {
},
fetch: this.createFetchFunction(
this.getRequestHeaders.bind(this),
this.timeout,
sseTimeout,
) as unknown as FetchLike,
});
@ -639,26 +789,70 @@ export class MCPConnection extends EventEmitter {
private setupTransportErrorHandlers(transport: Transport): void {
transport.onerror = (error) => {
if (error && typeof error === 'object' && 'code' in error) {
const errorCode = (error as unknown as { code?: number }).code;
// Extract meaningful error information (handles "SSE error: undefined" cases)
const {
message: errorMessage,
code: errorCode,
isProxyHint,
isTransient,
} = extractSSEErrorMessage(error);
// Ignore SSE 404 errors for servers that don't support SSE
if (
errorCode === 404 &&
String(error?.message).toLowerCase().includes('failed to open sse stream')
) {
logger.warn(`${this.getLogPrefix()} SSE stream not available (404). Ignoring.`);
return;
// Ignore SSE 404 errors for servers that don't support SSE
if (errorCode === 404 && errorMessage.toLowerCase().includes('failed to open sse stream')) {
logger.warn(`${this.getLogPrefix()} SSE stream not available (404). Ignoring.`);
return;
}
// Check if it's an OAuth authentication error
if (this.isOAuthError(error)) {
logger.warn(`${this.getLogPrefix()} OAuth authentication error detected`);
this.emit('oauthError', error);
}
/**
* Log with enhanced context for debugging.
* All transport.onerror events are logged as errors to preserve stack traces.
* isTransient indicates whether auto-reconnection is expected to succeed.
*
* The MCP SDK's SseError extends Error and includes:
* - code: HTTP status code or eventsource error code
* - event: The original eventsource ErrorEvent
* - stack: Full stack trace
*/
const errorContext: Record<string, unknown> = {
code: errorCode,
isTransient,
};
if (isProxyHint) {
errorContext.hint = 'Check Nginx/proxy configuration for SSE endpoints';
}
// Extract additional debug info from SseError if available
if (error && typeof error === 'object') {
const sseError = error as { event?: unknown; stack?: string };
// Include the original eventsource event for debugging
if (sseError.event && typeof sseError.event === 'object') {
const event = sseError.event as { code?: number; message?: string; type?: string };
errorContext.eventDetails = {
type: event.type,
code: event.code,
message: event.message,
};
}
// Check if it's an OAuth authentication error
if (this.isOAuthError(error)) {
logger.warn(`${this.getLogPrefix()} OAuth authentication error detected`);
this.emit('oauthError', error);
// Include stack trace if available
if (sseError.stack) {
errorContext.stack = sseError.stack;
}
}
logger.error(`${this.getLogPrefix()} Transport error:`, error);
const errorLabel = isTransient
? 'Transport error (transient, will reconnect)'
: 'Transport error (may require manual intervention)';
logger.error(`${this.getLogPrefix()} ${errorLabel}: ${errorMessage}`, errorContext);
this.emit('connectionChange', 'error');
};

View file

@ -0,0 +1,61 @@
/**
* MCP-specific error classes
*/
export const MCPErrorCodes = {
DOMAIN_NOT_ALLOWED: 'MCP_DOMAIN_NOT_ALLOWED',
INSPECTION_FAILED: 'MCP_INSPECTION_FAILED',
} as const;
export type MCPErrorCode = (typeof MCPErrorCodes)[keyof typeof MCPErrorCodes];
/**
* Custom error for MCP domain restriction violations.
* Thrown when a user attempts to connect to an MCP server whose domain is not in the allowlist.
*/
export class MCPDomainNotAllowedError extends Error {
public readonly code = MCPErrorCodes.DOMAIN_NOT_ALLOWED;
public readonly statusCode = 403;
public readonly domain: string;
constructor(domain: string) {
super(`Domain "${domain}" is not allowed`);
this.name = 'MCPDomainNotAllowedError';
this.domain = domain;
Object.setPrototypeOf(this, MCPDomainNotAllowedError.prototype);
}
}
/**
* Custom error for MCP server inspection failures.
* Thrown when attempting to connect/inspect an MCP server fails.
*/
export class MCPInspectionFailedError extends Error {
public readonly code = MCPErrorCodes.INSPECTION_FAILED;
public readonly statusCode = 400;
public readonly serverName: string;
constructor(serverName: string, cause?: Error) {
super(`Failed to connect to MCP server "${serverName}"`);
this.name = 'MCPInspectionFailedError';
this.serverName = serverName;
if (cause) {
this.cause = cause;
}
Object.setPrototypeOf(this, MCPInspectionFailedError.prototype);
}
}
/**
* Type guard to check if an error is an MCPDomainNotAllowedError
*/
export function isMCPDomainNotAllowedError(error: unknown): error is MCPDomainNotAllowedError {
return error instanceof MCPDomainNotAllowedError;
}
/**
* Type guard to check if an error is an MCPInspectionFailedError
*/
export function isMCPInspectionFailedError(error: unknown): error is MCPInspectionFailedError {
return error instanceof MCPInspectionFailedError;
}

View file

@ -9,5 +9,7 @@ export const mcpConfig = {
OAUTH_DETECTION_TIMEOUT: math(process.env.MCP_OAUTH_DETECTION_TIMEOUT ?? 5000),
CONNECTION_CHECK_TTL: math(process.env.MCP_CONNECTION_CHECK_TTL ?? 60000),
/** Idle timeout (ms) after which user connections are disconnected. Default: 15 minutes */
USER_CONNECTION_IDLE_TIMEOUT: math(process.env.MCP_USER_CONNECTION_IDLE_TIMEOUT ?? 15 * 60 * 1000),
USER_CONNECTION_IDLE_TIMEOUT: math(
process.env.MCP_USER_CONNECTION_IDLE_TIMEOUT ?? 15 * 60 * 1000,
),
};

View file

@ -9,7 +9,7 @@ import {
discoverAuthorizationServerMetadata,
discoverOAuthProtectedResourceMetadata,
} from '@modelcontextprotocol/sdk/client/auth.js';
import type { MCPOptions } from 'librechat-data-provider';
import { TokenExchangeMethodEnum, type MCPOptions } from 'librechat-data-provider';
import type { FlowStateManager } from '~/flow/manager';
import type {
OAuthClientInformation,
@ -27,15 +27,117 @@ export class MCPOAuthHandler {
private static readonly FLOW_TYPE = 'mcp_oauth';
private static readonly FLOW_TTL = 10 * 60 * 1000; // 10 minutes
private static getForcedTokenEndpointAuthMethod(
tokenExchangeMethod?: TokenExchangeMethodEnum,
): 'client_secret_basic' | 'client_secret_post' | undefined {
if (tokenExchangeMethod === TokenExchangeMethodEnum.DefaultPost) {
return 'client_secret_post';
}
if (tokenExchangeMethod === TokenExchangeMethodEnum.BasicAuthHeader) {
return 'client_secret_basic';
}
return undefined;
}
private static resolveTokenEndpointAuthMethod(options: {
tokenExchangeMethod?: TokenExchangeMethodEnum;
tokenAuthMethods: string[];
preferredMethod?: string;
}): 'client_secret_basic' | 'client_secret_post' | undefined {
const forcedMethod = this.getForcedTokenEndpointAuthMethod(options.tokenExchangeMethod);
const preferredMethod = forcedMethod ?? options.preferredMethod;
if (preferredMethod === 'client_secret_basic' || preferredMethod === 'client_secret_post') {
return preferredMethod;
}
if (options.tokenAuthMethods.includes('client_secret_basic')) {
return 'client_secret_basic';
}
if (options.tokenAuthMethods.includes('client_secret_post')) {
return 'client_secret_post';
}
return undefined;
}
/**
* Creates a fetch function with custom headers injected
*/
private static createOAuthFetch(headers: Record<string, string>): FetchLike {
private static createOAuthFetch(
headers: Record<string, string>,
clientInfo?: OAuthClientInformation,
): FetchLike {
return async (url: string | URL, init?: RequestInit): Promise<Response> => {
const newHeaders = new Headers(init?.headers ?? {});
for (const [key, value] of Object.entries(headers)) {
newHeaders.set(key, value);
}
const method = (init?.method ?? 'GET').toUpperCase();
const initBody = init?.body;
let params: URLSearchParams | undefined;
if (initBody instanceof URLSearchParams) {
params = initBody;
} else if (typeof initBody === 'string') {
const parsed = new URLSearchParams(initBody);
if (parsed.has('grant_type')) {
params = parsed;
}
}
/**
* FastMCP 2.14+/MCP SDK 1.24+ token endpoints can be strict about:
* - Content-Type (must be application/x-www-form-urlencoded)
* - where client_id/client_secret are supplied (default_post vs basic header)
*/
if (method === 'POST' && params?.has('grant_type')) {
newHeaders.set('Content-Type', 'application/x-www-form-urlencoded');
if (clientInfo?.client_id) {
let authMethod = clientInfo.token_endpoint_auth_method;
if (!authMethod) {
if (newHeaders.has('Authorization')) {
authMethod = 'client_secret_basic';
} else if (params.has('client_id') || params.has('client_secret')) {
authMethod = 'client_secret_post';
} else if (clientInfo.client_secret) {
authMethod = 'client_secret_post';
} else {
authMethod = 'none';
}
}
if (!clientInfo.client_secret || authMethod === 'none') {
newHeaders.delete('Authorization');
if (!params.has('client_id')) {
params.set('client_id', clientInfo.client_id);
}
} else if (authMethod === 'client_secret_post') {
newHeaders.delete('Authorization');
if (!params.has('client_id')) {
params.set('client_id', clientInfo.client_id);
}
if (!params.has('client_secret')) {
params.set('client_secret', clientInfo.client_secret);
}
} else if (authMethod === 'client_secret_basic') {
if (!newHeaders.has('Authorization')) {
const clientAuth = Buffer.from(
`${clientInfo.client_id}:${clientInfo.client_secret}`,
).toString('base64');
newHeaders.set('Authorization', `Basic ${clientAuth}`);
}
}
}
return fetch(url, {
...init,
body: params.toString(),
headers: newHeaders,
});
}
return fetch(url, {
...init,
headers: newHeaders,
@ -157,6 +259,7 @@ export class MCPOAuthHandler {
oauthHeaders: Record<string, string>,
resourceMetadata?: OAuthProtectedResourceMetadata,
redirectUri?: string,
tokenExchangeMethod?: TokenExchangeMethodEnum,
): Promise<OAuthClientInformation> {
logger.debug(
`[MCPOAuth] Starting client registration for ${sanitizeUrlForLogging(serverUrl)}, server metadata:`,
@ -197,7 +300,11 @@ export class MCPOAuthHandler {
clientMetadata.response_types = metadata.response_types_supported || ['code'];
if (metadata.token_endpoint_auth_methods_supported) {
const forcedAuthMethod = this.getForcedTokenEndpointAuthMethod(tokenExchangeMethod);
if (forcedAuthMethod) {
clientMetadata.token_endpoint_auth_method = forcedAuthMethod;
} else if (metadata.token_endpoint_auth_methods_supported) {
// Prefer client_secret_basic if supported, otherwise use the first supported method
if (metadata.token_endpoint_auth_methods_supported.includes('client_secret_basic')) {
clientMetadata.token_endpoint_auth_method = 'client_secret_basic';
@ -227,6 +334,12 @@ export class MCPOAuthHandler {
fetchFn: this.createOAuthFetch(oauthHeaders),
});
if (forcedAuthMethod) {
clientInfo.token_endpoint_auth_method = forcedAuthMethod;
} else if (!clientInfo.token_endpoint_auth_method) {
clientInfo.token_endpoint_auth_method = clientMetadata.token_endpoint_auth_method;
}
logger.debug(
`[MCPOAuth] Client registered successfully for ${sanitizeUrlForLogging(serverUrl)}:`,
{
@ -281,6 +394,26 @@ export class MCPOAuthHandler {
}
/** Metadata based on pre-configured settings */
let tokenEndpointAuthMethod: string;
if (!config.client_secret) {
tokenEndpointAuthMethod = 'none';
} else {
// When token_exchange_method is undefined or not DefaultPost, default to using
// client_secret_basic (Basic Auth header) for token endpoint authentication.
tokenEndpointAuthMethod =
this.getForcedTokenEndpointAuthMethod(config.token_exchange_method) ??
'client_secret_basic';
}
let defaultTokenAuthMethods: string[];
if (tokenEndpointAuthMethod === 'none') {
defaultTokenAuthMethods = ['none'];
} else if (tokenEndpointAuthMethod === 'client_secret_post') {
defaultTokenAuthMethods = ['client_secret_post', 'client_secret_basic'];
} else {
defaultTokenAuthMethods = ['client_secret_basic', 'client_secret_post'];
}
const metadata: OAuthMetadata = {
authorization_endpoint: config.authorization_url,
token_endpoint: config.token_url,
@ -290,10 +423,8 @@ export class MCPOAuthHandler {
'authorization_code',
'refresh_token',
],
token_endpoint_auth_methods_supported: config?.token_endpoint_auth_methods_supported ?? [
'client_secret_basic',
'client_secret_post',
],
token_endpoint_auth_methods_supported:
config?.token_endpoint_auth_methods_supported ?? defaultTokenAuthMethods,
response_types_supported: config?.response_types_supported ?? ['code'],
code_challenge_methods_supported: codeChallengeMethodsSupported,
};
@ -303,6 +434,7 @@ export class MCPOAuthHandler {
client_secret: config.client_secret,
redirect_uris: [config.redirect_uri || this.getDefaultRedirectUri(serverName)],
scope: config.scope,
token_endpoint_auth_method: tokenEndpointAuthMethod,
};
logger.debug(`[MCPOAuth] Starting authorization with pre-configured settings`);
@ -359,6 +491,7 @@ export class MCPOAuthHandler {
oauthHeaders,
resourceMetadata,
redirectUri,
config?.token_exchange_method,
);
logger.debug(`[MCPOAuth] Client registered with ID: ${clientInfo.client_id}`);
@ -490,7 +623,7 @@ export class MCPOAuthHandler {
codeVerifier: metadata.codeVerifier,
authorizationCode,
resource,
fetchFn: this.createOAuthFetch(oauthHeaders),
fetchFn: this.createOAuthFetch(oauthHeaders, metadata.clientInfo),
});
logger.debug('[MCPOAuth] Token exchange successful', {
@ -663,8 +796,8 @@ export class MCPOAuthHandler {
}
const headers: HeadersInit = {
'Content-Type': 'application/x-www-form-urlencoded',
Accept: 'application/json',
'Content-Type': 'application/x-www-form-urlencoded',
...oauthHeaders,
};
@ -672,17 +805,20 @@ export class MCPOAuthHandler {
if (metadata.clientInfo.client_secret) {
/** Default to client_secret_basic if no methods specified (per RFC 8414) */
const tokenAuthMethods = authMethods ?? ['client_secret_basic'];
const usesBasicAuth = tokenAuthMethods.includes('client_secret_basic');
const usesClientSecretPost = tokenAuthMethods.includes('client_secret_post');
const authMethod = this.resolveTokenEndpointAuthMethod({
tokenExchangeMethod: config?.token_exchange_method,
tokenAuthMethods,
preferredMethod: metadata.clientInfo.token_endpoint_auth_method,
});
if (usesBasicAuth) {
if (authMethod === 'client_secret_basic') {
/** Use Basic auth */
logger.debug('[MCPOAuth] Using client_secret_basic authentication method');
const clientAuth = Buffer.from(
`${metadata.clientInfo.client_id}:${metadata.clientInfo.client_secret}`,
).toString('base64');
headers['Authorization'] = `Basic ${clientAuth}`;
} else if (usesClientSecretPost) {
} else if (authMethod === 'client_secret_post') {
/** Use client_secret_post */
logger.debug('[MCPOAuth] Using client_secret_post authentication method');
body.append('client_id', metadata.clientInfo.client_id);
@ -739,8 +875,8 @@ export class MCPOAuthHandler {
}
const headers: HeadersInit = {
'Content-Type': 'application/x-www-form-urlencoded',
Accept: 'application/json',
'Content-Type': 'application/x-www-form-urlencoded',
...oauthHeaders,
};
@ -750,10 +886,12 @@ export class MCPOAuthHandler {
const tokenAuthMethods = config.token_endpoint_auth_methods_supported ?? [
'client_secret_basic',
];
const usesBasicAuth = tokenAuthMethods.includes('client_secret_basic');
const usesClientSecretPost = tokenAuthMethods.includes('client_secret_post');
const authMethod = this.resolveTokenEndpointAuthMethod({
tokenExchangeMethod: config.token_exchange_method,
tokenAuthMethods,
});
if (usesBasicAuth) {
if (authMethod === 'client_secret_basic') {
/** Use Basic auth */
logger.debug(
'[MCPOAuth] Using client_secret_basic authentication method (pre-configured)',
@ -762,7 +900,7 @@ export class MCPOAuthHandler {
'base64',
);
headers['Authorization'] = `Basic ${clientAuth}`;
} else if (usesClientSecretPost) {
} else if (authMethod === 'client_secret_post') {
/** Use client_secret_post */
logger.debug(
'[MCPOAuth] Using client_secret_post authentication method (pre-configured)',
@ -832,8 +970,8 @@ export class MCPOAuthHandler {
});
const headers: HeadersInit = {
'Content-Type': 'application/x-www-form-urlencoded',
Accept: 'application/json',
'Content-Type': 'application/x-www-form-urlencoded',
...oauthHeaders,
};

View file

@ -2,7 +2,9 @@ import { Constants } from 'librechat-data-provider';
import type { JsonSchemaType } from '@librechat/data-schemas';
import type { MCPConnection } from '~/mcp/connection';
import type * as t from '~/mcp/types';
import { isMCPDomainAllowed, extractMCPServerDomain } from '~/auth/domain';
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
import { MCPDomainNotAllowedError } from '~/mcp/errors';
import { detectOAuthRequirement } from '~/mcp/oauth';
import { isEnabled } from '~/utils';
@ -24,13 +26,22 @@ export class MCPServerInspector {
* @param serverName - The name of the server (used for tool function naming)
* @param rawConfig - The raw server configuration
* @param connection - The MCP connection
* @param allowedDomains - Optional list of allowed domains for remote transports
* @returns A fully processed and enriched configuration with server metadata
*/
public static async inspect(
serverName: string,
rawConfig: t.MCPOptions,
connection?: MCPConnection,
allowedDomains?: string[] | null,
): Promise<t.ParsedServerConfig> {
// Validate domain against allowlist BEFORE attempting connection
const isDomainAllowed = await isMCPDomainAllowed(rawConfig, allowedDomains);
if (!isDomainAllowed) {
const domain = extractMCPServerDomain(rawConfig);
throw new MCPDomainNotAllowedError(domain ?? 'unknown');
}
const start = Date.now();
const inspector = new MCPServerInspector(serverName, rawConfig, connection);
await inspector.inspectServer();

View file

@ -30,11 +30,24 @@ export class MCPServersInitializer {
* - Followers wait and poll `statusCache` until the leader finishes, ensuring only one node
* performs the expensive initialization operations.
*/
private static hasInitializedThisProcess = false;
/** Reset the process-level initialization flag. Only used for testing. */
public static resetProcessFlag(): void {
MCPServersInitializer.hasInitializedThisProcess = false;
}
public static async initialize(rawConfigs: t.MCPServers): Promise<void> {
if (await statusCache.isInitialized()) return;
// On first call in this process, always reset and re-initialize
// This ensures we don't use stale Redis data from previous runs
const isFirstCallThisProcess = !MCPServersInitializer.hasInitializedThisProcess;
// Set flag immediately so recursive calls (from followers) use Redis cache for coordination
MCPServersInitializer.hasInitializedThisProcess = true;
if (!isFirstCallThisProcess && (await statusCache.isInitialized())) return;
if (await isLeader()) {
// Leader performs initialization
// Leader performs initialization - always reset on first call
await statusCache.reset();
await MCPServersRegistry.getInstance().reset();
const serverNames = Object.keys(rawConfigs);

View file

@ -1,9 +1,12 @@
import { Keyv } from 'keyv';
import { logger } from '@librechat/data-schemas';
import type { IServerConfigsRepositoryInterface } from './ServerConfigsRepositoryInterface';
import type * as t from '~/mcp/types';
import { MCPInspectionFailedError, isMCPDomainNotAllowedError } from '~/mcp/errors';
import { ServerConfigsCacheFactory } from './cache/ServerConfigsCacheFactory';
import { MCPServerInspector } from './MCPServerInspector';
import { ServerConfigsDB } from './db/ServerConfigsDB';
import { cacheConfig } from '~/cache/cacheConfig';
/**
* Central registry for managing MCP server configurations.
@ -20,14 +23,33 @@ export class MCPServersRegistry {
private readonly dbConfigsRepo: IServerConfigsRepositoryInterface;
private readonly cacheConfigsRepo: IServerConfigsRepositoryInterface;
private readonly allowedDomains?: string[] | null;
private readonly readThroughCache: Keyv<t.ParsedServerConfig>;
private readonly readThroughCacheAll: Keyv<Record<string, t.ParsedServerConfig>>;
constructor(mongoose: typeof import('mongoose')) {
constructor(mongoose: typeof import('mongoose'), allowedDomains?: string[] | null) {
this.dbConfigsRepo = new ServerConfigsDB(mongoose);
this.cacheConfigsRepo = ServerConfigsCacheFactory.create('App', false);
this.allowedDomains = allowedDomains;
const ttl = cacheConfig.MCP_REGISTRY_CACHE_TTL;
this.readThroughCache = new Keyv<t.ParsedServerConfig>({
namespace: 'mcp-registry-read-through',
ttl,
});
this.readThroughCacheAll = new Keyv<Record<string, t.ParsedServerConfig>>({
namespace: 'mcp-registry-read-through-all',
ttl,
});
}
/** Creates and initializes the singleton MCPServersRegistry instance */
public static createInstance(mongoose: typeof import('mongoose')): MCPServersRegistry {
public static createInstance(
mongoose: typeof import('mongoose'),
allowedDomains?: string[] | null,
): MCPServersRegistry {
if (!mongoose) {
throw new Error(
'MCPServersRegistry creation failed: mongoose instance is required for database operations. ' +
@ -39,7 +61,7 @@ export class MCPServersRegistry {
return MCPServersRegistry.instance;
}
logger.info('[MCPServersRegistry] Creating new instance');
MCPServersRegistry.instance = new MCPServersRegistry(mongoose);
MCPServersRegistry.instance = new MCPServersRegistry(mongoose, allowedDomains);
return MCPServersRegistry.instance;
}
@ -55,20 +77,40 @@ export class MCPServersRegistry {
serverName: string,
userId?: string,
): Promise<t.ParsedServerConfig | undefined> {
const cacheKey = this.getReadThroughCacheKey(serverName, userId);
if (await this.readThroughCache.has(cacheKey)) {
return await this.readThroughCache.get(cacheKey);
}
// First we check if any config exist with the cache
// Yaml config are pre loaded to the cache
const configFromCache = await this.cacheConfigsRepo.get(serverName);
if (configFromCache) return configFromCache;
if (configFromCache) {
await this.readThroughCache.set(cacheKey, configFromCache);
return configFromCache;
}
const configFromDB = await this.dbConfigsRepo.get(serverName, userId);
if (configFromDB) return configFromDB;
return undefined;
await this.readThroughCache.set(cacheKey, configFromDB);
return configFromDB;
}
public async getAllServerConfigs(userId?: string): Promise<Record<string, t.ParsedServerConfig>> {
return {
const cacheKey = userId ?? '__no_user__';
// Check if key exists in read-through cache
if (await this.readThroughCacheAll.has(cacheKey)) {
return (await this.readThroughCacheAll.get(cacheKey)) ?? {};
}
const result = {
...(await this.cacheConfigsRepo.getAll()),
...(await this.dbConfigsRepo.getAll(userId)),
};
await this.readThroughCacheAll.set(cacheKey, result);
return result;
}
public async addServer(
@ -80,10 +122,19 @@ export class MCPServersRegistry {
const configRepo = this.getConfigRepository(storageLocation);
let parsedConfig: t.ParsedServerConfig;
try {
parsedConfig = await MCPServerInspector.inspect(serverName, config);
parsedConfig = await MCPServerInspector.inspect(
serverName,
config,
undefined,
this.allowedDomains,
);
} catch (error) {
logger.error(`[MCPServersRegistry] Failed to inspect server "${serverName}":`, error);
throw new Error(`MCP_INSPECTION_FAILED: Failed to connect to MCP server "${serverName}"`);
// Preserve domain-specific error for better error handling
if (isMCPDomainNotAllowedError(error)) {
throw error;
}
throw new MCPInspectionFailedError(serverName, error as Error);
}
return await configRepo.add(serverName, parsedConfig, userId);
}
@ -113,10 +164,19 @@ export class MCPServersRegistry {
let parsedConfig: t.ParsedServerConfig;
try {
parsedConfig = await MCPServerInspector.inspect(serverName, configForInspection);
parsedConfig = await MCPServerInspector.inspect(
serverName,
configForInspection,
undefined,
this.allowedDomains,
);
} catch (error) {
logger.error(`[MCPServersRegistry] Failed to inspect server "${serverName}":`, error);
throw new Error(`MCP_INSPECTION_FAILED: Failed to connect to MCP server "${serverName}"`);
// Preserve domain-specific error for better error handling
if (isMCPDomainNotAllowedError(error)) {
throw error;
}
throw new MCPInspectionFailedError(serverName, error as Error);
}
await configRepo.update(serverName, parsedConfig, userId);
return parsedConfig;
@ -132,6 +192,8 @@ export class MCPServersRegistry {
public async reset(): Promise<void> {
await this.cacheConfigsRepo.reset();
await this.readThroughCache.clear();
await this.readThroughCacheAll.clear();
}
public async removeServer(
@ -155,4 +217,8 @@ export class MCPServersRegistry {
);
}
}
private getReadThroughCacheKey(serverName: string, userId?: string): string {
return userId ? `${serverName}::${userId}` : serverName;
}
}

View file

@ -201,6 +201,11 @@ describe('MCPServersInitializer Redis Integration Tests', () => {
// Mock MCPConnectionFactory
jest.spyOn(MCPConnectionFactory, 'create').mockResolvedValue(mockConnection);
// Reset caches and process flag before each test
await registryStatusCache.reset();
await registry.reset();
MCPServersInitializer.resetProcessFlag();
});
afterEach(async () => {
@ -261,7 +266,7 @@ describe('MCPServersInitializer Redis Integration Tests', () => {
await MCPServersInitializer.initialize(testConfigs);
// Verify inspect was not called again
expect(MCPServerInspector.inspect).not.toHaveBeenCalled();
expect((MCPServerInspector.inspect as jest.Mock).mock.calls.length).toBe(0);
});
it('should initialize all servers to cache repository', async () => {
@ -309,4 +314,118 @@ describe('MCPServersInitializer Redis Integration Tests', () => {
expect(await registryStatusCache.isInitialized()).toBe(true);
});
});
describe('horizontal scaling / app restart behavior', () => {
it('should re-initialize on first call even if Redis says initialized (simulating app restart)', async () => {
// First: run full initialization
await MCPServersInitializer.initialize(testConfigs);
expect(await registryStatusCache.isInitialized()).toBe(true);
// Add a stale server directly to Redis to simulate stale data
await registry.addServer(
'stale_server',
{
type: 'stdio',
command: 'node',
args: ['stale.js'],
},
'CACHE',
);
expect(await registry.getServerConfig('stale_server')).toBeDefined();
// Simulate app restart by resetting the process flag (but NOT Redis)
MCPServersInitializer.resetProcessFlag();
// Clear mocks to track new calls
jest.clearAllMocks();
// Re-initialize - should still run initialization because process flag was reset
await MCPServersInitializer.initialize(testConfigs);
// Stale server should be gone because registry.reset() was called
expect(await registry.getServerConfig('stale_server')).toBeUndefined();
// Real servers should be present
expect(await registry.getServerConfig('file_tools_server')).toBeDefined();
expect(await registry.getServerConfig('disabled_server')).toBeDefined();
// Inspector should have been called (proving re-initialization happened)
expect((MCPServerInspector.inspect as jest.Mock).mock.calls.length).toBeGreaterThan(0);
});
it('should skip re-initialization on subsequent calls within same process', async () => {
// First initialization
await MCPServersInitializer.initialize(testConfigs);
expect(await registryStatusCache.isInitialized()).toBe(true);
// Clear mocks
jest.clearAllMocks();
// Second call in same process should skip
await MCPServersInitializer.initialize(testConfigs);
// Inspector should NOT have been called
expect((MCPServerInspector.inspect as jest.Mock).mock.calls.length).toBe(0);
});
it('should clear stale data from Redis when a new instance becomes leader', async () => {
// Initial setup with testConfigs
await MCPServersInitializer.initialize(testConfigs);
// Add stale data that shouldn't exist after next initialization
await registry.addServer(
'should_be_removed',
{
type: 'stdio',
command: 'node',
args: ['old.js'],
},
'CACHE',
);
// Verify stale data exists
expect(await registry.getServerConfig('should_be_removed')).toBeDefined();
// Simulate new process starting (reset process flag)
MCPServersInitializer.resetProcessFlag();
// Initialize with different configs (fewer servers)
const reducedConfigs: t.MCPServers = {
file_tools_server: testConfigs.file_tools_server,
};
await MCPServersInitializer.initialize(reducedConfigs);
// Stale server from previous config should be gone
expect(await registry.getServerConfig('should_be_removed')).toBeUndefined();
// Server not in new configs should be gone
expect(await registry.getServerConfig('disabled_server')).toBeUndefined();
// Only server in new configs should exist
expect(await registry.getServerConfig('file_tools_server')).toBeDefined();
});
it('should work correctly when multiple instances share Redis (leader handles init)', async () => {
// First instance initializes (we are the leader)
await MCPServersInitializer.initialize(testConfigs);
// Verify initialized state is in Redis
expect(await registryStatusCache.isInitialized()).toBe(true);
// Verify servers are in Redis
const fileToolsServer = await registry.getServerConfig('file_tools_server');
expect(fileToolsServer).toBeDefined();
expect(fileToolsServer?.tools).toBe('file_read, file_write');
// Simulate second instance starting (reset process flag but keep Redis data)
MCPServersInitializer.resetProcessFlag();
jest.clearAllMocks();
// Second instance initializes - should still process because isFirstCallThisProcess
await MCPServersInitializer.initialize(testConfigs);
// Redis should still have correct data
expect(await registryStatusCache.isInitialized()).toBe(true);
expect(await registry.getServerConfig('file_tools_server')).toBeDefined();
});
});
});

View file

@ -179,9 +179,10 @@ describe('MCPServersInitializer', () => {
} as unknown as t.ParsedServerConfig;
});
// Reset caches before each test
// Reset caches and process flag before each test
await registryStatusCache.reset();
await registry.reset();
MCPServersInitializer.resetProcessFlag();
jest.clearAllMocks();
});
@ -223,18 +224,38 @@ describe('MCPServersInitializer', () => {
it('should process all server configs through inspector', async () => {
await MCPServersInitializer.initialize(testConfigs);
// Verify all configs were processed by inspector (without connection parameter)
// Verify all configs were processed by inspector
// Signature: inspect(serverName, rawConfig, connection?, allowedDomains?)
expect(mockInspect).toHaveBeenCalledTimes(5);
expect(mockInspect).toHaveBeenCalledWith('disabled_server', testConfigs.disabled_server);
expect(mockInspect).toHaveBeenCalledWith('oauth_server', testConfigs.oauth_server);
expect(mockInspect).toHaveBeenCalledWith('file_tools_server', testConfigs.file_tools_server);
expect(mockInspect).toHaveBeenCalledWith(
'disabled_server',
testConfigs.disabled_server,
undefined,
undefined,
);
expect(mockInspect).toHaveBeenCalledWith(
'oauth_server',
testConfigs.oauth_server,
undefined,
undefined,
);
expect(mockInspect).toHaveBeenCalledWith(
'file_tools_server',
testConfigs.file_tools_server,
undefined,
undefined,
);
expect(mockInspect).toHaveBeenCalledWith(
'search_tools_server',
testConfigs.search_tools_server,
undefined,
undefined,
);
expect(mockInspect).toHaveBeenCalledWith(
'remote_no_oauth_server',
testConfigs.remote_no_oauth_server,
undefined,
undefined,
);
});
@ -309,5 +330,51 @@ describe('MCPServersInitializer', () => {
expect(await registryStatusCache.isInitialized()).toBe(true);
});
it('should re-initialize on first call even if Redis cache says initialized (simulating app restart)', async () => {
// First initialization - populates caches
await MCPServersInitializer.initialize(testConfigs);
expect(await registryStatusCache.isInitialized()).toBe(true);
expect(await registry.getServerConfig('file_tools_server')).toBeDefined();
// Simulate stale data: add an extra server that shouldn't be there
await registry.addServer('stale_server', testConfigs.file_tools_server, 'CACHE');
expect(await registry.getServerConfig('stale_server')).toBeDefined();
jest.clearAllMocks();
// Simulate app restart by resetting the process flag
// In real scenario, this happens automatically when process restarts
MCPServersInitializer.resetProcessFlag();
// Re-initialize - should reset caches even though Redis says initialized
await MCPServersInitializer.initialize(testConfigs);
// Verify stale server was removed (cache was reset)
expect(await registry.getServerConfig('stale_server')).toBeUndefined();
// Verify new servers are present
expect(await registry.getServerConfig('file_tools_server')).toBeDefined();
expect(await registry.getServerConfig('oauth_server')).toBeDefined();
// Verify inspector was called again (re-initialization happened)
expect(mockInspect).toHaveBeenCalled();
});
it('should not re-initialize on subsequent calls within same process', async () => {
// First initialization (5 servers in testConfigs)
await MCPServersInitializer.initialize(testConfigs);
expect(mockInspect).toHaveBeenCalledTimes(5);
jest.clearAllMocks();
// Second call - should skip because process flag is set and Redis says initialized
await MCPServersInitializer.initialize(testConfigs);
expect(mockInspect).not.toHaveBeenCalled();
// Third call - still skips
await MCPServersInitializer.initialize(testConfigs);
expect(mockInspect).not.toHaveBeenCalled();
});
});
});

View file

@ -192,15 +192,14 @@ describe('MCPServersRegistry Redis Integration Tests', () => {
// Add server
await registry.addServer(serverName, testRawConfig, 'CACHE');
// Verify server exists
const configBefore = await registry.getServerConfig(serverName);
expect(configBefore).toBeDefined();
// Verify server exists in underlying cache repository (not via getServerConfig to avoid populating read-through cache)
expect(await registry['cacheConfigsRepo'].get(serverName)).toBeDefined();
// Remove server
await registry.removeServer(serverName, 'CACHE');
// Verify server was removed
const configAfter = await registry.getServerConfig(serverName);
// Verify server was removed from underlying cache repository
const configAfter = await registry['cacheConfigsRepo'].get(serverName);
expect(configAfter).toBeUndefined();
});
});

View file

@ -158,11 +158,13 @@ describe('MCPServersRegistry', () => {
it('should route removeServer to cache repository', async () => {
await registry.addServer('cache_server', testParsedConfig, 'CACHE');
expect(await registry.getServerConfig('cache_server')).toBeDefined();
// Verify server exists in underlying cache repository (not via getServerConfig to avoid populating read-through cache)
expect(await registry['cacheConfigsRepo'].get('cache_server')).toBeDefined();
await registry.removeServer('cache_server', 'CACHE');
const config = await registry.getServerConfig('cache_server');
// Verify server is removed from underlying cache repository
const config = await registry['cacheConfigsRepo'].get('cache_server');
expect(config).toBeUndefined();
});
});
@ -190,4 +192,114 @@ describe('MCPServersRegistry', () => {
});
});
});
describe('Read-through cache', () => {
describe('getServerConfig', () => {
it('should cache repeated calls for the same server', async () => {
// Add a server to the cache repository
await registry['cacheConfigsRepo'].add('test_server', testParsedConfig);
// Spy on the cache repository get method
const cacheRepoGetSpy = jest.spyOn(registry['cacheConfigsRepo'], 'get');
// First call should hit the cache repository
const config1 = await registry.getServerConfig('test_server');
expect(config1).toEqual(testParsedConfig);
expect(cacheRepoGetSpy).toHaveBeenCalledTimes(1);
// Second call should hit the read-through cache, not the repository
const config2 = await registry.getServerConfig('test_server');
expect(config2).toEqual(testParsedConfig);
expect(cacheRepoGetSpy).toHaveBeenCalledTimes(1); // Still 1, not 2
// Third call should also hit the read-through cache
const config3 = await registry.getServerConfig('test_server');
expect(config3).toEqual(testParsedConfig);
expect(cacheRepoGetSpy).toHaveBeenCalledTimes(1); // Still 1
});
it('should cache "not found" results to avoid repeated DB lookups', async () => {
// Spy on the DB repository get method
const dbRepoGetSpy = jest.spyOn(registry['dbConfigsRepo'], 'get');
// First call - server doesn't exist, should hit DB
const config1 = await registry.getServerConfig('nonexistent_server');
expect(config1).toBeUndefined();
expect(dbRepoGetSpy).toHaveBeenCalledTimes(1);
// Second call - should hit read-through cache, not DB
const config2 = await registry.getServerConfig('nonexistent_server');
expect(config2).toBeUndefined();
expect(dbRepoGetSpy).toHaveBeenCalledTimes(1); // Still 1, not 2
});
it('should use different cache keys for different userIds', async () => {
// Spy on the cache repository get method
const cacheRepoGetSpy = jest.spyOn(registry['cacheConfigsRepo'], 'get');
// First call without userId
await registry.getServerConfig('test_server');
expect(cacheRepoGetSpy).toHaveBeenCalledTimes(1);
// Call with userId - should be a different cache key, so hits repository again
await registry.getServerConfig('test_server', 'user123');
expect(cacheRepoGetSpy).toHaveBeenCalledTimes(2);
// Repeat call with same userId - should hit read-through cache
await registry.getServerConfig('test_server', 'user123');
expect(cacheRepoGetSpy).toHaveBeenCalledTimes(2); // Still 2
// Call with different userId - should hit repository
await registry.getServerConfig('test_server', 'user456');
expect(cacheRepoGetSpy).toHaveBeenCalledTimes(3);
});
});
describe('getAllServerConfigs', () => {
it('should cache repeated calls', async () => {
// Add servers to cache
await registry['cacheConfigsRepo'].add('server1', testParsedConfig);
await registry['cacheConfigsRepo'].add('server2', testParsedConfig);
// Spy on the cache repository getAll method
const cacheRepoGetAllSpy = jest.spyOn(registry['cacheConfigsRepo'], 'getAll');
// First call should hit the repository
const configs1 = await registry.getAllServerConfigs();
expect(Object.keys(configs1)).toHaveLength(2);
expect(cacheRepoGetAllSpy).toHaveBeenCalledTimes(1);
// Second call should hit the read-through cache
const configs2 = await registry.getAllServerConfigs();
expect(Object.keys(configs2)).toHaveLength(2);
expect(cacheRepoGetAllSpy).toHaveBeenCalledTimes(1); // Still 1
// Third call should also hit the read-through cache
const configs3 = await registry.getAllServerConfigs();
expect(Object.keys(configs3)).toHaveLength(2);
expect(cacheRepoGetAllSpy).toHaveBeenCalledTimes(1); // Still 1
});
it('should use different cache keys for different userIds', async () => {
// Spy on the cache repository getAll method
const cacheRepoGetAllSpy = jest.spyOn(registry['cacheConfigsRepo'], 'getAll');
// First call without userId
await registry.getAllServerConfigs();
expect(cacheRepoGetAllSpy).toHaveBeenCalledTimes(1);
// Call with userId - should be a different cache key
await registry.getAllServerConfigs('user123');
expect(cacheRepoGetAllSpy).toHaveBeenCalledTimes(2);
// Repeat call with same userId - should hit read-through cache
await registry.getAllServerConfigs('user123');
expect(cacheRepoGetAllSpy).toHaveBeenCalledTimes(2); // Still 2
// Call with different userId - should hit repository
await registry.getAllServerConfigs('user456');
expect(cacheRepoGetAllSpy).toHaveBeenCalledTimes(3);
});
});
});
});

View file

@ -0,0 +1,940 @@
import { logger } from '@librechat/data-schemas';
import type { StandardGraph } from '@librechat/agents';
import type { Agents } from 'librechat-data-provider';
import type {
SerializableJobData,
IEventTransport,
AbortResult,
IJobStore,
} from './interfaces/IJobStore';
import type * as t from '~/types';
import { InMemoryEventTransport } from './implementations/InMemoryEventTransport';
import { InMemoryJobStore } from './implementations/InMemoryJobStore';
/**
* Configuration options for GenerationJobManager
*/
export interface GenerationJobManagerOptions {
jobStore?: IJobStore;
eventTransport?: IEventTransport;
/**
* If true, cleans up event transport immediately when job completes.
* If false, keeps EventEmitters until periodic cleanup for late reconnections.
* Default: true (immediate cleanup to save memory)
*/
cleanupOnComplete?: boolean;
}
/**
* Runtime state for active jobs - not serializable, kept in-memory per instance.
* Contains AbortController, ready promise, and other non-serializable state.
*
* @property abortController - Controller to abort the generation
* @property readyPromise - Resolves immediately (legacy, kept for API compatibility)
* @property resolveReady - Function to resolve readyPromise
* @property finalEvent - Cached final event for late subscribers
* @property syncSent - Whether sync event was sent (reset when all subscribers leave)
* @property earlyEventBuffer - Buffer for events emitted before first subscriber connects
* @property hasSubscriber - Whether at least one subscriber has connected
* @property allSubscribersLeftHandlers - Internal handlers for disconnect events.
* These are stored separately from eventTransport subscribers to avoid being counted
* in subscriber count. This is critical: if these were registered via subscribe(),
* they would count as subscribers, causing isFirstSubscriber() to return false
* when the real client connects, which would prevent readyPromise from resolving.
*/
interface RuntimeJobState {
abortController: AbortController;
readyPromise: Promise<void>;
resolveReady: () => void;
finalEvent?: t.ServerSentEvent;
syncSent: boolean;
earlyEventBuffer: t.ServerSentEvent[];
hasSubscriber: boolean;
allSubscribersLeftHandlers?: Array<(...args: unknown[]) => void>;
}
/**
* Manages generation jobs for resumable LLM streams.
*
* Architecture: Composes two pluggable services via dependency injection:
* - jobStore: Job metadata + content state (InMemory Redis for horizontal scaling)
* - eventTransport: Pub/sub events (InMemory Redis Pub/Sub for horizontal scaling)
*
* Content state is tied to jobs:
* - In-memory: jobStore holds WeakRef to graph for live content/run steps access
* - Redis: jobStore persists chunks, reconstructs content on demand
*
* All storage methods are async to support both in-memory and external stores (Redis, etc.).
*
* @example Redis injection:
* ```ts
* const manager = new GenerationJobManagerClass({
* jobStore: new RedisJobStore(redisClient),
* eventTransport: new RedisPubSubTransport(redisClient),
* });
* ```
*/
class GenerationJobManagerClass {
/** Job metadata + content state storage - swappable for Redis, etc. */
private jobStore: IJobStore;
/** Event pub/sub transport - swappable for Redis Pub/Sub, etc. */
private eventTransport: IEventTransport;
/** Runtime state - always in-memory, not serializable */
private runtimeState = new Map<string, RuntimeJobState>();
private cleanupInterval: NodeJS.Timeout | null = null;
/** Whether we're using Redis stores */
private _isRedis = false;
/** Whether to cleanup event transport immediately on job completion */
private _cleanupOnComplete = true;
constructor(options?: GenerationJobManagerOptions) {
this.jobStore =
options?.jobStore ?? new InMemoryJobStore({ ttlAfterComplete: 0, maxJobs: 1000 });
this.eventTransport = options?.eventTransport ?? new InMemoryEventTransport();
this._cleanupOnComplete = options?.cleanupOnComplete ?? true;
}
/**
* Initialize the job manager with periodic cleanup.
* Call this once at application startup.
*/
initialize(): void {
if (this.cleanupInterval) {
return;
}
this.jobStore.initialize();
this.cleanupInterval = setInterval(() => {
this.cleanup();
}, 60000);
if (this.cleanupInterval.unref) {
this.cleanupInterval.unref();
}
logger.debug('[GenerationJobManager] Initialized');
}
/**
* Configure the manager with custom stores.
* Call this BEFORE initialize() to use Redis or other stores.
*
* @example Using Redis
* ```ts
* import { createStreamServicesFromCache } from '~/stream/createStreamServices';
* import { cacheConfig, ioredisClient } from '~/cache';
*
* const services = createStreamServicesFromCache({ cacheConfig, ioredisClient });
* GenerationJobManager.configure(services);
* GenerationJobManager.initialize();
* ```
*/
configure(services: {
jobStore: IJobStore;
eventTransport: IEventTransport;
isRedis?: boolean;
cleanupOnComplete?: boolean;
}): void {
if (this.cleanupInterval) {
logger.warn(
'[GenerationJobManager] Reconfiguring after initialization - destroying existing services',
);
this.destroy();
}
this.jobStore = services.jobStore;
this.eventTransport = services.eventTransport;
this._isRedis = services.isRedis ?? false;
this._cleanupOnComplete = services.cleanupOnComplete ?? true;
logger.info(
`[GenerationJobManager] Configured with ${this._isRedis ? 'Redis' : 'in-memory'} stores`,
);
}
/**
* Check if using Redis stores.
*/
get isRedis(): boolean {
return this._isRedis;
}
/**
* Get the job store instance (for advanced use cases).
*/
getJobStore(): IJobStore {
return this.jobStore;
}
/**
* Create a new generation job.
*
* This sets up:
* 1. Serializable job data in the job store
* 2. Runtime state including readyPromise (resolves when first SSE client connects)
* 3. allSubscribersLeft callback for handling client disconnections
*
* The readyPromise mechanism ensures generation doesn't start before the client
* is ready to receive events. The controller awaits this promise (with a short timeout)
* before starting LLM generation.
*
* @param streamId - Unique identifier for this stream
* @param userId - User who initiated the request
* @param conversationId - Optional conversation ID for lookup
* @returns A facade object for the GenerationJob
*/
async createJob(
streamId: string,
userId: string,
conversationId?: string,
): Promise<t.GenerationJob> {
const jobData = await this.jobStore.createJob(streamId, userId, conversationId);
/**
* Create runtime state with readyPromise.
*
* With the resumable stream architecture, we no longer need to wait for the
* first subscriber before starting generation:
* - Redis mode: Events are persisted and can be replayed via sync
* - In-memory mode: Content is aggregated and sent via sync on connect
*
* We resolve readyPromise immediately to eliminate startup latency.
* The sync mechanism handles late-connecting clients.
*/
let resolveReady: () => void;
const readyPromise = new Promise<void>((resolve) => {
resolveReady = resolve;
});
const runtime: RuntimeJobState = {
abortController: new AbortController(),
readyPromise,
resolveReady: resolveReady!,
syncSent: false,
earlyEventBuffer: [],
hasSubscriber: false,
};
this.runtimeState.set(streamId, runtime);
// Resolve immediately - early event buffer handles late subscribers
resolveReady!();
/**
* Set up all-subscribers-left callback.
* When all SSE clients disconnect, this:
* 1. Resets syncSent so reconnecting clients get sync event
* 2. Calls any registered allSubscribersLeft handlers (e.g., to save partial responses)
*/
this.eventTransport.onAllSubscribersLeft(streamId, () => {
const currentRuntime = this.runtimeState.get(streamId);
if (currentRuntime) {
currentRuntime.syncSent = false;
// Call registered handlers (from job.emitter.on('allSubscribersLeft', ...))
if (currentRuntime.allSubscribersLeftHandlers) {
this.jobStore
.getContentParts(streamId)
.then((result) => {
const parts = result?.content ?? [];
for (const handler of currentRuntime.allSubscribersLeftHandlers ?? []) {
try {
handler(parts);
} catch (err) {
logger.error(`[GenerationJobManager] Error in allSubscribersLeft handler:`, err);
}
}
})
.catch((err) => {
logger.error(
`[GenerationJobManager] Failed to get content parts for allSubscribersLeft handlers:`,
err,
);
});
}
}
});
logger.debug(`[GenerationJobManager] Created job: ${streamId}`);
// Return facade for backwards compatibility
return this.buildJobFacade(streamId, jobData, runtime);
}
/**
* Build a GenerationJob facade from composed services.
*
* This facade provides a unified API (job.emitter, job.abortController, etc.)
* while internally delegating to the injected services (jobStore, eventTransport,
* contentState). This allows swapping implementations (e.g., Redis) without
* changing consumer code.
*
* IMPORTANT: The emitterProxy.on('allSubscribersLeft') handler registration
* does NOT use eventTransport.subscribe(). This is intentional:
*
* If we used subscribe() for internal handlers, those handlers would count
* as subscribers. When the real SSE client connects, isFirstSubscriber()
* would return false (because internal handler was "first"), and readyPromise
* would never resolve - causing a 5-second timeout delay before generation starts.
*
* Instead, allSubscribersLeft handlers are stored in runtime.allSubscribersLeftHandlers
* and called directly from the onAllSubscribersLeft callback in createJob().
*
* @param streamId - The stream identifier
* @param jobData - Serializable job metadata from job store
* @param runtime - Non-serializable runtime state (abort controller, promises, etc.)
* @returns A GenerationJob facade object
*/
private buildJobFacade(
streamId: string,
jobData: SerializableJobData,
runtime: RuntimeJobState,
): t.GenerationJob {
/**
* Proxy emitter that delegates to eventTransport for most operations.
* Exception: allSubscribersLeft handlers are stored separately to avoid
* incrementing subscriber count (see class JSDoc above).
*/
const emitterProxy = {
on: (event: string, handler: (...args: unknown[]) => void) => {
if (event === 'allSubscribersLeft') {
// Store handler for internal callback - don't use subscribe() to avoid counting as a subscriber
if (!runtime.allSubscribersLeftHandlers) {
runtime.allSubscribersLeftHandlers = [];
}
runtime.allSubscribersLeftHandlers.push(handler);
}
},
emit: () => {
/* handled via eventTransport */
},
listenerCount: () => this.eventTransport.getSubscriberCount(streamId),
setMaxListeners: () => {
/* no-op for proxy */
},
removeAllListeners: () => this.eventTransport.cleanup(streamId),
off: () => {
/* handled via unsubscribe */
},
};
return {
streamId,
emitter: emitterProxy as unknown as t.GenerationJob['emitter'],
status: jobData.status as t.GenerationJobStatus,
createdAt: jobData.createdAt,
completedAt: jobData.completedAt,
abortController: runtime.abortController,
error: jobData.error,
metadata: {
userId: jobData.userId,
conversationId: jobData.conversationId,
userMessage: jobData.userMessage,
responseMessageId: jobData.responseMessageId,
sender: jobData.sender,
},
readyPromise: runtime.readyPromise,
resolveReady: runtime.resolveReady,
finalEvent: runtime.finalEvent,
syncSent: runtime.syncSent,
};
}
/**
* Get a job by streamId.
*/
async getJob(streamId: string): Promise<t.GenerationJob | undefined> {
const jobData = await this.jobStore.getJob(streamId);
const runtime = this.runtimeState.get(streamId);
if (!jobData || !runtime) {
return undefined;
}
return this.buildJobFacade(streamId, jobData, runtime);
}
/**
* Check if a job exists.
*/
async hasJob(streamId: string): Promise<boolean> {
return this.jobStore.hasJob(streamId);
}
/**
* Get job status.
*/
async getJobStatus(streamId: string): Promise<t.GenerationJobStatus | undefined> {
const jobData = await this.jobStore.getJob(streamId);
return jobData?.status as t.GenerationJobStatus | undefined;
}
/**
* Mark job as complete.
* If cleanupOnComplete is true (default), immediately cleans up job resources.
* Note: eventTransport is NOT cleaned up here to allow the final event to be
* fully transmitted. It will be cleaned up when subscribers disconnect or
* by the periodic cleanup job.
*/
async completeJob(streamId: string, error?: string): Promise<void> {
const runtime = this.runtimeState.get(streamId);
// Abort the controller to signal all pending operations (e.g., OAuth flow monitors)
// that the job is done and they should clean up
if (runtime) {
runtime.abortController.abort();
}
// Clear content state and run step buffer (Redis only)
this.jobStore.clearContentState(streamId);
this.runStepBuffers?.delete(streamId);
// Immediate cleanup if configured (default: true)
if (this._cleanupOnComplete) {
this.runtimeState.delete(streamId);
// Don't cleanup eventTransport here - let the done event fully transmit first.
// EventTransport will be cleaned up when subscribers disconnect or by periodic cleanup.
await this.jobStore.deleteJob(streamId);
} else {
// Only update status if keeping the job around
await this.jobStore.updateJob(streamId, {
status: error ? 'error' : 'complete',
completedAt: Date.now(),
error,
});
}
logger.debug(`[GenerationJobManager] Job completed: ${streamId}`);
}
/**
* Abort a job (user-initiated).
* Returns all data needed for token spending and message saving.
*/
async abortJob(streamId: string): Promise<AbortResult> {
const jobData = await this.jobStore.getJob(streamId);
const runtime = this.runtimeState.get(streamId);
if (!jobData) {
logger.warn(`[GenerationJobManager] Cannot abort - job not found: ${streamId}`);
return { success: false, jobData: null, content: [], finalEvent: null };
}
if (runtime) {
runtime.abortController.abort();
}
// Get content before clearing state
const result = await this.jobStore.getContentParts(streamId);
const content = result?.content ?? [];
// Detect "early abort" - aborted before any generation happened (e.g., during tool loading)
// In this case, no messages were saved to DB, so frontend shouldn't navigate to conversation
const isEarlyAbort = content.length === 0 && !jobData.responseMessageId;
// Create final event for abort
const userMessageId = jobData.userMessage?.messageId;
const abortFinalEvent: t.ServerSentEvent = {
final: true,
// Don't include conversation for early aborts - it doesn't exist in DB
conversation: isEarlyAbort ? null : { conversationId: jobData.conversationId },
title: 'New Chat',
requestMessage: jobData.userMessage
? {
messageId: userMessageId,
parentMessageId: jobData.userMessage.parentMessageId,
conversationId: jobData.conversationId,
text: jobData.userMessage.text ?? '',
isCreatedByUser: true,
}
: null,
responseMessage: isEarlyAbort
? null
: {
messageId: jobData.responseMessageId ?? `${userMessageId ?? 'aborted'}_`,
parentMessageId: userMessageId,
conversationId: jobData.conversationId,
content,
sender: jobData.sender ?? 'AI',
unfinished: true,
error: false,
isCreatedByUser: false,
},
aborted: true,
// Flag for early abort - no messages saved, frontend should go to new chat
earlyAbort: isEarlyAbort,
} as unknown as t.ServerSentEvent;
if (runtime) {
runtime.finalEvent = abortFinalEvent;
}
this.eventTransport.emitDone(streamId, abortFinalEvent);
this.jobStore.clearContentState(streamId);
this.runStepBuffers?.delete(streamId);
// Immediate cleanup if configured (default: true)
if (this._cleanupOnComplete) {
this.runtimeState.delete(streamId);
// Don't cleanup eventTransport here - let the abort event fully transmit first.
await this.jobStore.deleteJob(streamId);
} else {
// Only update status if keeping the job around
await this.jobStore.updateJob(streamId, {
status: 'aborted',
completedAt: Date.now(),
});
}
logger.debug(`[GenerationJobManager] Job aborted: ${streamId}`);
return {
success: true,
jobData,
content,
finalEvent: abortFinalEvent,
};
}
/**
* Subscribe to a job's event stream.
*
* This is called when an SSE client connects to /chat/stream/:streamId.
* On first subscription:
* - Resolves readyPromise (legacy, for API compatibility)
* - Replays any buffered early events (e.g., 'created' event)
*
* @param streamId - The stream to subscribe to
* @param onChunk - Handler for chunk events (streamed tokens, run steps, etc.)
* @param onDone - Handler for completion event (includes final message)
* @param onError - Handler for error events
* @returns Subscription object with unsubscribe function, or null if job not found
*/
async subscribe(
streamId: string,
onChunk: t.ChunkHandler,
onDone?: t.DoneHandler,
onError?: t.ErrorHandler,
): Promise<{ unsubscribe: t.UnsubscribeFn } | null> {
const runtime = this.runtimeState.get(streamId);
if (!runtime) {
return null;
}
const jobData = await this.jobStore.getJob(streamId);
// If job already complete, send final event
setImmediate(() => {
if (
runtime.finalEvent &&
jobData &&
['complete', 'error', 'aborted'].includes(jobData.status)
) {
onDone?.(runtime.finalEvent);
}
});
const subscription = this.eventTransport.subscribe(streamId, {
onChunk: (event) => {
const e = event as t.ServerSentEvent;
// Filter out internal events
if (!(e as Record<string, unknown>)._internal) {
onChunk(e);
}
},
onDone: (event) => onDone?.(event as t.ServerSentEvent),
onError,
});
// Check if this is the first subscriber
const isFirst = this.eventTransport.isFirstSubscriber(streamId);
// First subscriber: replay buffered events and mark as connected
if (!runtime.hasSubscriber) {
runtime.hasSubscriber = true;
// Replay any events that were emitted before subscriber connected
if (runtime.earlyEventBuffer.length > 0) {
logger.debug(
`[GenerationJobManager] Replaying ${runtime.earlyEventBuffer.length} buffered events for ${streamId}`,
);
for (const bufferedEvent of runtime.earlyEventBuffer) {
onChunk(bufferedEvent);
}
// Clear buffer after replay
runtime.earlyEventBuffer = [];
}
}
if (isFirst) {
runtime.resolveReady();
logger.debug(
`[GenerationJobManager] First subscriber ready, resolving promise for ${streamId}`,
);
}
return subscription;
}
/**
* Emit a chunk event to all subscribers.
* Uses runtime state check for performance (avoids async job store lookup per token).
*
* If no subscriber has connected yet, buffers the event for replay when they do.
* This ensures early events (like 'created') aren't lost due to race conditions.
*/
emitChunk(streamId: string, event: t.ServerSentEvent): void {
const runtime = this.runtimeState.get(streamId);
if (!runtime || runtime.abortController.signal.aborted) {
return;
}
// Track user message from created event
this.trackUserMessage(streamId, event);
// For Redis mode, persist chunk for later reconstruction
if (this._isRedis) {
// The SSE event structure is { event: string, data: unknown, ... }
// The aggregator expects { event: string, data: unknown } where data is the payload
const eventObj = event as Record<string, unknown>;
const eventType = eventObj.event as string | undefined;
const eventData = eventObj.data;
if (eventType && eventData !== undefined) {
// Store in format expected by aggregateContent: { event, data }
this.jobStore.appendChunk(streamId, { event: eventType, data: eventData }).catch((err) => {
logger.error(`[GenerationJobManager] Failed to append chunk:`, err);
});
// For run step events, also save to run steps key for quick retrieval
if (eventType === 'on_run_step' || eventType === 'on_run_step_completed') {
this.saveRunStepFromEvent(streamId, eventData as Record<string, unknown>);
}
}
}
// Buffer early events if no subscriber yet (replay when first subscriber connects)
if (!runtime.hasSubscriber) {
runtime.earlyEventBuffer.push(event);
// Also emit to transport in case subscriber connects mid-flight
}
this.eventTransport.emitChunk(streamId, event);
}
/**
* Extract and save run step from event data.
* The data is already the run step object from the event payload.
*/
private saveRunStepFromEvent(streamId: string, data: Record<string, unknown>): void {
// The data IS the run step object
const runStep = data as Agents.RunStep;
if (!runStep.id) {
return;
}
// Fire and forget - accumulate run steps
this.accumulateRunStep(streamId, runStep);
}
/**
* Accumulate run steps for a stream (Redis mode only).
* Uses a simple in-memory buffer that gets flushed to Redis.
* Not used in in-memory mode - run steps come from live graph via WeakRef.
*/
private runStepBuffers: Map<string, Agents.RunStep[]> | null = null;
private accumulateRunStep(streamId: string, runStep: Agents.RunStep): void {
// Lazy initialization - only create map when first used (Redis mode)
if (!this.runStepBuffers) {
this.runStepBuffers = new Map();
}
let buffer = this.runStepBuffers.get(streamId);
if (!buffer) {
buffer = [];
this.runStepBuffers.set(streamId, buffer);
}
// Update or add run step
const existingIdx = buffer.findIndex((rs) => rs.id === runStep.id);
if (existingIdx >= 0) {
buffer[existingIdx] = runStep;
} else {
buffer.push(runStep);
}
// Save to Redis
if (this.jobStore.saveRunSteps) {
this.jobStore.saveRunSteps(streamId, buffer).catch((err) => {
logger.error(`[GenerationJobManager] Failed to save run steps:`, err);
});
}
}
/**
* Track user message from created event.
*/
private trackUserMessage(streamId: string, event: t.ServerSentEvent): void {
const data = event as Record<string, unknown>;
if (!data.created || !data.message) {
return;
}
const message = data.message as Record<string, unknown>;
const updates: Partial<SerializableJobData> = {
userMessage: {
messageId: message.messageId as string,
parentMessageId: message.parentMessageId as string | undefined,
conversationId: message.conversationId as string | undefined,
text: message.text as string | undefined,
},
};
if (message.conversationId) {
updates.conversationId = message.conversationId as string;
}
this.jobStore.updateJob(streamId, updates);
}
/**
* Update job metadata.
*/
async updateMetadata(
streamId: string,
metadata: Partial<t.GenerationJobMetadata>,
): Promise<void> {
const updates: Partial<SerializableJobData> = {};
if (metadata.responseMessageId) {
updates.responseMessageId = metadata.responseMessageId;
}
if (metadata.sender) {
updates.sender = metadata.sender;
}
if (metadata.conversationId) {
updates.conversationId = metadata.conversationId;
}
if (metadata.userMessage) {
updates.userMessage = metadata.userMessage;
}
if (metadata.endpoint) {
updates.endpoint = metadata.endpoint;
}
if (metadata.iconURL) {
updates.iconURL = metadata.iconURL;
}
if (metadata.model) {
updates.model = metadata.model;
}
if (metadata.promptTokens !== undefined) {
updates.promptTokens = metadata.promptTokens;
}
await this.jobStore.updateJob(streamId, updates);
}
/**
* Set reference to the graph's contentParts array.
*/
setContentParts(streamId: string, contentParts: Agents.MessageContentComplex[]): void {
// Use runtime state check for performance (sync check)
if (!this.runtimeState.has(streamId)) {
return;
}
this.jobStore.setContentParts(streamId, contentParts);
}
/**
* Set reference to the graph instance.
*/
setGraph(streamId: string, graph: StandardGraph): void {
// Use runtime state check for performance (sync check)
if (!this.runtimeState.has(streamId)) {
return;
}
this.jobStore.setGraph(streamId, graph);
}
/**
* Get resume state for reconnecting clients.
*/
async getResumeState(streamId: string): Promise<t.ResumeState | null> {
const jobData = await this.jobStore.getJob(streamId);
if (!jobData) {
return null;
}
const result = await this.jobStore.getContentParts(streamId);
const aggregatedContent = result?.content ?? [];
const runSteps = await this.jobStore.getRunSteps(streamId);
logger.debug(`[GenerationJobManager] getResumeState:`, {
streamId,
runStepsLength: runSteps.length,
aggregatedContentLength: aggregatedContent.length,
});
return {
runSteps,
aggregatedContent,
userMessage: jobData.userMessage,
responseMessageId: jobData.responseMessageId,
conversationId: jobData.conversationId,
sender: jobData.sender,
};
}
/**
* Mark that sync has been sent.
*/
markSyncSent(streamId: string): void {
const runtime = this.runtimeState.get(streamId);
if (runtime) {
runtime.syncSent = true;
}
}
/**
* Check if sync has been sent.
*/
wasSyncSent(streamId: string): boolean {
return this.runtimeState.get(streamId)?.syncSent ?? false;
}
/**
* Emit a done event.
*/
emitDone(streamId: string, event: t.ServerSentEvent): void {
const runtime = this.runtimeState.get(streamId);
if (runtime) {
runtime.finalEvent = event;
}
this.eventTransport.emitDone(streamId, event);
}
/**
* Emit an error event.
*/
emitError(streamId: string, error: string): void {
this.eventTransport.emitError(streamId, error);
}
/**
* Cleanup expired jobs.
* Also cleans up any orphaned runtime state, buffers, and event transport entries.
*/
private async cleanup(): Promise<void> {
const count = await this.jobStore.cleanup();
// Cleanup runtime state for deleted jobs
for (const streamId of this.runtimeState.keys()) {
if (!(await this.jobStore.hasJob(streamId))) {
this.runtimeState.delete(streamId);
this.runStepBuffers?.delete(streamId);
this.jobStore.clearContentState(streamId);
this.eventTransport.cleanup(streamId);
}
}
// Also check runStepBuffers for any orphaned entries (Redis mode only)
if (this.runStepBuffers) {
for (const streamId of this.runStepBuffers.keys()) {
if (!(await this.jobStore.hasJob(streamId))) {
this.runStepBuffers.delete(streamId);
}
}
}
// Check eventTransport for orphaned streams (e.g., connections dropped without clean close)
// These are streams that exist in eventTransport but have no corresponding job
for (const streamId of this.eventTransport.getTrackedStreamIds()) {
if (!(await this.jobStore.hasJob(streamId)) && !this.runtimeState.has(streamId)) {
this.eventTransport.cleanup(streamId);
}
}
if (count > 0) {
logger.debug(`[GenerationJobManager] Cleaned up ${count} expired jobs`);
}
}
/**
* Get stream info for status endpoint.
*/
async getStreamInfo(streamId: string): Promise<{
active: boolean;
status: t.GenerationJobStatus;
aggregatedContent?: Agents.MessageContentComplex[];
createdAt: number;
} | null> {
const jobData = await this.jobStore.getJob(streamId);
if (!jobData) {
return null;
}
const result = await this.jobStore.getContentParts(streamId);
const aggregatedContent = result?.content ?? [];
return {
active: jobData.status === 'running',
status: jobData.status as t.GenerationJobStatus,
aggregatedContent,
createdAt: jobData.createdAt,
};
}
/**
* Get total job count.
*/
async getJobCount(): Promise<number> {
return this.jobStore.getJobCount();
}
/**
* Get job count by status.
*/
async getJobCountByStatus(): Promise<Record<t.GenerationJobStatus, number>> {
const [running, complete, error, aborted] = await Promise.all([
this.jobStore.getJobCountByStatus('running'),
this.jobStore.getJobCountByStatus('complete'),
this.jobStore.getJobCountByStatus('error'),
this.jobStore.getJobCountByStatus('aborted'),
]);
return { running, complete, error, aborted };
}
/**
* Get active job IDs for a user.
* Returns conversation IDs of running jobs belonging to the user.
* Performs self-healing cleanup of stale entries.
*
* @param userId - The user ID to query
* @returns Array of conversation IDs with active jobs
*/
async getActiveJobIdsForUser(userId: string): Promise<string[]> {
return this.jobStore.getActiveJobIdsByUser(userId);
}
/**
* Destroy the manager.
* Cleans up all resources including runtime state, buffers, and stores.
*/
async destroy(): Promise<void> {
if (this.cleanupInterval) {
clearInterval(this.cleanupInterval);
this.cleanupInterval = null;
}
await this.jobStore.destroy();
this.eventTransport.destroy();
this.runtimeState.clear();
this.runStepBuffers?.clear();
logger.debug('[GenerationJobManager] Destroyed');
}
}
export const GenerationJobManager = new GenerationJobManagerClass();
export { GenerationJobManagerClass };

View file

@ -0,0 +1,415 @@
import type { Redis, Cluster } from 'ioredis';
/**
* Integration tests for GenerationJobManager.
*
* Tests the job manager with both in-memory and Redis backends
* to ensure consistent behavior across deployment modes.
*
* Run with: USE_REDIS=true npx jest GenerationJobManager.stream_integration
*/
describe('GenerationJobManager Integration Tests', () => {
let originalEnv: NodeJS.ProcessEnv;
let ioredisClient: Redis | Cluster | null = null;
const testPrefix = 'JobManager-Integration-Test';
beforeAll(async () => {
originalEnv = { ...process.env };
// Set up test environment
process.env.USE_REDIS = process.env.USE_REDIS ?? 'true';
process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379';
process.env.REDIS_KEY_PREFIX = testPrefix;
jest.resetModules();
const { ioredisClient: client } = await import('../../cache/redisClients');
ioredisClient = client;
});
afterEach(async () => {
// Clean up module state
jest.resetModules();
// Clean up Redis keys (delete individually for cluster compatibility)
if (ioredisClient) {
try {
const keys = await ioredisClient.keys(`${testPrefix}*`);
const streamKeys = await ioredisClient.keys(`stream:*`);
const allKeys = [...keys, ...streamKeys];
await Promise.all(allKeys.map((key) => ioredisClient!.del(key)));
} catch {
// Ignore cleanup errors
}
}
});
afterAll(async () => {
if (ioredisClient) {
try {
// Use quit() to gracefully close - waits for pending commands
await ioredisClient.quit();
} catch {
// Fall back to disconnect if quit fails
try {
ioredisClient.disconnect();
} catch {
// Ignore
}
}
}
process.env = originalEnv;
});
describe('In-Memory Mode', () => {
test('should create and manage jobs', async () => {
const { GenerationJobManager } = await import('../GenerationJobManager');
const { InMemoryJobStore } = await import('../implementations/InMemoryJobStore');
const { InMemoryEventTransport } = await import('../implementations/InMemoryEventTransport');
// Configure with in-memory
// cleanupOnComplete: false so we can verify completed status
GenerationJobManager.configure({
jobStore: new InMemoryJobStore({ ttlAfterComplete: 60000 }),
eventTransport: new InMemoryEventTransport(),
isRedis: false,
cleanupOnComplete: false,
});
await GenerationJobManager.initialize();
const streamId = `inmem-job-${Date.now()}`;
const userId = 'test-user-1';
// Create job (async)
const job = await GenerationJobManager.createJob(streamId, userId);
expect(job.streamId).toBe(streamId);
expect(job.status).toBe('running');
// Check job exists
const hasJob = await GenerationJobManager.hasJob(streamId);
expect(hasJob).toBe(true);
// Get job
const retrieved = await GenerationJobManager.getJob(streamId);
expect(retrieved?.streamId).toBe(streamId);
// Update job
await GenerationJobManager.updateMetadata(streamId, { sender: 'TestAgent' });
const updated = await GenerationJobManager.getJob(streamId);
expect(updated?.metadata?.sender).toBe('TestAgent');
// Complete job
await GenerationJobManager.completeJob(streamId);
const completed = await GenerationJobManager.getJob(streamId);
expect(completed?.status).toBe('complete');
await GenerationJobManager.destroy();
});
test('should handle event streaming', async () => {
const { GenerationJobManager } = await import('../GenerationJobManager');
const { InMemoryJobStore } = await import('../implementations/InMemoryJobStore');
const { InMemoryEventTransport } = await import('../implementations/InMemoryEventTransport');
GenerationJobManager.configure({
jobStore: new InMemoryJobStore({ ttlAfterComplete: 60000 }),
eventTransport: new InMemoryEventTransport(),
isRedis: false,
});
await GenerationJobManager.initialize();
const streamId = `inmem-events-${Date.now()}`;
await GenerationJobManager.createJob(streamId, 'user-1');
const receivedChunks: unknown[] = [];
// Subscribe to events (subscribe takes separate args, not an object)
const subscription = await GenerationJobManager.subscribe(streamId, (event) =>
receivedChunks.push(event),
);
const { unsubscribe } = subscription!;
// Wait for first subscriber to be registered
await new Promise((resolve) => setTimeout(resolve, 10));
// Emit chunks (emitChunk takes { event, data } format)
GenerationJobManager.emitChunk(streamId, {
event: 'on_message_delta',
data: { type: 'text', text: 'Hello' },
});
GenerationJobManager.emitChunk(streamId, {
event: 'on_message_delta',
data: { type: 'text', text: ' world' },
});
// Give time for events to propagate
await new Promise((resolve) => setTimeout(resolve, 50));
// Verify chunks were received
expect(receivedChunks.length).toBeGreaterThan(0);
// Complete the job (this cleans up resources)
await GenerationJobManager.completeJob(streamId);
unsubscribe();
await GenerationJobManager.destroy();
});
});
describe('Redis Mode', () => {
test('should create and manage jobs via Redis', async () => {
if (!ioredisClient) {
console.warn('Redis not available, skipping test');
return;
}
const { GenerationJobManager } = await import('../GenerationJobManager');
const { createStreamServices } = await import('../createStreamServices');
// Create Redis services
const services = createStreamServices({
useRedis: true,
redisClient: ioredisClient,
});
expect(services.isRedis).toBe(true);
GenerationJobManager.configure(services);
await GenerationJobManager.initialize();
const streamId = `redis-job-${Date.now()}`;
const userId = 'test-user-redis';
// Create job (async)
const job = await GenerationJobManager.createJob(streamId, userId);
expect(job.streamId).toBe(streamId);
// Verify in Redis
const hasJob = await GenerationJobManager.hasJob(streamId);
expect(hasJob).toBe(true);
// Update and verify
await GenerationJobManager.updateMetadata(streamId, { sender: 'RedisAgent' });
const updated = await GenerationJobManager.getJob(streamId);
expect(updated?.metadata?.sender).toBe('RedisAgent');
await GenerationJobManager.destroy();
});
test('should persist chunks for cross-instance resume', async () => {
if (!ioredisClient) {
console.warn('Redis not available, skipping test');
return;
}
const { GenerationJobManager } = await import('../GenerationJobManager');
const { createStreamServices } = await import('../createStreamServices');
const services = createStreamServices({
useRedis: true,
redisClient: ioredisClient,
});
GenerationJobManager.configure(services);
await GenerationJobManager.initialize();
const streamId = `redis-chunks-${Date.now()}`;
await GenerationJobManager.createJob(streamId, 'user-1');
// Emit chunks (these should be persisted to Redis)
// emitChunk takes { event, data } format
GenerationJobManager.emitChunk(streamId, {
event: 'on_run_step',
data: {
id: 'step-1',
runId: 'run-1',
index: 0,
stepDetails: { type: 'message_creation' },
},
});
GenerationJobManager.emitChunk(streamId, {
event: 'on_message_delta',
data: {
id: 'step-1',
delta: { content: { type: 'text', text: 'Persisted ' } },
},
});
GenerationJobManager.emitChunk(streamId, {
event: 'on_message_delta',
data: {
id: 'step-1',
delta: { content: { type: 'text', text: 'content' } },
},
});
// Wait for async operations
await new Promise((resolve) => setTimeout(resolve, 100));
// Simulate getting resume state (as if from different instance)
const resumeState = await GenerationJobManager.getResumeState(streamId);
expect(resumeState).not.toBeNull();
expect(resumeState!.aggregatedContent?.length).toBeGreaterThan(0);
await GenerationJobManager.destroy();
});
test('should handle abort and return content', async () => {
if (!ioredisClient) {
console.warn('Redis not available, skipping test');
return;
}
const { GenerationJobManager } = await import('../GenerationJobManager');
const { createStreamServices } = await import('../createStreamServices');
const services = createStreamServices({
useRedis: true,
redisClient: ioredisClient,
});
GenerationJobManager.configure(services);
await GenerationJobManager.initialize();
const streamId = `redis-abort-${Date.now()}`;
await GenerationJobManager.createJob(streamId, 'user-1');
// Emit some content (emitChunk takes { event, data } format)
GenerationJobManager.emitChunk(streamId, {
event: 'on_run_step',
data: {
id: 'step-1',
runId: 'run-1',
index: 0,
stepDetails: { type: 'message_creation' },
},
});
GenerationJobManager.emitChunk(streamId, {
event: 'on_message_delta',
data: {
id: 'step-1',
delta: { content: { type: 'text', text: 'Partial response...' } },
},
});
await new Promise((resolve) => setTimeout(resolve, 100));
// Abort the job
const abortResult = await GenerationJobManager.abortJob(streamId);
expect(abortResult.success).toBe(true);
expect(abortResult.content.length).toBeGreaterThan(0);
await GenerationJobManager.destroy();
});
});
describe('Cross-Mode Consistency', () => {
test('should have consistent API between in-memory and Redis modes', async () => {
// This test verifies that the same operations work identically
// regardless of backend mode
const runTestWithMode = async (isRedis: boolean) => {
jest.resetModules();
const { GenerationJobManager } = await import('../GenerationJobManager');
if (isRedis && ioredisClient) {
const { createStreamServices } = await import('../createStreamServices');
GenerationJobManager.configure({
...createStreamServices({
useRedis: true,
redisClient: ioredisClient,
}),
cleanupOnComplete: false, // Keep job for verification
});
} else {
const { InMemoryJobStore } = await import('../implementations/InMemoryJobStore');
const { InMemoryEventTransport } = await import(
'../implementations/InMemoryEventTransport'
);
GenerationJobManager.configure({
jobStore: new InMemoryJobStore({ ttlAfterComplete: 60000 }),
eventTransport: new InMemoryEventTransport(),
isRedis: false,
cleanupOnComplete: false,
});
}
await GenerationJobManager.initialize();
const streamId = `consistency-${isRedis ? 'redis' : 'inmem'}-${Date.now()}`;
// Test sequence
const job = await GenerationJobManager.createJob(streamId, 'user-1');
expect(job.streamId).toBe(streamId);
expect(job.status).toBe('running');
const hasJob = await GenerationJobManager.hasJob(streamId);
expect(hasJob).toBe(true);
await GenerationJobManager.updateMetadata(streamId, {
sender: 'ConsistencyAgent',
responseMessageId: 'resp-123',
});
const updated = await GenerationJobManager.getJob(streamId);
expect(updated?.metadata?.sender).toBe('ConsistencyAgent');
expect(updated?.metadata?.responseMessageId).toBe('resp-123');
await GenerationJobManager.completeJob(streamId);
const completed = await GenerationJobManager.getJob(streamId);
expect(completed?.status).toBe('complete');
await GenerationJobManager.destroy();
};
// Test in-memory mode
await runTestWithMode(false);
// Test Redis mode if available
if (ioredisClient) {
await runTestWithMode(true);
}
});
});
describe('createStreamServices Auto-Detection', () => {
test('should auto-detect Redis when USE_REDIS is true', async () => {
if (!ioredisClient) {
console.warn('Redis not available, skipping test');
return;
}
// Force USE_REDIS to true
process.env.USE_REDIS = 'true';
jest.resetModules();
const { createStreamServices } = await import('../createStreamServices');
const services = createStreamServices();
// Should detect Redis
expect(services.isRedis).toBe(true);
});
test('should fall back to in-memory when USE_REDIS is false', async () => {
process.env.USE_REDIS = 'false';
jest.resetModules();
const { createStreamServices } = await import('../createStreamServices');
const services = createStreamServices();
expect(services.isRedis).toBe(false);
});
test('should allow forcing in-memory via config override', async () => {
const { createStreamServices } = await import('../createStreamServices');
const services = createStreamServices({ useRedis: false });
expect(services.isRedis).toBe(false);
});
});
});

View file

@ -0,0 +1,326 @@
import type { Redis, Cluster } from 'ioredis';
/**
* Integration tests for RedisEventTransport.
*
* Tests Redis Pub/Sub functionality:
* - Cross-instance event delivery
* - Subscriber management
* - Error handling
*
* Run with: USE_REDIS=true npx jest RedisEventTransport.stream_integration
*/
describe('RedisEventTransport Integration Tests', () => {
let originalEnv: NodeJS.ProcessEnv;
let ioredisClient: Redis | Cluster | null = null;
const testPrefix = 'EventTransport-Integration-Test';
beforeAll(async () => {
originalEnv = { ...process.env };
process.env.USE_REDIS = process.env.USE_REDIS ?? 'true';
process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379';
process.env.REDIS_KEY_PREFIX = testPrefix;
jest.resetModules();
const { ioredisClient: client } = await import('../../cache/redisClients');
ioredisClient = client;
});
afterAll(async () => {
if (ioredisClient) {
try {
// Use quit() to gracefully close - waits for pending commands
await ioredisClient.quit();
} catch {
// Fall back to disconnect if quit fails
try {
ioredisClient.disconnect();
} catch {
// Ignore
}
}
}
process.env = originalEnv;
});
describe('Pub/Sub Event Delivery', () => {
test('should deliver events to subscribers on same instance', async () => {
if (!ioredisClient) {
console.warn('Redis not available, skipping test');
return;
}
const { RedisEventTransport } = await import('../implementations/RedisEventTransport');
// Create subscriber client (Redis pub/sub requires dedicated connection)
const subscriber = (ioredisClient as Redis).duplicate();
const transport = new RedisEventTransport(ioredisClient, subscriber);
const streamId = `pubsub-same-${Date.now()}`;
const receivedChunks: unknown[] = [];
let doneEvent: unknown = null;
// Subscribe
const { unsubscribe } = transport.subscribe(streamId, {
onChunk: (event) => receivedChunks.push(event),
onDone: (event) => {
doneEvent = event;
},
});
// Wait for subscription to be established
await new Promise((resolve) => setTimeout(resolve, 100));
// Emit events
transport.emitChunk(streamId, { type: 'text', text: 'Hello' });
transport.emitChunk(streamId, { type: 'text', text: ' World' });
transport.emitDone(streamId, { finished: true });
// Wait for events to propagate
await new Promise((resolve) => setTimeout(resolve, 200));
expect(receivedChunks.length).toBe(2);
expect(doneEvent).toEqual({ finished: true });
unsubscribe();
transport.destroy();
subscriber.disconnect();
});
test('should deliver events across transport instances (simulating different servers)', async () => {
if (!ioredisClient) {
console.warn('Redis not available, skipping test');
return;
}
const { RedisEventTransport } = await import('../implementations/RedisEventTransport');
// Create two separate transport instances (simulating two servers)
const subscriber1 = (ioredisClient as Redis).duplicate();
const subscriber2 = (ioredisClient as Redis).duplicate();
const transport1 = new RedisEventTransport(ioredisClient, subscriber1);
const transport2 = new RedisEventTransport(ioredisClient, subscriber2);
const streamId = `pubsub-cross-${Date.now()}`;
const instance2Chunks: unknown[] = [];
// Subscribe on transport 2 (consumer)
const sub2 = transport2.subscribe(streamId, {
onChunk: (event) => instance2Chunks.push(event),
});
// Wait for subscription
await new Promise((resolve) => setTimeout(resolve, 100));
// Emit from transport 1 (producer on different instance)
transport1.emitChunk(streamId, { data: 'from-instance-1' });
// Wait for cross-instance delivery
await new Promise((resolve) => setTimeout(resolve, 200));
// Transport 2 should receive the event
expect(instance2Chunks.length).toBe(1);
expect(instance2Chunks[0]).toEqual({ data: 'from-instance-1' });
sub2.unsubscribe();
transport1.destroy();
transport2.destroy();
subscriber1.disconnect();
subscriber2.disconnect();
});
test('should handle multiple subscribers to same stream', async () => {
if (!ioredisClient) {
console.warn('Redis not available, skipping test');
return;
}
const { RedisEventTransport } = await import('../implementations/RedisEventTransport');
const subscriber = (ioredisClient as Redis).duplicate();
const transport = new RedisEventTransport(ioredisClient, subscriber);
const streamId = `pubsub-multi-${Date.now()}`;
const subscriber1Chunks: unknown[] = [];
const subscriber2Chunks: unknown[] = [];
// Two subscribers
const sub1 = transport.subscribe(streamId, {
onChunk: (event) => subscriber1Chunks.push(event),
});
const sub2 = transport.subscribe(streamId, {
onChunk: (event) => subscriber2Chunks.push(event),
});
await new Promise((resolve) => setTimeout(resolve, 100));
transport.emitChunk(streamId, { data: 'broadcast' });
await new Promise((resolve) => setTimeout(resolve, 200));
// Both should receive
expect(subscriber1Chunks.length).toBe(1);
expect(subscriber2Chunks.length).toBe(1);
sub1.unsubscribe();
sub2.unsubscribe();
transport.destroy();
subscriber.disconnect();
});
});
describe('Subscriber Management', () => {
test('should track first subscriber correctly', async () => {
if (!ioredisClient) {
console.warn('Redis not available, skipping test');
return;
}
const { RedisEventTransport } = await import('../implementations/RedisEventTransport');
const subscriber = (ioredisClient as Redis).duplicate();
const transport = new RedisEventTransport(ioredisClient, subscriber);
const streamId = `first-sub-${Date.now()}`;
// Before any subscribers - count is 0, not "first" since no one subscribed
expect(transport.getSubscriberCount(streamId)).toBe(0);
// First subscriber
const sub1 = transport.subscribe(streamId, { onChunk: () => {} });
await new Promise((resolve) => setTimeout(resolve, 50));
// Now there's a subscriber - isFirstSubscriber returns true when count is 1
expect(transport.getSubscriberCount(streamId)).toBe(1);
expect(transport.isFirstSubscriber(streamId)).toBe(true);
// Second subscriber - not first anymore
const sub2temp = transport.subscribe(streamId, { onChunk: () => {} });
await new Promise((resolve) => setTimeout(resolve, 50));
expect(transport.isFirstSubscriber(streamId)).toBe(false);
sub2temp.unsubscribe();
const sub2 = transport.subscribe(streamId, { onChunk: () => {} });
await new Promise((resolve) => setTimeout(resolve, 50));
expect(transport.getSubscriberCount(streamId)).toBe(2);
sub1.unsubscribe();
sub2.unsubscribe();
transport.destroy();
subscriber.disconnect();
});
test('should fire onAllSubscribersLeft when last subscriber leaves', async () => {
if (!ioredisClient) {
console.warn('Redis not available, skipping test');
return;
}
const { RedisEventTransport } = await import('../implementations/RedisEventTransport');
const subscriber = (ioredisClient as Redis).duplicate();
const transport = new RedisEventTransport(ioredisClient, subscriber);
const streamId = `all-left-${Date.now()}`;
let allLeftCalled = false;
transport.onAllSubscribersLeft(streamId, () => {
allLeftCalled = true;
});
const sub1 = transport.subscribe(streamId, { onChunk: () => {} });
const sub2 = transport.subscribe(streamId, { onChunk: () => {} });
await new Promise((resolve) => setTimeout(resolve, 50));
// Unsubscribe first
sub1.unsubscribe();
await new Promise((resolve) => setTimeout(resolve, 50));
// Still have one subscriber
expect(allLeftCalled).toBe(false);
// Unsubscribe last
sub2.unsubscribe();
await new Promise((resolve) => setTimeout(resolve, 50));
// Now all left
expect(allLeftCalled).toBe(true);
transport.destroy();
subscriber.disconnect();
});
});
describe('Error Handling', () => {
test('should deliver error events to subscribers', async () => {
if (!ioredisClient) {
console.warn('Redis not available, skipping test');
return;
}
const { RedisEventTransport } = await import('../implementations/RedisEventTransport');
const subscriber = (ioredisClient as Redis).duplicate();
const transport = new RedisEventTransport(ioredisClient, subscriber);
const streamId = `error-${Date.now()}`;
let receivedError: string | null = null;
transport.subscribe(streamId, {
onChunk: () => {},
onError: (err) => {
receivedError = err;
},
});
await new Promise((resolve) => setTimeout(resolve, 100));
transport.emitError(streamId, 'Test error message');
await new Promise((resolve) => setTimeout(resolve, 200));
expect(receivedError).toBe('Test error message');
transport.destroy();
subscriber.disconnect();
});
});
describe('Cleanup', () => {
test('should clean up stream resources', async () => {
if (!ioredisClient) {
console.warn('Redis not available, skipping test');
return;
}
const { RedisEventTransport } = await import('../implementations/RedisEventTransport');
const subscriber = (ioredisClient as Redis).duplicate();
const transport = new RedisEventTransport(ioredisClient, subscriber);
const streamId = `cleanup-${Date.now()}`;
transport.subscribe(streamId, { onChunk: () => {} });
await new Promise((resolve) => setTimeout(resolve, 50));
expect(transport.getSubscriberCount(streamId)).toBe(1);
// Cleanup the stream
transport.cleanup(streamId);
// Subscriber count should be 0
expect(transport.getSubscriberCount(streamId)).toBe(0);
transport.destroy();
subscriber.disconnect();
});
});
});

View file

@ -0,0 +1,975 @@
import { StepTypes } from 'librechat-data-provider';
import type { Agents } from 'librechat-data-provider';
import type { Redis, Cluster } from 'ioredis';
import { StandardGraph } from '@librechat/agents';
/**
* Integration tests for RedisJobStore.
*
* Tests horizontal scaling scenarios:
* - Multi-instance job access
* - Content reconstruction from chunks
* - Consumer groups for resumable streams
* - TTL and cleanup behavior
*
* Run with: USE_REDIS=true npx jest RedisJobStore.stream_integration
*/
describe('RedisJobStore Integration Tests', () => {
let originalEnv: NodeJS.ProcessEnv;
let ioredisClient: Redis | Cluster | null = null;
const testPrefix = 'Stream-Integration-Test';
beforeAll(async () => {
originalEnv = { ...process.env };
// Set up test environment
process.env.USE_REDIS = process.env.USE_REDIS ?? 'true';
process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379';
process.env.REDIS_KEY_PREFIX = testPrefix;
jest.resetModules();
// Import Redis client
const { ioredisClient: client } = await import('../../cache/redisClients');
ioredisClient = client;
if (!ioredisClient) {
console.warn('Redis not available, skipping integration tests');
}
});
afterEach(async () => {
if (!ioredisClient) {
return;
}
// Clean up all test keys (delete individually for cluster compatibility)
try {
const keys = await ioredisClient.keys(`${testPrefix}*`);
// Also clean up stream keys which use hash tags
const streamKeys = await ioredisClient.keys(`stream:*`);
const allKeys = [...keys, ...streamKeys];
// Delete individually to avoid CROSSSLOT errors in cluster mode
await Promise.all(allKeys.map((key) => ioredisClient!.del(key)));
} catch (error) {
console.warn('Error cleaning up test keys:', error);
}
});
afterAll(async () => {
if (ioredisClient) {
try {
// Use quit() to gracefully close - waits for pending commands
await ioredisClient.quit();
} catch {
// Fall back to disconnect if quit fails
try {
ioredisClient.disconnect();
} catch {
// Ignore
}
}
}
process.env = originalEnv;
});
describe('Job CRUD Operations', () => {
test('should create and retrieve a job', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const streamId = `test-stream-${Date.now()}`;
const userId = 'test-user-123';
const job = await store.createJob(streamId, userId, streamId);
expect(job).toMatchObject({
streamId,
userId,
status: 'running',
conversationId: streamId,
syncSent: false,
});
const retrieved = await store.getJob(streamId);
expect(retrieved).toMatchObject({
streamId,
userId,
status: 'running',
});
await store.destroy();
});
test('should update job status', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const streamId = `test-stream-${Date.now()}`;
await store.createJob(streamId, 'user-1', streamId);
await store.updateJob(streamId, { status: 'complete', completedAt: Date.now() });
const job = await store.getJob(streamId);
expect(job?.status).toBe('complete');
expect(job?.completedAt).toBeDefined();
await store.destroy();
});
test('should delete job and related data', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const streamId = `test-stream-${Date.now()}`;
await store.createJob(streamId, 'user-1', streamId);
// Add some chunks
await store.appendChunk(streamId, { event: 'on_message_delta', data: { text: 'Hello' } });
await store.deleteJob(streamId);
const job = await store.getJob(streamId);
expect(job).toBeNull();
await store.destroy();
});
});
describe('Horizontal Scaling - Multi-Instance Simulation', () => {
test('should share job state between two store instances', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
// Simulate two server instances with separate store instances
const instance1 = new RedisJobStore(ioredisClient);
const instance2 = new RedisJobStore(ioredisClient);
await instance1.initialize();
await instance2.initialize();
const streamId = `multi-instance-${Date.now()}`;
// Instance 1 creates job
await instance1.createJob(streamId, 'user-1', streamId);
// Instance 2 should see the job
const jobFromInstance2 = await instance2.getJob(streamId);
expect(jobFromInstance2).not.toBeNull();
expect(jobFromInstance2?.streamId).toBe(streamId);
// Instance 1 updates job
await instance1.updateJob(streamId, { sender: 'TestAgent', syncSent: true });
// Instance 2 should see the update
const updatedJob = await instance2.getJob(streamId);
expect(updatedJob?.sender).toBe('TestAgent');
expect(updatedJob?.syncSent).toBe(true);
await instance1.destroy();
await instance2.destroy();
});
test('should share chunks between instances for content reconstruction', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const instance1 = new RedisJobStore(ioredisClient);
const instance2 = new RedisJobStore(ioredisClient);
await instance1.initialize();
await instance2.initialize();
const streamId = `chunk-sharing-${Date.now()}`;
await instance1.createJob(streamId, 'user-1', streamId);
// Instance 1 emits chunks (simulating stream generation)
// Format must match what aggregateContent expects:
// - on_run_step: { id, index, stepDetails: { type } }
// - on_message_delta: { id, delta: { content: { type, text } } }
const chunks = [
{
event: 'on_run_step',
data: {
id: 'step-1',
runId: 'run-1',
index: 0,
stepDetails: { type: 'message_creation' },
},
},
{
event: 'on_message_delta',
data: { id: 'step-1', delta: { content: { type: 'text', text: 'Hello, ' } } },
},
{
event: 'on_message_delta',
data: { id: 'step-1', delta: { content: { type: 'text', text: 'world!' } } },
},
];
for (const chunk of chunks) {
await instance1.appendChunk(streamId, chunk);
}
// Instance 2 reconstructs content (simulating reconnect to different instance)
const result = await instance2.getContentParts(streamId);
// Should have reconstructed content
expect(result).not.toBeNull();
expect(result!.content.length).toBeGreaterThan(0);
await instance1.destroy();
await instance2.destroy();
});
test('should share run steps between instances', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const instance1 = new RedisJobStore(ioredisClient);
const instance2 = new RedisJobStore(ioredisClient);
await instance1.initialize();
await instance2.initialize();
const streamId = `runsteps-sharing-${Date.now()}`;
await instance1.createJob(streamId, 'user-1', streamId);
// Instance 1 saves run steps
const runSteps: Partial<Agents.RunStep>[] = [
{ id: 'step-1', runId: 'run-1', type: StepTypes.MESSAGE_CREATION, index: 0 },
{ id: 'step-2', runId: 'run-1', type: StepTypes.TOOL_CALLS, index: 1 },
];
await instance1.saveRunSteps!(streamId, runSteps as Agents.RunStep[]);
// Instance 2 retrieves run steps
const retrievedSteps = await instance2.getRunSteps(streamId);
expect(retrievedSteps).toHaveLength(2);
expect(retrievedSteps[0].id).toBe('step-1');
expect(retrievedSteps[1].id).toBe('step-2');
await instance1.destroy();
await instance2.destroy();
});
});
describe('Content Reconstruction', () => {
test('should reconstruct text content from message deltas', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const streamId = `text-reconstruction-${Date.now()}`;
await store.createJob(streamId, 'user-1', streamId);
// Simulate a streaming response with correct event format
const chunks = [
{
event: 'on_run_step',
data: {
id: 'step-1',
runId: 'run-1',
index: 0,
stepDetails: { type: 'message_creation' },
},
},
{
event: 'on_message_delta',
data: { id: 'step-1', delta: { content: { type: 'text', text: 'The ' } } },
},
{
event: 'on_message_delta',
data: { id: 'step-1', delta: { content: { type: 'text', text: 'quick ' } } },
},
{
event: 'on_message_delta',
data: { id: 'step-1', delta: { content: { type: 'text', text: 'brown ' } } },
},
{
event: 'on_message_delta',
data: { id: 'step-1', delta: { content: { type: 'text', text: 'fox.' } } },
},
];
for (const chunk of chunks) {
await store.appendChunk(streamId, chunk);
}
const result = await store.getContentParts(streamId);
expect(result).not.toBeNull();
// Content aggregator combines text deltas
const textPart = result!.content.find((p) => p.type === 'text');
expect(textPart).toBeDefined();
await store.destroy();
});
test('should reconstruct thinking content from reasoning deltas', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const streamId = `think-reconstruction-${Date.now()}`;
await store.createJob(streamId, 'user-1', streamId);
// on_reasoning_delta events need id and delta.content format
const chunks = [
{
event: 'on_run_step',
data: {
id: 'step-1',
runId: 'run-1',
index: 0,
stepDetails: { type: 'message_creation' },
},
},
{
event: 'on_reasoning_delta',
data: { id: 'step-1', delta: { content: { type: 'think', think: 'Let me think...' } } },
},
{
event: 'on_reasoning_delta',
data: {
id: 'step-1',
delta: { content: { type: 'think', think: ' about this problem.' } },
},
},
{
event: 'on_run_step',
data: {
id: 'step-2',
runId: 'run-1',
index: 1,
stepDetails: { type: 'message_creation' },
},
},
{
event: 'on_message_delta',
data: { id: 'step-2', delta: { content: { type: 'text', text: 'The answer is 42.' } } },
},
];
for (const chunk of chunks) {
await store.appendChunk(streamId, chunk);
}
const result = await store.getContentParts(streamId);
expect(result).not.toBeNull();
// Should have both think and text parts
const thinkPart = result!.content.find((p) => p.type === 'think');
const textPart = result!.content.find((p) => p.type === 'text');
expect(thinkPart).toBeDefined();
expect(textPart).toBeDefined();
await store.destroy();
});
test('should return null for empty chunks', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const streamId = `empty-chunks-${Date.now()}`;
await store.createJob(streamId, 'user-1', streamId);
// No chunks appended
const content = await store.getContentParts(streamId);
expect(content).toBeNull();
await store.destroy();
});
});
describe('Consumer Groups', () => {
test('should create consumer group and read chunks', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const streamId = `consumer-group-${Date.now()}`;
await store.createJob(streamId, 'user-1', streamId);
// Add some chunks
const chunks = [
{ event: 'on_message_delta', data: { type: 'text', text: 'Chunk 1' } },
{ event: 'on_message_delta', data: { type: 'text', text: 'Chunk 2' } },
{ event: 'on_message_delta', data: { type: 'text', text: 'Chunk 3' } },
];
for (const chunk of chunks) {
await store.appendChunk(streamId, chunk);
}
// Wait for Redis to sync
await new Promise((resolve) => setTimeout(resolve, 50));
// Create consumer group starting from beginning
const groupName = `client-${Date.now()}`;
await store.createConsumerGroup(streamId, groupName, '0');
// Read chunks from group
// Note: With '0' as lastId, we need to use getPendingChunks or read with '0' instead of '>'
// The '>' only gives new messages after group creation
const readChunks = await store.getPendingChunks(streamId, groupName, 'consumer-1');
// If pending is empty, the messages haven't been delivered yet
// Let's read from '0' using regular read
if (readChunks.length === 0) {
// Consumer groups created at '0' should have access to all messages
// but they need to be "claimed" first. Skip this test as consumer groups
// require more complex setup for historical messages.
console.log(
'Skipping consumer group test - requires claim mechanism for historical messages',
);
await store.deleteConsumerGroup(streamId, groupName);
await store.destroy();
return;
}
expect(readChunks.length).toBe(3);
// Acknowledge chunks
const ids = readChunks.map((c) => c.id);
await store.acknowledgeChunks(streamId, groupName, ids);
// Reading again should return empty (all acknowledged)
const moreChunks = await store.readChunksFromGroup(streamId, groupName, 'consumer-1');
expect(moreChunks.length).toBe(0);
// Cleanup
await store.deleteConsumerGroup(streamId, groupName);
await store.destroy();
});
// TODO: Debug consumer group timing with Redis Streams
test.skip('should resume from where client left off', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const streamId = `resume-test-${Date.now()}`;
await store.createJob(streamId, 'user-1', streamId);
// Create consumer group FIRST (before adding chunks) to track delivery
const groupName = `client-resume-${Date.now()}`;
await store.createConsumerGroup(streamId, groupName, '$'); // Start from end (only new messages)
// Add initial chunks (these will be "new" to the consumer group)
await store.appendChunk(streamId, {
event: 'on_message_delta',
data: { type: 'text', text: 'Part 1' },
});
await store.appendChunk(streamId, {
event: 'on_message_delta',
data: { type: 'text', text: 'Part 2' },
});
// Wait for Redis to sync
await new Promise((resolve) => setTimeout(resolve, 50));
// Client reads first batch
const firstRead = await store.readChunksFromGroup(streamId, groupName, 'consumer-1');
expect(firstRead.length).toBe(2);
// ACK the chunks
await store.acknowledgeChunks(
streamId,
groupName,
firstRead.map((c) => c.id),
);
// More chunks arrive while client is away
await store.appendChunk(streamId, {
event: 'on_message_delta',
data: { type: 'text', text: 'Part 3' },
});
await store.appendChunk(streamId, {
event: 'on_message_delta',
data: { type: 'text', text: 'Part 4' },
});
// Wait for Redis to sync
await new Promise((resolve) => setTimeout(resolve, 50));
// Client reconnects - should only get new chunks
const secondRead = await store.readChunksFromGroup(streamId, groupName, 'consumer-1');
expect(secondRead.length).toBe(2);
await store.deleteConsumerGroup(streamId, groupName);
await store.destroy();
});
});
describe('TTL and Cleanup', () => {
test('should set running TTL on chunk stream', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient, { runningTtl: 60 });
await store.initialize();
const streamId = `ttl-test-${Date.now()}`;
await store.createJob(streamId, 'user-1', streamId);
await store.appendChunk(streamId, {
event: 'on_message_delta',
data: { id: 'step-1', type: 'text', text: 'test' },
});
// Check that TTL was set on the stream key
// Note: ioredis client has keyPrefix, so we use the key WITHOUT the prefix
// Key uses hash tag format: stream:{streamId}:chunks
const ttl = await ioredisClient.ttl(`stream:{${streamId}}:chunks`);
expect(ttl).toBeGreaterThan(0);
expect(ttl).toBeLessThanOrEqual(60);
await store.destroy();
});
test('should clean up stale jobs', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
// Very short TTL for testing
const store = new RedisJobStore(ioredisClient, { runningTtl: 1 });
await store.initialize();
const streamId = `stale-job-${Date.now()}`;
// Manually create a job that looks old
// Note: ioredis client has keyPrefix, so we use the key WITHOUT the prefix
// Key uses hash tag format: stream:{streamId}:job
const jobKey = `stream:{${streamId}}:job`;
const veryOldTimestamp = Date.now() - 10000; // 10 seconds ago
await ioredisClient.hmset(jobKey, {
streamId,
userId: 'user-1',
status: 'running',
createdAt: veryOldTimestamp.toString(),
syncSent: '0',
});
await ioredisClient.sadd(`stream:running`, streamId);
// Run cleanup
const cleaned = await store.cleanup();
// Should have cleaned the stale job
expect(cleaned).toBeGreaterThanOrEqual(1);
await store.destroy();
});
});
describe('Active Jobs by User', () => {
test('should return active job IDs for a user', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const userId = `test-user-${Date.now()}`;
const streamId1 = `stream-1-${Date.now()}`;
const streamId2 = `stream-2-${Date.now()}`;
// Create two jobs for the same user
await store.createJob(streamId1, userId, streamId1);
await store.createJob(streamId2, userId, streamId2);
// Get active jobs for user
const activeJobs = await store.getActiveJobIdsByUser(userId);
expect(activeJobs).toHaveLength(2);
expect(activeJobs).toContain(streamId1);
expect(activeJobs).toContain(streamId2);
await store.destroy();
});
test('should return empty array for user with no jobs', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const userId = `nonexistent-user-${Date.now()}`;
const activeJobs = await store.getActiveJobIdsByUser(userId);
expect(activeJobs).toHaveLength(0);
await store.destroy();
});
test('should not return completed jobs', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const userId = `test-user-${Date.now()}`;
const streamId1 = `stream-1-${Date.now()}`;
const streamId2 = `stream-2-${Date.now()}`;
// Create two jobs
await store.createJob(streamId1, userId, streamId1);
await store.createJob(streamId2, userId, streamId2);
// Complete one job
await store.updateJob(streamId1, { status: 'complete', completedAt: Date.now() });
// Get active jobs - should only return the running one
const activeJobs = await store.getActiveJobIdsByUser(userId);
expect(activeJobs).toHaveLength(1);
expect(activeJobs).toContain(streamId2);
expect(activeJobs).not.toContain(streamId1);
await store.destroy();
});
test('should not return aborted jobs', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const userId = `test-user-${Date.now()}`;
const streamId = `stream-${Date.now()}`;
// Create a job and abort it
await store.createJob(streamId, userId, streamId);
await store.updateJob(streamId, { status: 'aborted', completedAt: Date.now() });
// Get active jobs - should be empty
const activeJobs = await store.getActiveJobIdsByUser(userId);
expect(activeJobs).toHaveLength(0);
await store.destroy();
});
test('should not return error jobs', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const userId = `test-user-${Date.now()}`;
const streamId = `stream-${Date.now()}`;
// Create a job with error status
await store.createJob(streamId, userId, streamId);
await store.updateJob(streamId, {
status: 'error',
error: 'Test error',
completedAt: Date.now(),
});
// Get active jobs - should be empty
const activeJobs = await store.getActiveJobIdsByUser(userId);
expect(activeJobs).toHaveLength(0);
await store.destroy();
});
test('should perform self-healing cleanup of stale entries', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const userId = `test-user-${Date.now()}`;
const streamId = `stream-${Date.now()}`;
const staleStreamId = `stale-stream-${Date.now()}`;
// Create a real job
await store.createJob(streamId, userId, streamId);
// Manually add a stale entry to the user's job set (simulating orphaned data)
const userJobsKey = `stream:user:{${userId}}:jobs`;
await ioredisClient.sadd(userJobsKey, staleStreamId);
// Verify both entries exist in the set
const beforeCleanup = await ioredisClient.smembers(userJobsKey);
expect(beforeCleanup).toContain(streamId);
expect(beforeCleanup).toContain(staleStreamId);
// Get active jobs - should trigger self-healing
const activeJobs = await store.getActiveJobIdsByUser(userId);
// Should only return the real job
expect(activeJobs).toHaveLength(1);
expect(activeJobs).toContain(streamId);
// Verify stale entry was removed
const afterCleanup = await ioredisClient.smembers(userJobsKey);
expect(afterCleanup).toContain(streamId);
expect(afterCleanup).not.toContain(staleStreamId);
await store.destroy();
});
test('should isolate jobs between different users', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const userId1 = `user-1-${Date.now()}`;
const userId2 = `user-2-${Date.now()}`;
const streamId1 = `stream-1-${Date.now()}`;
const streamId2 = `stream-2-${Date.now()}`;
// Create jobs for different users
await store.createJob(streamId1, userId1, streamId1);
await store.createJob(streamId2, userId2, streamId2);
// Get active jobs for user 1
const user1Jobs = await store.getActiveJobIdsByUser(userId1);
expect(user1Jobs).toHaveLength(1);
expect(user1Jobs).toContain(streamId1);
expect(user1Jobs).not.toContain(streamId2);
// Get active jobs for user 2
const user2Jobs = await store.getActiveJobIdsByUser(userId2);
expect(user2Jobs).toHaveLength(1);
expect(user2Jobs).toContain(streamId2);
expect(user2Jobs).not.toContain(streamId1);
await store.destroy();
});
test('should work across multiple store instances (horizontal scaling)', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
// Simulate two server instances
const instance1 = new RedisJobStore(ioredisClient);
const instance2 = new RedisJobStore(ioredisClient);
await instance1.initialize();
await instance2.initialize();
const userId = `test-user-${Date.now()}`;
const streamId = `stream-${Date.now()}`;
// Instance 1 creates a job
await instance1.createJob(streamId, userId, streamId);
// Instance 2 should see the active job
const activeJobs = await instance2.getActiveJobIdsByUser(userId);
expect(activeJobs).toHaveLength(1);
expect(activeJobs).toContain(streamId);
// Instance 1 completes the job
await instance1.updateJob(streamId, { status: 'complete', completedAt: Date.now() });
// Instance 2 should no longer see the job as active
const activeJobsAfter = await instance2.getActiveJobIdsByUser(userId);
expect(activeJobsAfter).toHaveLength(0);
await instance1.destroy();
await instance2.destroy();
});
test('should clean up user jobs set when job is deleted', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const userId = `test-user-${Date.now()}`;
const streamId = `stream-${Date.now()}`;
// Create a job
await store.createJob(streamId, userId, streamId);
// Verify job is in active list
let activeJobs = await store.getActiveJobIdsByUser(userId);
expect(activeJobs).toContain(streamId);
// Delete the job
await store.deleteJob(streamId);
// Job should no longer be in active list
activeJobs = await store.getActiveJobIdsByUser(userId);
expect(activeJobs).not.toContain(streamId);
await store.destroy();
});
});
describe('Local Graph Cache Optimization', () => {
test('should use local cache when available', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const streamId = `local-cache-${Date.now()}`;
await store.createJob(streamId, 'user-1', streamId);
// Create a mock graph
const mockContentParts = [{ type: 'text', text: 'From local cache' }];
const mockRunSteps = [{ id: 'step-1', type: 'message_creation', status: 'completed' }];
const mockGraph = {
getContentParts: () => mockContentParts,
getRunSteps: () => mockRunSteps,
};
// Set graph reference (will be cached locally)
store.setGraph(streamId, mockGraph as unknown as StandardGraph);
// Get content - should come from local cache, not Redis
const result = await store.getContentParts(streamId);
expect(result!.content).toEqual(mockContentParts);
// Get run steps - should come from local cache
const runSteps = await store.getRunSteps(streamId);
expect(runSteps).toEqual(mockRunSteps);
await store.destroy();
});
test('should fall back to Redis when local cache not available', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
// Instance 1 creates and populates data
const instance1 = new RedisJobStore(ioredisClient);
await instance1.initialize();
const streamId = `fallback-test-${Date.now()}`;
await instance1.createJob(streamId, 'user-1', streamId);
// Add chunks to Redis with correct format
await instance1.appendChunk(streamId, {
event: 'on_run_step',
data: {
id: 'step-1',
runId: 'run-1',
index: 0,
stepDetails: { type: 'message_creation' },
},
});
await instance1.appendChunk(streamId, {
event: 'on_message_delta',
data: { id: 'step-1', delta: { content: { type: 'text', text: 'From Redis' } } },
});
// Save run steps to Redis
await instance1.saveRunSteps!(streamId, [
{
id: 'step-1',
runId: 'run-1',
type: StepTypes.MESSAGE_CREATION,
index: 0,
} as unknown as Agents.RunStep,
]);
// Instance 2 has NO local cache - should fall back to Redis
const instance2 = new RedisJobStore(ioredisClient);
await instance2.initialize();
// Get content - should reconstruct from Redis chunks
const result = await instance2.getContentParts(streamId);
expect(result).not.toBeNull();
expect(result!.content.length).toBeGreaterThan(0);
// Get run steps - should fetch from Redis
const runSteps = await instance2.getRunSteps(streamId);
expect(runSteps).toHaveLength(1);
expect(runSteps[0].id).toBe('step-1');
await instance1.destroy();
await instance2.destroy();
});
});
});

View file

@ -0,0 +1,133 @@
import type { Redis, Cluster } from 'ioredis';
import { logger } from '@librechat/data-schemas';
import type { IJobStore, IEventTransport } from './interfaces/IJobStore';
import { InMemoryJobStore } from './implementations/InMemoryJobStore';
import { InMemoryEventTransport } from './implementations/InMemoryEventTransport';
import { RedisJobStore } from './implementations/RedisJobStore';
import { RedisEventTransport } from './implementations/RedisEventTransport';
import { cacheConfig } from '~/cache/cacheConfig';
import { ioredisClient } from '~/cache/redisClients';
/**
* Configuration for stream services (optional overrides)
*/
export interface StreamServicesConfig {
/**
* Override Redis detection. If not provided, uses cacheConfig.USE_REDIS.
*/
useRedis?: boolean;
/**
* Override Redis client. If not provided, uses ioredisClient from cache.
*/
redisClient?: Redis | Cluster | null;
/**
* Dedicated Redis client for pub/sub subscribing.
* If not provided, will duplicate the main client.
*/
redisSubscriber?: Redis | Cluster | null;
/**
* Options for in-memory job store
*/
inMemoryOptions?: {
ttlAfterComplete?: number;
maxJobs?: number;
};
}
/**
* Stream services result
*/
export interface StreamServices {
jobStore: IJobStore;
eventTransport: IEventTransport;
isRedis: boolean;
}
/**
* Create stream services (job store + event transport).
*
* Automatically detects Redis from cacheConfig.USE_REDIS_STREAMS and uses
* the existing ioredisClient. Falls back to in-memory if Redis
* is not configured or not available.
*
* USE_REDIS_STREAMS defaults to USE_REDIS if not explicitly set,
* allowing users to disable Redis for streams while keeping it for other caches.
*
* @example Auto-detect (uses cacheConfig)
* ```ts
* const services = createStreamServices();
* // Uses Redis if USE_REDIS_STREAMS=true (defaults to USE_REDIS), otherwise in-memory
* ```
*
* @example Force in-memory
* ```ts
* const services = createStreamServices({ useRedis: false });
* ```
*/
export function createStreamServices(config: StreamServicesConfig = {}): StreamServices {
// Use provided config or fall back to cache config (USE_REDIS_STREAMS for stream-specific override)
const useRedis = config.useRedis ?? cacheConfig.USE_REDIS_STREAMS;
const redisClient = config.redisClient ?? ioredisClient;
const { redisSubscriber, inMemoryOptions } = config;
// Check if we should and can use Redis
if (useRedis && redisClient) {
try {
// For subscribing, we need a dedicated connection
// If subscriber not provided, duplicate the main client
let subscriber = redisSubscriber;
if (!subscriber && 'duplicate' in redisClient) {
subscriber = (redisClient as Redis).duplicate();
logger.info('[StreamServices] Duplicated Redis client for subscriber');
}
if (!subscriber) {
logger.warn('[StreamServices] No subscriber client available, falling back to in-memory');
return createInMemoryServices(inMemoryOptions);
}
const jobStore = new RedisJobStore(redisClient);
const eventTransport = new RedisEventTransport(redisClient, subscriber);
logger.info('[StreamServices] Created Redis-backed stream services');
return {
jobStore,
eventTransport,
isRedis: true,
};
} catch (err) {
logger.error(
'[StreamServices] Failed to create Redis services, falling back to in-memory:',
err,
);
return createInMemoryServices(inMemoryOptions);
}
}
return createInMemoryServices(inMemoryOptions);
}
/**
* Create in-memory stream services
*/
function createInMemoryServices(options?: StreamServicesConfig['inMemoryOptions']): StreamServices {
const jobStore = new InMemoryJobStore({
ttlAfterComplete: options?.ttlAfterComplete ?? 300000, // 5 minutes
maxJobs: options?.maxJobs ?? 1000,
});
const eventTransport = new InMemoryEventTransport();
logger.info('[StreamServices] Created in-memory stream services');
return {
jobStore,
eventTransport,
isRedis: false,
};
}

View file

@ -0,0 +1,137 @@
import { EventEmitter } from 'events';
import { logger } from '@librechat/data-schemas';
import type { IEventTransport } from '../interfaces/IJobStore';
interface StreamState {
emitter: EventEmitter;
allSubscribersLeftCallback?: () => void;
}
/**
* In-memory event transport using Node.js EventEmitter.
* For horizontal scaling, replace with RedisEventTransport.
*/
export class InMemoryEventTransport implements IEventTransport {
private streams = new Map<string, StreamState>();
private getOrCreateStream(streamId: string): StreamState {
let state = this.streams.get(streamId);
if (!state) {
const emitter = new EventEmitter();
emitter.setMaxListeners(100);
state = { emitter };
this.streams.set(streamId, state);
}
return state;
}
subscribe(
streamId: string,
handlers: {
onChunk: (event: unknown) => void;
onDone?: (event: unknown) => void;
onError?: (error: string) => void;
},
): { unsubscribe: () => void } {
const state = this.getOrCreateStream(streamId);
const chunkHandler = (event: unknown) => handlers.onChunk(event);
const doneHandler = (event: unknown) => handlers.onDone?.(event);
const errorHandler = (error: string) => handlers.onError?.(error);
state.emitter.on('chunk', chunkHandler);
state.emitter.on('done', doneHandler);
state.emitter.on('error', errorHandler);
logger.debug(
`[InMemoryEventTransport] subscribe ${streamId}: listeners=${state.emitter.listenerCount('chunk')}`,
);
return {
unsubscribe: () => {
const currentState = this.streams.get(streamId);
if (currentState) {
currentState.emitter.off('chunk', chunkHandler);
currentState.emitter.off('done', doneHandler);
currentState.emitter.off('error', errorHandler);
// Check if all subscribers left - cleanup and notify
if (currentState.emitter.listenerCount('chunk') === 0) {
currentState.allSubscribersLeftCallback?.();
// Auto-cleanup the stream entry when no subscribers remain
currentState.emitter.removeAllListeners();
this.streams.delete(streamId);
}
}
},
};
}
emitChunk(streamId: string, event: unknown): void {
const state = this.streams.get(streamId);
state?.emitter.emit('chunk', event);
}
emitDone(streamId: string, event: unknown): void {
const state = this.streams.get(streamId);
state?.emitter.emit('done', event);
}
emitError(streamId: string, error: string): void {
const state = this.streams.get(streamId);
state?.emitter.emit('error', error);
}
getSubscriberCount(streamId: string): number {
const state = this.streams.get(streamId);
return state?.emitter.listenerCount('chunk') ?? 0;
}
onAllSubscribersLeft(streamId: string, callback: () => void): void {
const state = this.getOrCreateStream(streamId);
state.allSubscribersLeftCallback = callback;
}
/**
* Check if this is the first subscriber (for ready signaling)
*/
isFirstSubscriber(streamId: string): boolean {
const state = this.streams.get(streamId);
const count = state?.emitter.listenerCount('chunk') ?? 0;
logger.debug(`[InMemoryEventTransport] isFirstSubscriber ${streamId}: count=${count}`);
return count === 1;
}
/**
* Cleanup a stream's event emitter
*/
cleanup(streamId: string): void {
const state = this.streams.get(streamId);
if (state) {
state.emitter.removeAllListeners();
this.streams.delete(streamId);
}
}
/**
* Get count of tracked streams (for monitoring)
*/
getStreamCount(): number {
return this.streams.size;
}
/**
* Get all tracked stream IDs (for orphan cleanup)
*/
getTrackedStreamIds(): string[] {
return Array.from(this.streams.keys());
}
destroy(): void {
for (const state of this.streams.values()) {
state.emitter.removeAllListeners();
}
this.streams.clear();
logger.debug('[InMemoryEventTransport] Destroyed');
}
}

View file

@ -0,0 +1,303 @@
import { logger } from '@librechat/data-schemas';
import type { StandardGraph } from '@librechat/agents';
import type { Agents } from 'librechat-data-provider';
import type { IJobStore, SerializableJobData, JobStatus } from '~/stream/interfaces/IJobStore';
/**
* Content state for a job - volatile, in-memory only.
* Uses WeakRef to allow garbage collection of graph when no longer needed.
*/
interface ContentState {
contentParts: Agents.MessageContentComplex[];
graphRef: WeakRef<StandardGraph> | null;
}
/**
* In-memory implementation of IJobStore.
* Suitable for single-instance deployments.
* For horizontal scaling, use RedisJobStore.
*
* Content state is tied to jobs:
* - Uses WeakRef to graph for live access to contentParts and contentData (run steps)
* - No chunk persistence needed - same instance handles generation and reconnects
*/
export class InMemoryJobStore implements IJobStore {
private jobs = new Map<string, SerializableJobData>();
private contentState = new Map<string, ContentState>();
private cleanupInterval: NodeJS.Timeout | null = null;
/** Maps userId -> Set of streamIds (conversationIds) for active jobs */
private userJobMap = new Map<string, Set<string>>();
/** Time to keep completed jobs before cleanup (0 = immediate) */
private ttlAfterComplete = 0;
/** Maximum number of concurrent jobs */
private maxJobs = 1000;
constructor(options?: { ttlAfterComplete?: number; maxJobs?: number }) {
if (options?.ttlAfterComplete) {
this.ttlAfterComplete = options.ttlAfterComplete;
}
if (options?.maxJobs) {
this.maxJobs = options.maxJobs;
}
}
async initialize(): Promise<void> {
if (this.cleanupInterval) {
return;
}
this.cleanupInterval = setInterval(() => {
this.cleanup();
}, 60000);
if (this.cleanupInterval.unref) {
this.cleanupInterval.unref();
}
logger.debug('[InMemoryJobStore] Initialized with cleanup interval');
}
async createJob(
streamId: string,
userId: string,
conversationId?: string,
): Promise<SerializableJobData> {
if (this.jobs.size >= this.maxJobs) {
await this.evictOldest();
}
const job: SerializableJobData = {
streamId,
userId,
status: 'running',
createdAt: Date.now(),
conversationId,
syncSent: false,
};
this.jobs.set(streamId, job);
// Track job by userId for efficient user-scoped queries
let userJobs = this.userJobMap.get(userId);
if (!userJobs) {
userJobs = new Set();
this.userJobMap.set(userId, userJobs);
}
userJobs.add(streamId);
logger.debug(`[InMemoryJobStore] Created job: ${streamId}`);
return job;
}
async getJob(streamId: string): Promise<SerializableJobData | null> {
return this.jobs.get(streamId) ?? null;
}
async updateJob(streamId: string, updates: Partial<SerializableJobData>): Promise<void> {
const job = this.jobs.get(streamId);
if (!job) {
return;
}
Object.assign(job, updates);
}
async deleteJob(streamId: string): Promise<void> {
this.jobs.delete(streamId);
this.contentState.delete(streamId);
logger.debug(`[InMemoryJobStore] Deleted job: ${streamId}`);
}
async hasJob(streamId: string): Promise<boolean> {
return this.jobs.has(streamId);
}
async getRunningJobs(): Promise<SerializableJobData[]> {
const running: SerializableJobData[] = [];
for (const job of this.jobs.values()) {
if (job.status === 'running') {
running.push(job);
}
}
return running;
}
async cleanup(): Promise<number> {
const now = Date.now();
const toDelete: string[] = [];
for (const [streamId, job] of this.jobs) {
const isFinished = ['complete', 'error', 'aborted'].includes(job.status);
if (isFinished && job.completedAt) {
// TTL of 0 means immediate cleanup, otherwise wait for TTL to expire
if (this.ttlAfterComplete === 0 || now - job.completedAt > this.ttlAfterComplete) {
toDelete.push(streamId);
}
}
}
for (const id of toDelete) {
await this.deleteJob(id);
}
if (toDelete.length > 0) {
logger.debug(`[InMemoryJobStore] Cleaned up ${toDelete.length} expired jobs`);
}
return toDelete.length;
}
private async evictOldest(): Promise<void> {
let oldestId: string | null = null;
let oldestTime = Infinity;
for (const [streamId, job] of this.jobs) {
if (job.createdAt < oldestTime) {
oldestTime = job.createdAt;
oldestId = streamId;
}
}
if (oldestId) {
logger.warn(`[InMemoryJobStore] Evicting oldest job: ${oldestId}`);
await this.deleteJob(oldestId);
}
}
/** Get job count (for monitoring) */
async getJobCount(): Promise<number> {
return this.jobs.size;
}
/** Get job count by status (for monitoring) */
async getJobCountByStatus(status: JobStatus): Promise<number> {
let count = 0;
for (const job of this.jobs.values()) {
if (job.status === status) {
count++;
}
}
return count;
}
async destroy(): Promise<void> {
if (this.cleanupInterval) {
clearInterval(this.cleanupInterval);
this.cleanupInterval = null;
}
this.jobs.clear();
this.contentState.clear();
this.userJobMap.clear();
logger.debug('[InMemoryJobStore] Destroyed');
}
/**
* Get active job IDs for a user.
* Returns conversation IDs of running jobs belonging to the user.
* Also performs self-healing cleanup: removes stale entries for jobs that no longer exist.
*/
async getActiveJobIdsByUser(userId: string): Promise<string[]> {
const trackedIds = this.userJobMap.get(userId);
if (!trackedIds || trackedIds.size === 0) {
return [];
}
const activeIds: string[] = [];
for (const streamId of trackedIds) {
const job = this.jobs.get(streamId);
// Only include if job exists AND is still running
if (job && job.status === 'running') {
activeIds.push(streamId);
} else {
// Self-healing: job completed/deleted but mapping wasn't cleaned - fix it now
trackedIds.delete(streamId);
}
}
// Clean up empty set
if (trackedIds.size === 0) {
this.userJobMap.delete(userId);
}
return activeIds;
}
// ===== Content State Methods =====
/**
* Set the graph reference for a job.
* Uses WeakRef to allow garbage collection when graph is no longer needed.
*/
setGraph(streamId: string, graph: StandardGraph): void {
const existing = this.contentState.get(streamId);
if (existing) {
existing.graphRef = new WeakRef(graph);
} else {
this.contentState.set(streamId, {
contentParts: [],
graphRef: new WeakRef(graph),
});
}
}
/**
* Set content parts reference for a job.
*/
setContentParts(streamId: string, contentParts: Agents.MessageContentComplex[]): void {
const existing = this.contentState.get(streamId);
if (existing) {
existing.contentParts = contentParts;
} else {
this.contentState.set(streamId, { contentParts, graphRef: null });
}
}
/**
* Get content parts for a job.
* Returns live content from stored reference.
*/
async getContentParts(streamId: string): Promise<{
content: Agents.MessageContentComplex[];
} | null> {
const state = this.contentState.get(streamId);
if (!state?.contentParts) {
return null;
}
return {
content: state.contentParts,
};
}
/**
* Get run steps for a job from graph.contentData.
* Uses WeakRef - may return empty if graph has been GC'd.
*/
async getRunSteps(streamId: string): Promise<Agents.RunStep[]> {
const state = this.contentState.get(streamId);
if (!state?.graphRef) {
return [];
}
// Dereference WeakRef - may return undefined if GC'd
const graph = state.graphRef.deref();
return graph?.contentData ?? [];
}
/**
* No-op for in-memory - content available via graph reference.
*/
async appendChunk(): Promise<void> {
// No-op: content available via graph reference
}
/**
* Clear content state for a job.
*/
clearContentState(streamId: string): void {
this.contentState.delete(streamId);
}
}

View file

@ -0,0 +1,318 @@
import type { Redis, Cluster } from 'ioredis';
import { logger } from '@librechat/data-schemas';
import type { IEventTransport } from '~/stream/interfaces/IJobStore';
/**
* Redis key prefixes for pub/sub channels
*/
const CHANNELS = {
/** Main event channel: stream:{streamId}:events (hash tag for cluster compatibility) */
events: (streamId: string) => `stream:{${streamId}}:events`,
};
/**
* Event types for pub/sub messages
*/
const EventTypes = {
CHUNK: 'chunk',
DONE: 'done',
ERROR: 'error',
} as const;
interface PubSubMessage {
type: (typeof EventTypes)[keyof typeof EventTypes];
data?: unknown;
error?: string;
}
/**
* Subscriber state for a stream
*/
interface StreamSubscribers {
count: number;
handlers: Map<
string,
{
onChunk: (event: unknown) => void;
onDone?: (event: unknown) => void;
onError?: (error: string) => void;
}
>;
allSubscribersLeftCallbacks: Array<() => void>;
}
/**
* Redis Pub/Sub implementation of IEventTransport.
* Enables real-time event delivery across multiple instances.
*
* Architecture (inspired by https://upstash.com/blog/resumable-llm-streams):
* - Publisher: Emits events to Redis channel when chunks arrive
* - Subscriber: Listens to Redis channel and forwards to SSE clients
* - Decoupled: Generator and consumer don't need direct connection
*
* Note: Requires TWO Redis connections - one for publishing, one for subscribing.
* This is a Redis limitation: a client in subscribe mode can't publish.
*
* @example
* ```ts
* const transport = new RedisEventTransport(publisherClient, subscriberClient);
* transport.subscribe(streamId, { onChunk: (e) => res.write(e) });
* transport.emitChunk(streamId, { text: 'Hello' });
* ```
*/
export class RedisEventTransport implements IEventTransport {
/** Redis client for publishing events */
private publisher: Redis | Cluster;
/** Redis client for subscribing to events (separate connection required) */
private subscriber: Redis | Cluster;
/** Track subscribers per stream */
private streams = new Map<string, StreamSubscribers>();
/** Track which channels we're subscribed to */
private subscribedChannels = new Set<string>();
/** Counter for generating unique subscriber IDs */
private subscriberIdCounter = 0;
/**
* Create a new Redis event transport.
*
* @param publisher - Redis client for publishing (can be shared)
* @param subscriber - Redis client for subscribing (must be dedicated)
*/
constructor(publisher: Redis | Cluster, subscriber: Redis | Cluster) {
this.publisher = publisher;
this.subscriber = subscriber;
// Set up message handler for all subscriptions
this.subscriber.on('message', (channel: string, message: string) => {
this.handleMessage(channel, message);
});
}
/**
* Handle incoming pub/sub message
*/
private handleMessage(channel: string, message: string): void {
// Extract streamId from channel name: stream:{streamId}:events
// Use regex to extract the hash tag content
const match = channel.match(/^stream:\{([^}]+)\}:events$/);
if (!match) {
return;
}
const streamId = match[1];
const streamState = this.streams.get(streamId);
if (!streamState) {
return;
}
try {
const parsed = JSON.parse(message) as PubSubMessage;
for (const [, handlers] of streamState.handlers) {
switch (parsed.type) {
case EventTypes.CHUNK:
handlers.onChunk(parsed.data);
break;
case EventTypes.DONE:
handlers.onDone?.(parsed.data);
break;
case EventTypes.ERROR:
handlers.onError?.(parsed.error ?? 'Unknown error');
break;
}
}
} catch (err) {
logger.error(`[RedisEventTransport] Failed to parse message:`, err);
}
}
/**
* Subscribe to events for a stream.
*
* On first subscriber for a stream, subscribes to the Redis channel.
* Returns unsubscribe function that cleans up when last subscriber leaves.
*/
subscribe(
streamId: string,
handlers: {
onChunk: (event: unknown) => void;
onDone?: (event: unknown) => void;
onError?: (error: string) => void;
},
): { unsubscribe: () => void } {
const channel = CHANNELS.events(streamId);
const subscriberId = `sub_${++this.subscriberIdCounter}`;
// Initialize stream state if needed
if (!this.streams.has(streamId)) {
this.streams.set(streamId, {
count: 0,
handlers: new Map(),
allSubscribersLeftCallbacks: [],
});
}
const streamState = this.streams.get(streamId)!;
streamState.count++;
streamState.handlers.set(subscriberId, handlers);
// Subscribe to Redis channel if this is first subscriber
if (!this.subscribedChannels.has(channel)) {
this.subscribedChannels.add(channel);
this.subscriber.subscribe(channel).catch((err) => {
logger.error(`[RedisEventTransport] Failed to subscribe to ${channel}:`, err);
});
}
// Return unsubscribe function
return {
unsubscribe: () => {
const state = this.streams.get(streamId);
if (!state) {
return;
}
state.handlers.delete(subscriberId);
state.count--;
// If last subscriber left, unsubscribe from Redis and notify
if (state.count === 0) {
this.subscriber.unsubscribe(channel).catch((err) => {
logger.error(`[RedisEventTransport] Failed to unsubscribe from ${channel}:`, err);
});
this.subscribedChannels.delete(channel);
// Call all-subscribers-left callbacks
for (const callback of state.allSubscribersLeftCallbacks) {
try {
callback();
} catch (err) {
logger.error(`[RedisEventTransport] Error in allSubscribersLeft callback:`, err);
}
}
this.streams.delete(streamId);
}
},
};
}
/**
* Publish a chunk event to all subscribers across all instances.
*/
emitChunk(streamId: string, event: unknown): void {
const channel = CHANNELS.events(streamId);
const message: PubSubMessage = { type: EventTypes.CHUNK, data: event };
this.publisher.publish(channel, JSON.stringify(message)).catch((err) => {
logger.error(`[RedisEventTransport] Failed to publish chunk:`, err);
});
}
/**
* Publish a done event to all subscribers.
*/
emitDone(streamId: string, event: unknown): void {
const channel = CHANNELS.events(streamId);
const message: PubSubMessage = { type: EventTypes.DONE, data: event };
this.publisher.publish(channel, JSON.stringify(message)).catch((err) => {
logger.error(`[RedisEventTransport] Failed to publish done:`, err);
});
}
/**
* Publish an error event to all subscribers.
*/
emitError(streamId: string, error: string): void {
const channel = CHANNELS.events(streamId);
const message: PubSubMessage = { type: EventTypes.ERROR, error };
this.publisher.publish(channel, JSON.stringify(message)).catch((err) => {
logger.error(`[RedisEventTransport] Failed to publish error:`, err);
});
}
/**
* Get subscriber count for a stream (local instance only).
*
* Note: In a multi-instance setup, this only returns local subscriber count.
* For global count, would need to track in Redis (e.g., with a counter key).
*/
getSubscriberCount(streamId: string): number {
return this.streams.get(streamId)?.count ?? 0;
}
/**
* Check if this is the first subscriber (local instance only).
*/
isFirstSubscriber(streamId: string): boolean {
return this.getSubscriberCount(streamId) === 1;
}
/**
* Register callback for when all subscribers leave.
*/
onAllSubscribersLeft(streamId: string, callback: () => void): void {
const state = this.streams.get(streamId);
if (state) {
state.allSubscribersLeftCallbacks.push(callback);
} else {
// Create state just for the callback
this.streams.set(streamId, {
count: 0,
handlers: new Map(),
allSubscribersLeftCallbacks: [callback],
});
}
}
/**
* Get all tracked stream IDs (for orphan cleanup)
*/
getTrackedStreamIds(): string[] {
return Array.from(this.streams.keys());
}
/**
* Cleanup resources for a specific stream.
*/
cleanup(streamId: string): void {
const channel = CHANNELS.events(streamId);
const state = this.streams.get(streamId);
if (state) {
// Clear all handlers
state.handlers.clear();
state.allSubscribersLeftCallbacks = [];
}
// Unsubscribe from Redis channel
if (this.subscribedChannels.has(channel)) {
this.subscriber.unsubscribe(channel).catch((err) => {
logger.error(`[RedisEventTransport] Failed to cleanup ${channel}:`, err);
});
this.subscribedChannels.delete(channel);
}
this.streams.delete(streamId);
}
/**
* Destroy all resources.
*/
destroy(): void {
// Unsubscribe from all channels
for (const channel of this.subscribedChannels) {
this.subscriber.unsubscribe(channel).catch(() => {
// Ignore errors during shutdown
});
}
this.subscribedChannels.clear();
this.streams.clear();
// Note: Don't close Redis connections - they may be shared
logger.info('[RedisEventTransport] Destroyed');
}
}

View file

@ -0,0 +1,835 @@
import { logger } from '@librechat/data-schemas';
import { createContentAggregator } from '@librechat/agents';
import type { IJobStore, SerializableJobData, JobStatus } from '~/stream/interfaces/IJobStore';
import type { StandardGraph } from '@librechat/agents';
import type { Agents } from 'librechat-data-provider';
import type { Redis, Cluster } from 'ioredis';
/**
* Key prefixes for Redis storage.
* All keys include the streamId for easy cleanup.
* Note: streamId === conversationId, so no separate mapping needed.
*
* IMPORTANT: Uses hash tags {streamId} for Redis Cluster compatibility.
* All keys for the same stream hash to the same slot, enabling:
* - Pipeline operations across related keys
* - Atomic multi-key operations
*/
const KEYS = {
/** Job metadata: stream:{streamId}:job */
job: (streamId: string) => `stream:{${streamId}}:job`,
/** Chunk stream (Redis Streams): stream:{streamId}:chunks */
chunks: (streamId: string) => `stream:{${streamId}}:chunks`,
/** Run steps: stream:{streamId}:runsteps */
runSteps: (streamId: string) => `stream:{${streamId}}:runsteps`,
/** Running jobs set for cleanup (global set - single slot) */
runningJobs: 'stream:running',
/** User's active jobs set: stream:user:{userId}:jobs */
userJobs: (userId: string) => `stream:user:{${userId}}:jobs`,
};
/**
* Default TTL values in seconds.
* Can be overridden via constructor options.
*/
const DEFAULT_TTL = {
/** TTL for completed jobs (5 minutes) */
completed: 300,
/** TTL for running jobs/chunks (20 minutes - failsafe for crashed jobs) */
running: 1200,
/** TTL for chunks after completion (0 = delete immediately) */
chunksAfterComplete: 0,
/** TTL for run steps after completion (0 = delete immediately) */
runStepsAfterComplete: 0,
};
/**
* Redis implementation of IJobStore.
* Enables horizontal scaling with multi-instance deployments.
*
* Storage strategy:
* - Job metadata: Redis Hash (fast field access)
* - Chunks: Redis Streams (append-only, efficient for streaming)
* - Run steps: Redis String (JSON serialized)
*
* Note: streamId === conversationId, so getJob(conversationId) works directly.
*
* @example
* ```ts
* import { ioredisClient } from '~/cache';
* const store = new RedisJobStore(ioredisClient);
* await store.initialize();
* ```
*/
/**
* Configuration options for RedisJobStore
*/
export interface RedisJobStoreOptions {
/** TTL for completed jobs in seconds (default: 300 = 5 minutes) */
completedTtl?: number;
/** TTL for running jobs/chunks in seconds (default: 1200 = 20 minutes) */
runningTtl?: number;
/** TTL for chunks after completion in seconds (default: 0 = delete immediately) */
chunksAfterCompleteTtl?: number;
/** TTL for run steps after completion in seconds (default: 0 = delete immediately) */
runStepsAfterCompleteTtl?: number;
}
export class RedisJobStore implements IJobStore {
private redis: Redis | Cluster;
private cleanupInterval: NodeJS.Timeout | null = null;
private ttl: typeof DEFAULT_TTL;
/** Whether Redis client is in cluster mode (affects pipeline usage) */
private isCluster: boolean;
/**
* Local cache for graph references on THIS instance.
* Enables fast reconnects when client returns to the same server.
* Uses WeakRef to allow garbage collection when graph is no longer needed.
*/
private localGraphCache = new Map<string, WeakRef<StandardGraph>>();
/** Cleanup interval in ms (1 minute) */
private cleanupIntervalMs = 60000;
constructor(redis: Redis | Cluster, options?: RedisJobStoreOptions) {
this.redis = redis;
this.ttl = {
completed: options?.completedTtl ?? DEFAULT_TTL.completed,
running: options?.runningTtl ?? DEFAULT_TTL.running,
chunksAfterComplete: options?.chunksAfterCompleteTtl ?? DEFAULT_TTL.chunksAfterComplete,
runStepsAfterComplete: options?.runStepsAfterCompleteTtl ?? DEFAULT_TTL.runStepsAfterComplete,
};
// Detect cluster mode using ioredis's isCluster property
this.isCluster = (redis as Cluster).isCluster === true;
}
async initialize(): Promise<void> {
if (this.cleanupInterval) {
return;
}
// Start periodic cleanup
this.cleanupInterval = setInterval(() => {
this.cleanup().catch((err) => {
logger.error('[RedisJobStore] Cleanup error:', err);
});
}, this.cleanupIntervalMs);
if (this.cleanupInterval.unref) {
this.cleanupInterval.unref();
}
logger.info('[RedisJobStore] Initialized with cleanup interval');
}
async createJob(
streamId: string,
userId: string,
conversationId?: string,
): Promise<SerializableJobData> {
const job: SerializableJobData = {
streamId,
userId,
status: 'running',
createdAt: Date.now(),
conversationId,
syncSent: false,
};
const key = KEYS.job(streamId);
const userJobsKey = KEYS.userJobs(userId);
// For cluster mode, we can't pipeline keys on different slots
// The job key uses hash tag {streamId}, runningJobs and userJobs are on different slots
if (this.isCluster) {
await this.redis.hmset(key, this.serializeJob(job));
await this.redis.expire(key, this.ttl.running);
await this.redis.sadd(KEYS.runningJobs, streamId);
await this.redis.sadd(userJobsKey, streamId);
} else {
const pipeline = this.redis.pipeline();
pipeline.hmset(key, this.serializeJob(job));
pipeline.expire(key, this.ttl.running);
pipeline.sadd(KEYS.runningJobs, streamId);
pipeline.sadd(userJobsKey, streamId);
await pipeline.exec();
}
logger.debug(`[RedisJobStore] Created job: ${streamId}`);
return job;
}
async getJob(streamId: string): Promise<SerializableJobData | null> {
const data = await this.redis.hgetall(KEYS.job(streamId));
if (!data || Object.keys(data).length === 0) {
return null;
}
return this.deserializeJob(data);
}
async updateJob(streamId: string, updates: Partial<SerializableJobData>): Promise<void> {
const key = KEYS.job(streamId);
const exists = await this.redis.exists(key);
if (!exists) {
return;
}
const serialized = this.serializeJob(updates as SerializableJobData);
if (Object.keys(serialized).length === 0) {
return;
}
await this.redis.hmset(key, serialized);
// If status changed to complete/error/aborted, update TTL and remove from running set
// Note: userJobs cleanup is handled lazily via self-healing in getActiveJobIdsByUser
if (updates.status && ['complete', 'error', 'aborted'].includes(updates.status)) {
// In cluster mode, separate runningJobs (global) from stream-specific keys
if (this.isCluster) {
await this.redis.expire(key, this.ttl.completed);
await this.redis.srem(KEYS.runningJobs, streamId);
if (this.ttl.chunksAfterComplete === 0) {
await this.redis.del(KEYS.chunks(streamId));
} else {
await this.redis.expire(KEYS.chunks(streamId), this.ttl.chunksAfterComplete);
}
if (this.ttl.runStepsAfterComplete === 0) {
await this.redis.del(KEYS.runSteps(streamId));
} else {
await this.redis.expire(KEYS.runSteps(streamId), this.ttl.runStepsAfterComplete);
}
} else {
const pipeline = this.redis.pipeline();
pipeline.expire(key, this.ttl.completed);
pipeline.srem(KEYS.runningJobs, streamId);
if (this.ttl.chunksAfterComplete === 0) {
pipeline.del(KEYS.chunks(streamId));
} else {
pipeline.expire(KEYS.chunks(streamId), this.ttl.chunksAfterComplete);
}
if (this.ttl.runStepsAfterComplete === 0) {
pipeline.del(KEYS.runSteps(streamId));
} else {
pipeline.expire(KEYS.runSteps(streamId), this.ttl.runStepsAfterComplete);
}
await pipeline.exec();
}
}
}
async deleteJob(streamId: string): Promise<void> {
// Clear local caches
this.localGraphCache.delete(streamId);
// Note: userJobs cleanup is handled lazily via self-healing in getActiveJobIdsByUser
// In cluster mode, separate runningJobs (global) from stream-specific keys (same slot)
if (this.isCluster) {
// Stream-specific keys all hash to same slot due to {streamId}
const pipeline = this.redis.pipeline();
pipeline.del(KEYS.job(streamId));
pipeline.del(KEYS.chunks(streamId));
pipeline.del(KEYS.runSteps(streamId));
await pipeline.exec();
// Global set is on different slot - execute separately
await this.redis.srem(KEYS.runningJobs, streamId);
} else {
const pipeline = this.redis.pipeline();
pipeline.del(KEYS.job(streamId));
pipeline.del(KEYS.chunks(streamId));
pipeline.del(KEYS.runSteps(streamId));
pipeline.srem(KEYS.runningJobs, streamId);
await pipeline.exec();
}
logger.debug(`[RedisJobStore] Deleted job: ${streamId}`);
}
async hasJob(streamId: string): Promise<boolean> {
const exists = await this.redis.exists(KEYS.job(streamId));
return exists === 1;
}
async getRunningJobs(): Promise<SerializableJobData[]> {
const streamIds = await this.redis.smembers(KEYS.runningJobs);
if (streamIds.length === 0) {
return [];
}
const jobs: SerializableJobData[] = [];
for (const streamId of streamIds) {
const job = await this.getJob(streamId);
if (job && job.status === 'running') {
jobs.push(job);
}
}
return jobs;
}
async cleanup(): Promise<number> {
const now = Date.now();
const streamIds = await this.redis.smembers(KEYS.runningJobs);
let cleaned = 0;
// Clean up stale local graph cache entries (WeakRefs that were collected)
for (const [streamId, graphRef] of this.localGraphCache) {
if (!graphRef.deref()) {
this.localGraphCache.delete(streamId);
}
}
for (const streamId of streamIds) {
const job = await this.getJob(streamId);
// Job no longer exists (TTL expired) - remove from set
if (!job) {
await this.redis.srem(KEYS.runningJobs, streamId);
this.localGraphCache.delete(streamId);
cleaned++;
continue;
}
// Job completed but still in running set (shouldn't happen, but handle it)
if (job.status !== 'running') {
await this.redis.srem(KEYS.runningJobs, streamId);
this.localGraphCache.delete(streamId);
cleaned++;
continue;
}
// Stale running job (failsafe - running for > configured TTL)
if (now - job.createdAt > this.ttl.running * 1000) {
logger.warn(`[RedisJobStore] Cleaning up stale job: ${streamId}`);
await this.deleteJob(streamId);
cleaned++;
}
}
if (cleaned > 0) {
logger.debug(`[RedisJobStore] Cleaned up ${cleaned} jobs`);
}
return cleaned;
}
async getJobCount(): Promise<number> {
// This is approximate - counts jobs in running set + scans for job keys
// For exact count, would need to scan all job:* keys
const runningCount = await this.redis.scard(KEYS.runningJobs);
return runningCount;
}
async getJobCountByStatus(status: JobStatus): Promise<number> {
if (status === 'running') {
return this.redis.scard(KEYS.runningJobs);
}
// For other statuses, we'd need to scan - return 0 for now
// In production, consider maintaining separate sets per status if needed
return 0;
}
/**
* Get active job IDs for a user.
* Returns conversation IDs of running jobs belonging to the user.
* Also performs self-healing cleanup: removes stale entries for jobs that no longer exist.
*
* @param userId - The user ID to query
* @returns Array of conversation IDs with active jobs
*/
async getActiveJobIdsByUser(userId: string): Promise<string[]> {
const userJobsKey = KEYS.userJobs(userId);
const trackedIds = await this.redis.smembers(userJobsKey);
if (trackedIds.length === 0) {
return [];
}
const activeIds: string[] = [];
const staleIds: string[] = [];
for (const streamId of trackedIds) {
const job = await this.getJob(streamId);
// Only include if job exists AND is still running
if (job && job.status === 'running') {
activeIds.push(streamId);
} else {
// Self-healing: job completed/deleted but mapping wasn't cleaned - mark for removal
staleIds.push(streamId);
}
}
// Clean up stale entries
if (staleIds.length > 0) {
await this.redis.srem(userJobsKey, ...staleIds);
logger.debug(
`[RedisJobStore] Self-healed ${staleIds.length} stale job entries for user ${userId}`,
);
}
return activeIds;
}
async destroy(): Promise<void> {
if (this.cleanupInterval) {
clearInterval(this.cleanupInterval);
this.cleanupInterval = null;
}
// Clear local caches
this.localGraphCache.clear();
// Don't close the Redis connection - it's shared
logger.info('[RedisJobStore] Destroyed');
}
// ===== Content State Methods =====
// For Redis, content is primarily reconstructed from chunks.
// However, we keep a LOCAL graph cache for fast same-instance reconnects.
/**
* Store graph reference in local cache.
* This enables fast reconnects when client returns to the same instance.
* Falls back to Redis chunk reconstruction for cross-instance reconnects.
*
* @param streamId - The stream identifier
* @param graph - The graph instance (stored as WeakRef)
*/
setGraph(streamId: string, graph: StandardGraph): void {
this.localGraphCache.set(streamId, new WeakRef(graph));
}
/**
* No-op for Redis - content parts are reconstructed from chunks.
* Metadata (agentId, groupId) is embedded directly on content parts by the agent runtime.
*/
setContentParts(_streamId: string, _contentParts: Agents.MessageContentComplex[]): void {
// Content parts are reconstructed from chunks during getContentParts
// No separate storage needed
}
/**
* Get aggregated content - tries local cache first, falls back to Redis reconstruction.
*
* Optimization: If this instance has the live graph (same-instance reconnect),
* we return the content directly without Redis round-trip.
* For cross-instance reconnects, we reconstruct from Redis Streams.
*
* @param streamId - The stream identifier
* @returns Content parts array or null if not found
*/
async getContentParts(streamId: string): Promise<{
content: Agents.MessageContentComplex[];
} | null> {
// 1. Try local graph cache first (fast path for same-instance reconnect)
const graphRef = this.localGraphCache.get(streamId);
if (graphRef) {
const graph = graphRef.deref();
if (graph) {
const localParts = graph.getContentParts();
if (localParts && localParts.length > 0) {
return {
content: localParts,
};
}
} else {
// WeakRef was collected, remove from cache
this.localGraphCache.delete(streamId);
}
}
// 2. Fall back to Redis chunk reconstruction (cross-instance reconnect)
const chunks = await this.getChunks(streamId);
if (chunks.length === 0) {
return null;
}
// Use the same content aggregator as live streaming
const { contentParts, aggregateContent } = createContentAggregator();
// Valid event types for content aggregation
const validEvents = new Set([
'on_run_step',
'on_message_delta',
'on_reasoning_delta',
'on_run_step_delta',
'on_run_step_completed',
'on_agent_update',
]);
for (const chunk of chunks) {
const event = chunk as { event?: string; data?: unknown };
if (!event.event || !event.data || !validEvents.has(event.event)) {
continue;
}
// Pass event string directly - GraphEvents values are lowercase strings
// eslint-disable-next-line @typescript-eslint/no-explicit-any
aggregateContent({ event: event.event as any, data: event.data as any });
}
// Filter out undefined entries
const filtered: Agents.MessageContentComplex[] = [];
for (const part of contentParts) {
if (part !== undefined) {
filtered.push(part);
}
}
return {
content: filtered,
};
}
/**
* Get run steps - tries local cache first, falls back to Redis.
*
* Optimization: If this instance has the live graph, we get run steps
* directly without Redis round-trip.
*
* @param streamId - The stream identifier
* @returns Run steps array
*/
async getRunSteps(streamId: string): Promise<Agents.RunStep[]> {
// 1. Try local graph cache first (fast path for same-instance reconnect)
const graphRef = this.localGraphCache.get(streamId);
if (graphRef) {
const graph = graphRef.deref();
if (graph) {
const localSteps = graph.getRunSteps();
if (localSteps && localSteps.length > 0) {
return localSteps;
}
}
// Note: Don't delete from cache here - graph may still be valid
// but just not have run steps yet
}
// 2. Fall back to Redis (cross-instance reconnect)
const key = KEYS.runSteps(streamId);
const data = await this.redis.get(key);
if (!data) {
return [];
}
try {
return JSON.parse(data);
} catch {
return [];
}
}
/**
* Clear content state for a job.
* Removes both local cache and Redis data.
*/
clearContentState(streamId: string): void {
// Clear local caches immediately
this.localGraphCache.delete(streamId);
// Fire and forget - async cleanup for Redis
this.clearContentStateAsync(streamId).catch((err) => {
logger.error(`[RedisJobStore] Failed to clear content state for ${streamId}:`, err);
});
}
/**
* Clear content state async.
*/
private async clearContentStateAsync(streamId: string): Promise<void> {
const pipeline = this.redis.pipeline();
pipeline.del(KEYS.chunks(streamId));
pipeline.del(KEYS.runSteps(streamId));
await pipeline.exec();
}
/**
* Append a streaming chunk to Redis Stream.
* Uses XADD for efficient append-only storage.
* Sets TTL on first chunk to ensure cleanup if job crashes.
*/
async appendChunk(streamId: string, event: unknown): Promise<void> {
const key = KEYS.chunks(streamId);
const added = await this.redis.xadd(key, '*', 'event', JSON.stringify(event));
// Set TTL on first chunk (when stream is created)
// Subsequent chunks inherit the stream's TTL
if (added) {
const len = await this.redis.xlen(key);
if (len === 1) {
await this.redis.expire(key, this.ttl.running);
}
}
}
/**
* Get all chunks from Redis Stream.
*/
private async getChunks(streamId: string): Promise<unknown[]> {
const key = KEYS.chunks(streamId);
const entries = await this.redis.xrange(key, '-', '+');
return entries
.map(([, fields]) => {
const eventIdx = fields.indexOf('event');
if (eventIdx >= 0 && eventIdx + 1 < fields.length) {
try {
return JSON.parse(fields[eventIdx + 1]);
} catch {
return null;
}
}
return null;
})
.filter(Boolean);
}
/**
* Save run steps for resume state.
*/
async saveRunSteps(streamId: string, runSteps: Agents.RunStep[]): Promise<void> {
const key = KEYS.runSteps(streamId);
await this.redis.set(key, JSON.stringify(runSteps), 'EX', this.ttl.running);
}
// ===== Consumer Group Methods =====
// These enable tracking which chunks each client has seen.
// Based on https://upstash.com/blog/resumable-llm-streams
/**
* Create a consumer group for a stream.
* Used to track which chunks a client has already received.
*
* @param streamId - The stream identifier
* @param groupName - Unique name for the consumer group (e.g., session ID)
* @param startFrom - Where to start reading ('0' = from beginning, '$' = only new)
*/
async createConsumerGroup(
streamId: string,
groupName: string,
startFrom: '0' | '$' = '0',
): Promise<void> {
const key = KEYS.chunks(streamId);
try {
await this.redis.xgroup('CREATE', key, groupName, startFrom, 'MKSTREAM');
logger.debug(`[RedisJobStore] Created consumer group ${groupName} for ${streamId}`);
} catch (err) {
// BUSYGROUP error means group already exists - that's fine
const error = err as Error;
if (!error.message?.includes('BUSYGROUP')) {
throw err;
}
}
}
/**
* Read chunks from a consumer group (only unseen chunks).
* This is the key to the resumable stream pattern.
*
* @param streamId - The stream identifier
* @param groupName - Consumer group name
* @param consumerName - Name of the consumer within the group
* @param count - Maximum number of chunks to read (default: all available)
* @returns Array of { id, event } where id is the Redis stream entry ID
*/
async readChunksFromGroup(
streamId: string,
groupName: string,
consumerName: string = 'consumer-1',
count?: number,
): Promise<Array<{ id: string; event: unknown }>> {
const key = KEYS.chunks(streamId);
try {
// XREADGROUP GROUP groupName consumerName [COUNT count] STREAMS key >
// The '>' means only read new messages not yet delivered to this consumer
let result;
if (count) {
result = await this.redis.xreadgroup(
'GROUP',
groupName,
consumerName,
'COUNT',
count,
'STREAMS',
key,
'>',
);
} else {
result = await this.redis.xreadgroup('GROUP', groupName, consumerName, 'STREAMS', key, '>');
}
if (!result || result.length === 0) {
return [];
}
// Result format: [[streamKey, [[id, [field, value, ...]], ...]]]
const [, messages] = result[0] as [string, Array<[string, string[]]>];
const chunks: Array<{ id: string; event: unknown }> = [];
for (const [id, fields] of messages) {
const eventIdx = fields.indexOf('event');
if (eventIdx >= 0 && eventIdx + 1 < fields.length) {
try {
chunks.push({
id,
event: JSON.parse(fields[eventIdx + 1]),
});
} catch {
// Skip malformed entries
}
}
}
return chunks;
} catch (err) {
const error = err as Error;
// NOGROUP error means the group doesn't exist yet
if (error.message?.includes('NOGROUP')) {
return [];
}
throw err;
}
}
/**
* Acknowledge that chunks have been processed.
* This tells Redis we've successfully delivered these chunks to the client.
*
* @param streamId - The stream identifier
* @param groupName - Consumer group name
* @param messageIds - Array of Redis stream entry IDs to acknowledge
*/
async acknowledgeChunks(
streamId: string,
groupName: string,
messageIds: string[],
): Promise<void> {
if (messageIds.length === 0) {
return;
}
const key = KEYS.chunks(streamId);
await this.redis.xack(key, groupName, ...messageIds);
}
/**
* Delete a consumer group.
* Called when a client disconnects and won't reconnect.
*
* @param streamId - The stream identifier
* @param groupName - Consumer group name to delete
*/
async deleteConsumerGroup(streamId: string, groupName: string): Promise<void> {
const key = KEYS.chunks(streamId);
try {
await this.redis.xgroup('DESTROY', key, groupName);
logger.debug(`[RedisJobStore] Deleted consumer group ${groupName} for ${streamId}`);
} catch {
// Ignore errors - group may not exist
}
}
/**
* Get pending chunks for a consumer (chunks delivered but not acknowledged).
* Useful for recovering from crashes.
*
* @param streamId - The stream identifier
* @param groupName - Consumer group name
* @param consumerName - Consumer name
*/
async getPendingChunks(
streamId: string,
groupName: string,
consumerName: string = 'consumer-1',
): Promise<Array<{ id: string; event: unknown }>> {
const key = KEYS.chunks(streamId);
try {
// Read pending messages (delivered but not acked) by using '0' instead of '>'
const result = await this.redis.xreadgroup(
'GROUP',
groupName,
consumerName,
'STREAMS',
key,
'0',
);
if (!result || result.length === 0) {
return [];
}
const [, messages] = result[0] as [string, Array<[string, string[]]>];
const chunks: Array<{ id: string; event: unknown }> = [];
for (const [id, fields] of messages) {
const eventIdx = fields.indexOf('event');
if (eventIdx >= 0 && eventIdx + 1 < fields.length) {
try {
chunks.push({
id,
event: JSON.parse(fields[eventIdx + 1]),
});
} catch {
// Skip malformed entries
}
}
}
return chunks;
} catch {
return [];
}
}
/**
* Serialize job data for Redis hash storage.
* Converts complex types to strings.
*/
private serializeJob(job: Partial<SerializableJobData>): Record<string, string> {
const result: Record<string, string> = {};
for (const [key, value] of Object.entries(job)) {
if (value === undefined) {
continue;
}
if (typeof value === 'object') {
result[key] = JSON.stringify(value);
} else if (typeof value === 'boolean') {
result[key] = value ? '1' : '0';
} else {
result[key] = String(value);
}
}
return result;
}
/**
* Deserialize job data from Redis hash.
*/
private deserializeJob(data: Record<string, string>): SerializableJobData {
return {
streamId: data.streamId,
userId: data.userId,
status: data.status as JobStatus,
createdAt: parseInt(data.createdAt, 10),
completedAt: data.completedAt ? parseInt(data.completedAt, 10) : undefined,
conversationId: data.conversationId || undefined,
error: data.error || undefined,
userMessage: data.userMessage ? JSON.parse(data.userMessage) : undefined,
responseMessageId: data.responseMessageId || undefined,
sender: data.sender || undefined,
syncSent: data.syncSent === '1',
finalEvent: data.finalEvent || undefined,
endpoint: data.endpoint || undefined,
iconURL: data.iconURL || undefined,
model: data.model || undefined,
promptTokens: data.promptTokens ? parseInt(data.promptTokens, 10) : undefined,
};
}
}

View file

@ -0,0 +1,4 @@
export * from './InMemoryJobStore';
export * from './InMemoryEventTransport';
export * from './RedisJobStore';
export * from './RedisEventTransport';

View file

@ -0,0 +1,22 @@
export {
GenerationJobManager,
GenerationJobManagerClass,
type GenerationJobManagerOptions,
} from './GenerationJobManager';
export type {
AbortResult,
SerializableJobData,
JobStatus,
IJobStore,
IEventTransport,
} from './interfaces/IJobStore';
export { createStreamServices } from './createStreamServices';
export type { StreamServicesConfig, StreamServices } from './createStreamServices';
// Implementations (for advanced use cases)
export { InMemoryJobStore } from './implementations/InMemoryJobStore';
export { InMemoryEventTransport } from './implementations/InMemoryEventTransport';
export { RedisJobStore } from './implementations/RedisJobStore';
export { RedisEventTransport } from './implementations/RedisEventTransport';

View file

@ -0,0 +1,256 @@
import type { Agents } from 'librechat-data-provider';
import type { StandardGraph } from '@librechat/agents';
/**
* Job status enum
*/
export type JobStatus = 'running' | 'complete' | 'error' | 'aborted';
/**
* Serializable job data - no object references, suitable for Redis/external storage
*/
export interface SerializableJobData {
streamId: string;
userId: string;
status: JobStatus;
createdAt: number;
completedAt?: number;
conversationId?: string;
error?: string;
/** User message metadata */
userMessage?: {
messageId: string;
parentMessageId?: string;
conversationId?: string;
text?: string;
};
/** Response message ID for reconnection */
responseMessageId?: string;
/** Sender name for UI display */
sender?: string;
/** Whether sync has been sent to a client */
syncSent: boolean;
/** Serialized final event for replay */
finalEvent?: string;
/** Endpoint metadata for abort handling - avoids storing functions */
endpoint?: string;
iconURL?: string;
model?: string;
promptTokens?: number;
}
/**
* Result returned from aborting a job - contains all data needed
* for token spending and message saving without storing callbacks
*/
export interface AbortResult {
/** Whether the abort was successful */
success: boolean;
/** The job data at time of abort */
jobData: SerializableJobData | null;
/** Aggregated content from the stream */
content: Agents.MessageContentComplex[];
/** Final event to send to client */
finalEvent: unknown;
}
/**
* Resume state for reconnecting clients
*/
export interface ResumeState {
runSteps: Agents.RunStep[];
aggregatedContent: Agents.MessageContentComplex[];
userMessage?: SerializableJobData['userMessage'];
responseMessageId?: string;
conversationId?: string;
sender?: string;
}
/**
* Interface for job storage backend.
* Implementations can use in-memory Map, Redis, KV store, etc.
*
* Content state is tied to jobs:
* - In-memory: Holds WeakRef to graph for live content/run steps access
* - Redis: Persists chunks, reconstructs content on reconnect
*
* This consolidates job metadata + content state into a single interface.
*/
export interface IJobStore {
/** Initialize the store (e.g., connect to Redis, start cleanup intervals) */
initialize(): Promise<void>;
/** Create a new job */
createJob(
streamId: string,
userId: string,
conversationId?: string,
): Promise<SerializableJobData>;
/** Get a job by streamId (streamId === conversationId) */
getJob(streamId: string): Promise<SerializableJobData | null>;
/** Update job data */
updateJob(streamId: string, updates: Partial<SerializableJobData>): Promise<void>;
/** Delete a job */
deleteJob(streamId: string): Promise<void>;
/** Check if job exists */
hasJob(streamId: string): Promise<boolean>;
/** Get all running jobs (for cleanup) */
getRunningJobs(): Promise<SerializableJobData[]>;
/** Cleanup expired jobs */
cleanup(): Promise<number>;
/** Get total job count */
getJobCount(): Promise<number>;
/** Get job count by status */
getJobCountByStatus(status: JobStatus): Promise<number>;
/** Destroy the store and release resources */
destroy(): Promise<void>;
/**
* Get active job IDs for a user.
* Returns conversation IDs of running jobs belonging to the user.
* Also performs self-healing cleanup of stale entries.
*
* @param userId - The user ID to query
* @returns Array of conversation IDs with active jobs
*/
getActiveJobIdsByUser(userId: string): Promise<string[]>;
// ===== Content State Methods =====
// These methods manage volatile content state tied to each job.
// In-memory: Uses WeakRef to graph for live access
// Redis: Persists chunks and reconstructs on demand
/**
* Set the graph reference for a job (in-memory only).
* The graph provides live access to contentParts and contentData (run steps).
*
* In-memory: Stores WeakRef to graph
* Redis: No-op (graph not transferable, uses chunks instead)
*
* @param streamId - The stream identifier
* @param graph - The StandardGraph instance
*/
setGraph(streamId: string, graph: StandardGraph): void;
/**
* Set content parts reference for a job.
*
* In-memory: Stores direct reference to content array
* Redis: No-op (content built from chunks)
*
* @param streamId - The stream identifier
* @param contentParts - The content parts array
*/
setContentParts(streamId: string, contentParts: Agents.MessageContentComplex[]): void;
/**
* Get aggregated content for a job.
*
* In-memory: Returns live content from graph.contentParts or stored reference
* Redis: Reconstructs from stored chunks
*
* @param streamId - The stream identifier
* @returns Content parts or null if not available
*/
getContentParts(streamId: string): Promise<{
content: Agents.MessageContentComplex[];
} | null>;
/**
* Get run steps for a job (for resume state).
*
* In-memory: Returns live run steps from graph.contentData
* Redis: Fetches from persistent storage
*
* @param streamId - The stream identifier
* @returns Run steps or empty array
*/
getRunSteps(streamId: string): Promise<Agents.RunStep[]>;
/**
* Append a streaming chunk for later reconstruction.
*
* In-memory: No-op (content available via graph reference)
* Redis: Uses XADD for append-only log efficiency
*
* @param streamId - The stream identifier
* @param event - The SSE event to append
*/
appendChunk(streamId: string, event: unknown): Promise<void>;
/**
* Clear all content state for a job.
* Called on job completion/cleanup.
*
* @param streamId - The stream identifier
*/
clearContentState(streamId: string): void;
/**
* Save run steps to persistent storage.
* In-memory: No-op (run steps accessed via graph reference)
* Redis: Persists for resume across instances
*
* @param streamId - The stream identifier
* @param runSteps - Run steps to save
*/
saveRunSteps?(streamId: string, runSteps: Agents.RunStep[]): Promise<void>;
}
/**
* Interface for pub/sub event transport.
* Implementations can use EventEmitter, Redis Pub/Sub, etc.
*/
export interface IEventTransport {
/** Subscribe to events for a stream */
subscribe(
streamId: string,
handlers: {
onChunk: (event: unknown) => void;
onDone?: (event: unknown) => void;
onError?: (error: string) => void;
},
): { unsubscribe: () => void };
/** Publish a chunk event */
emitChunk(streamId: string, event: unknown): void;
/** Publish a done event */
emitDone(streamId: string, event: unknown): void;
/** Publish an error event */
emitError(streamId: string, error: string): void;
/** Get subscriber count for a stream */
getSubscriberCount(streamId: string): number;
/** Check if this is the first subscriber (for ready signaling) */
isFirstSubscriber(streamId: string): boolean;
/** Listen for all subscribers leaving */
onAllSubscribersLeft(streamId: string, callback: () => void): void;
/** Cleanup transport resources for a specific stream */
cleanup(streamId: string): void;
/** Get all tracked stream IDs (for orphan cleanup) */
getTrackedStreamIds(): string[];
/** Destroy all transport resources */
destroy(): void;
}

View file

@ -0,0 +1 @@
export * from './IJobStore';

View file

@ -29,12 +29,25 @@ export interface AudioProcessingResult {
bytes: number;
}
/** Google video block format */
export interface GoogleVideoBlock {
type: 'media';
mimeType: string;
data: string;
}
/** OpenRouter video block format */
export interface OpenRouterVideoBlock {
type: 'video_url';
video_url: {
url: string;
};
}
export type VideoBlock = GoogleVideoBlock | OpenRouterVideoBlock;
export interface VideoResult {
videos: Array<{
type: string;
mimeType: string;
data: string;
}>;
videos: VideoBlock[];
files: Array<{
file_id?: string;
temp_file_id?: string;
@ -100,12 +113,26 @@ export interface DocumentResult {
}>;
}
export interface AudioResult {
audios: Array<{
type: string;
mimeType: string;
/** Google audio block format */
export interface GoogleAudioBlock {
type: 'media';
mimeType: string;
data: string;
}
/** OpenRouter audio block format */
export interface OpenRouterAudioBlock {
type: 'input_audio';
input_audio: {
data: string;
}>;
format: string;
};
}
export type AudioBlock = GoogleAudioBlock | OpenRouterAudioBlock;
export interface AudioResult {
audios: AudioBlock[];
files: Array<{
file_id?: string;
temp_file_id?: string;

View file

@ -13,3 +13,4 @@ export type * from './openai';
export * from './prompts';
export * from './run';
export * from './tokens';
export * from './stream';

View file

@ -0,0 +1,49 @@
import type { EventEmitter } from 'events';
import type { Agents } from 'librechat-data-provider';
import type { ServerSentEvent } from '~/types';
export interface GenerationJobMetadata {
userId: string;
conversationId?: string;
/** User message data for rebuilding submission on reconnect */
userMessage?: Agents.UserMessageMeta;
/** Response message ID for tracking */
responseMessageId?: string;
/** Sender label for the response (e.g., "GPT-4.1", "Claude") */
sender?: string;
/** Endpoint identifier for abort handling */
endpoint?: string;
/** Icon URL for UI display */
iconURL?: string;
/** Model name for token tracking */
model?: string;
/** Prompt token count for abort token spending */
promptTokens?: number;
}
export type GenerationJobStatus = 'running' | 'complete' | 'error' | 'aborted';
export interface GenerationJob {
streamId: string;
emitter: EventEmitter;
status: GenerationJobStatus;
createdAt: number;
completedAt?: number;
abortController: AbortController;
error?: string;
metadata: GenerationJobMetadata;
readyPromise: Promise<void>;
resolveReady: () => void;
/** Final event when job completes */
finalEvent?: ServerSentEvent;
/** Flag to indicate if a sync event was already sent (prevent duplicate replays) */
syncSent?: boolean;
}
export type ContentPart = Agents.ContentPart;
export type ResumeState = Agents.ResumeState;
export type ChunkHandler = (event: ServerSentEvent) => void;
export type DoneHandler = (event: ServerSentEvent) => void;
export type ErrorHandler = (error: string) => void;
export type UnsubscribeFn = () => void;

View file

@ -0,0 +1,196 @@
/**
* Integration tests for math function with actual config patterns.
* These tests verify that real environment variable expressions from .env.example
* are correctly evaluated by the math function.
*/
import { math } from './math';
describe('math - integration with real config patterns', () => {
describe('SESSION_EXPIRY patterns', () => {
test('should evaluate default SESSION_EXPIRY (15 minutes)', () => {
const result = math('1000 * 60 * 15');
expect(result).toBe(900000); // 15 minutes in ms
});
test('should evaluate 30 minute session', () => {
const result = math('1000 * 60 * 30');
expect(result).toBe(1800000); // 30 minutes in ms
});
test('should evaluate 1 hour session', () => {
const result = math('1000 * 60 * 60');
expect(result).toBe(3600000); // 1 hour in ms
});
});
describe('REFRESH_TOKEN_EXPIRY patterns', () => {
test('should evaluate default REFRESH_TOKEN_EXPIRY (7 days)', () => {
const result = math('(1000 * 60 * 60 * 24) * 7');
expect(result).toBe(604800000); // 7 days in ms
});
test('should evaluate 1 day refresh token', () => {
const result = math('1000 * 60 * 60 * 24');
expect(result).toBe(86400000); // 1 day in ms
});
test('should evaluate 30 day refresh token', () => {
const result = math('(1000 * 60 * 60 * 24) * 30');
expect(result).toBe(2592000000); // 30 days in ms
});
});
describe('BAN_DURATION patterns', () => {
test('should evaluate default BAN_DURATION (2 hours)', () => {
const result = math('1000 * 60 * 60 * 2');
expect(result).toBe(7200000); // 2 hours in ms
});
test('should evaluate 24 hour ban', () => {
const result = math('1000 * 60 * 60 * 24');
expect(result).toBe(86400000); // 24 hours in ms
});
});
describe('Redis config patterns', () => {
test('should evaluate REDIS_RETRY_MAX_DELAY', () => {
expect(math('3000')).toBe(3000);
});
test('should evaluate REDIS_RETRY_MAX_ATTEMPTS', () => {
expect(math('10')).toBe(10);
});
test('should evaluate REDIS_CONNECT_TIMEOUT', () => {
expect(math('10000')).toBe(10000);
});
test('should evaluate REDIS_MAX_LISTENERS', () => {
expect(math('40')).toBe(40);
});
test('should evaluate REDIS_DELETE_CHUNK_SIZE', () => {
expect(math('1000')).toBe(1000);
});
});
describe('MCP config patterns', () => {
test('should evaluate MCP_OAUTH_DETECTION_TIMEOUT', () => {
expect(math('5000')).toBe(5000);
});
test('should evaluate MCP_CONNECTION_CHECK_TTL', () => {
expect(math('60000')).toBe(60000); // 1 minute
});
test('should evaluate MCP_USER_CONNECTION_IDLE_TIMEOUT (15 minutes)', () => {
const result = math('15 * 60 * 1000');
expect(result).toBe(900000); // 15 minutes in ms
});
test('should evaluate MCP_REGISTRY_CACHE_TTL', () => {
expect(math('5000')).toBe(5000); // 5 seconds
});
});
describe('Leader election config patterns', () => {
test('should evaluate LEADER_LEASE_DURATION (25 seconds)', () => {
expect(math('25')).toBe(25);
});
test('should evaluate LEADER_RENEW_INTERVAL (10 seconds)', () => {
expect(math('10')).toBe(10);
});
test('should evaluate LEADER_RENEW_ATTEMPTS', () => {
expect(math('3')).toBe(3);
});
test('should evaluate LEADER_RENEW_RETRY_DELAY (0.5 seconds)', () => {
expect(math('0.5')).toBe(0.5);
});
});
describe('OpenID config patterns', () => {
test('should evaluate OPENID_JWKS_URL_CACHE_TIME (10 minutes)', () => {
const result = math('600000');
expect(result).toBe(600000); // 10 minutes in ms
});
test('should evaluate custom cache time expression', () => {
const result = math('1000 * 60 * 10');
expect(result).toBe(600000); // 10 minutes in ms
});
});
describe('simulated process.env usage', () => {
const originalEnv = process.env;
beforeEach(() => {
process.env = { ...originalEnv };
});
afterEach(() => {
process.env = originalEnv;
});
test('should work with SESSION_EXPIRY from env', () => {
process.env.SESSION_EXPIRY = '1000 * 60 * 15';
const result = math(process.env.SESSION_EXPIRY, 900000);
expect(result).toBe(900000);
});
test('should work with REFRESH_TOKEN_EXPIRY from env', () => {
process.env.REFRESH_TOKEN_EXPIRY = '(1000 * 60 * 60 * 24) * 7';
const result = math(process.env.REFRESH_TOKEN_EXPIRY, 604800000);
expect(result).toBe(604800000);
});
test('should work with BAN_DURATION from env', () => {
process.env.BAN_DURATION = '1000 * 60 * 60 * 2';
const result = math(process.env.BAN_DURATION, 7200000);
expect(result).toBe(7200000);
});
test('should use fallback when env var is undefined', () => {
delete process.env.SESSION_EXPIRY;
const result = math(process.env.SESSION_EXPIRY, 900000);
expect(result).toBe(900000);
});
test('should use fallback when env var is empty string', () => {
process.env.SESSION_EXPIRY = '';
const result = math(process.env.SESSION_EXPIRY, 900000);
expect(result).toBe(900000);
});
test('should use fallback when env var has invalid expression', () => {
process.env.SESSION_EXPIRY = 'invalid';
const result = math(process.env.SESSION_EXPIRY, 900000);
expect(result).toBe(900000);
});
});
describe('time calculation helpers', () => {
// Helper functions to make time calculations more readable
const seconds = (n: number) => n * 1000;
const minutes = (n: number) => seconds(n * 60);
const hours = (n: number) => minutes(n * 60);
const days = (n: number) => hours(n * 24);
test('should match helper calculations', () => {
// Verify our math function produces same results as programmatic calculations
expect(math('1000 * 60 * 15')).toBe(minutes(15));
expect(math('1000 * 60 * 60 * 2')).toBe(hours(2));
expect(math('(1000 * 60 * 60 * 24) * 7')).toBe(days(7));
});
test('should handle complex expressions', () => {
// 2 hours + 30 minutes
expect(math('(1000 * 60 * 60 * 2) + (1000 * 60 * 30)')).toBe(hours(2) + minutes(30));
// Half a day
expect(math('(1000 * 60 * 60 * 24) / 2')).toBe(days(1) / 2);
});
});
});

View file

@ -0,0 +1,326 @@
import { math } from './math';
describe('math', () => {
describe('number input passthrough', () => {
test('should return number as-is when input is a number', () => {
expect(math(42)).toBe(42);
});
test('should return zero when input is 0', () => {
expect(math(0)).toBe(0);
});
test('should return negative numbers as-is', () => {
expect(math(-10)).toBe(-10);
});
test('should return decimal numbers as-is', () => {
expect(math(0.5)).toBe(0.5);
});
test('should return very large numbers as-is', () => {
expect(math(Number.MAX_SAFE_INTEGER)).toBe(Number.MAX_SAFE_INTEGER);
});
});
describe('simple string number parsing', () => {
test('should parse simple integer string', () => {
expect(math('42')).toBe(42);
});
test('should parse zero string', () => {
expect(math('0')).toBe(0);
});
test('should parse negative number string', () => {
expect(math('-10')).toBe(-10);
});
test('should parse decimal string', () => {
expect(math('0.5')).toBe(0.5);
});
test('should parse string with leading/trailing spaces', () => {
expect(math(' 42 ')).toBe(42);
});
test('should parse large number string', () => {
expect(math('9007199254740991')).toBe(Number.MAX_SAFE_INTEGER);
});
});
describe('mathematical expressions - multiplication', () => {
test('should evaluate simple multiplication', () => {
expect(math('2 * 3')).toBe(6);
});
test('should evaluate chained multiplication (BAN_DURATION pattern: 1000 * 60 * 60 * 2)', () => {
// 2 hours in milliseconds
expect(math('1000 * 60 * 60 * 2')).toBe(7200000);
});
test('should evaluate SESSION_EXPIRY pattern (1000 * 60 * 15)', () => {
// 15 minutes in milliseconds
expect(math('1000 * 60 * 15')).toBe(900000);
});
test('should evaluate multiplication without spaces', () => {
expect(math('2*3')).toBe(6);
});
});
describe('mathematical expressions - addition and subtraction', () => {
test('should evaluate simple addition', () => {
expect(math('2 + 3')).toBe(5);
});
test('should evaluate simple subtraction', () => {
expect(math('10 - 3')).toBe(7);
});
test('should evaluate mixed addition and subtraction', () => {
expect(math('10 + 5 - 3')).toBe(12);
});
test('should handle negative results', () => {
expect(math('3 - 10')).toBe(-7);
});
});
describe('mathematical expressions - division', () => {
test('should evaluate simple division', () => {
expect(math('10 / 2')).toBe(5);
});
test('should evaluate division resulting in decimal', () => {
expect(math('7 / 2')).toBe(3.5);
});
});
describe('mathematical expressions - parentheses', () => {
test('should evaluate expression with parentheses (REFRESH_TOKEN_EXPIRY pattern)', () => {
// 7 days in milliseconds: (1000 * 60 * 60 * 24) * 7
expect(math('(1000 * 60 * 60 * 24) * 7')).toBe(604800000);
});
test('should evaluate nested parentheses', () => {
expect(math('((2 + 3) * 4)')).toBe(20);
});
test('should respect operator precedence with parentheses', () => {
expect(math('2 * (3 + 4)')).toBe(14);
});
});
describe('mathematical expressions - modulo', () => {
test('should evaluate modulo operation', () => {
expect(math('10 % 3')).toBe(1);
});
test('should evaluate modulo with larger numbers', () => {
expect(math('100 % 7')).toBe(2);
});
});
describe('complex real-world expressions', () => {
test('should evaluate MCP_USER_CONNECTION_IDLE_TIMEOUT pattern (15 * 60 * 1000)', () => {
// 15 minutes in milliseconds
expect(math('15 * 60 * 1000')).toBe(900000);
});
test('should evaluate Redis default TTL (5000)', () => {
expect(math('5000')).toBe(5000);
});
test('should evaluate LEADER_RENEW_RETRY_DELAY decimal (0.5)', () => {
expect(math('0.5')).toBe(0.5);
});
test('should evaluate BAN_DURATION default (7200000)', () => {
// 2 hours in milliseconds
expect(math('7200000')).toBe(7200000);
});
test('should evaluate expression with mixed operators and parentheses', () => {
// (1 hour + 30 min) in ms
expect(math('(1000 * 60 * 60) + (1000 * 60 * 30)')).toBe(5400000);
});
});
describe('fallback value behavior', () => {
test('should return fallback when input is undefined', () => {
expect(math(undefined, 100)).toBe(100);
});
test('should return fallback when input is null', () => {
// @ts-expect-error - testing runtime behavior with invalid input
expect(math(null, 100)).toBe(100);
});
test('should return fallback when input contains invalid characters', () => {
expect(math('abc', 100)).toBe(100);
});
test('should return fallback when input has SQL injection attempt', () => {
expect(math('1; DROP TABLE users;', 100)).toBe(100);
});
test('should return fallback when input has function call attempt', () => {
expect(math('console.log("hacked")', 100)).toBe(100);
});
test('should return fallback when input is empty string', () => {
expect(math('', 100)).toBe(100);
});
test('should return zero fallback when specified', () => {
expect(math(undefined, 0)).toBe(0);
});
test('should use number input even when fallback is provided', () => {
expect(math(42, 100)).toBe(42);
});
test('should use valid string even when fallback is provided', () => {
expect(math('42', 100)).toBe(42);
});
});
describe('error cases without fallback', () => {
test('should throw error when input is undefined without fallback', () => {
expect(() => math(undefined)).toThrow('str is undefined, but should be a string');
});
test('should throw error when input is null without fallback', () => {
// @ts-expect-error - testing runtime behavior with invalid input
expect(() => math(null)).toThrow('str is object, but should be a string');
});
test('should throw error when input contains invalid characters without fallback', () => {
expect(() => math('abc')).toThrow('Invalid characters in string');
});
test('should throw error when input has letter characters', () => {
expect(() => math('10x')).toThrow('Invalid characters in string');
});
test('should throw error when input has special characters', () => {
expect(() => math('10!')).toThrow('Invalid characters in string');
});
test('should throw error for malicious code injection', () => {
expect(() => math('process.exit(1)')).toThrow('Invalid characters in string');
});
test('should throw error for require injection', () => {
expect(() => math('require("fs")')).toThrow('Invalid characters in string');
});
});
describe('security - input validation', () => {
test('should reject strings with alphabetic characters', () => {
expect(() => math('Math.PI')).toThrow('Invalid characters in string');
});
test('should reject strings with brackets', () => {
expect(() => math('[1,2,3]')).toThrow('Invalid characters in string');
});
test('should reject strings with curly braces', () => {
expect(() => math('{}')).toThrow('Invalid characters in string');
});
test('should reject strings with semicolons', () => {
expect(() => math('1;2')).toThrow('Invalid characters in string');
});
test('should reject strings with quotes', () => {
expect(() => math('"test"')).toThrow('Invalid characters in string');
});
test('should reject strings with backticks', () => {
expect(() => math('`test`')).toThrow('Invalid characters in string');
});
test('should reject strings with equals sign', () => {
expect(() => math('x=1')).toThrow('Invalid characters in string');
});
test('should reject strings with ampersand', () => {
expect(() => math('1 && 2')).toThrow('Invalid characters in string');
});
test('should reject strings with pipe', () => {
expect(() => math('1 || 2')).toThrow('Invalid characters in string');
});
});
describe('edge cases', () => {
test('should handle expression resulting in Infinity with fallback', () => {
// Division by zero returns Infinity, which is technically a number
expect(math('1 / 0')).toBe(Infinity);
});
test('should handle very small decimals', () => {
expect(math('0.001')).toBe(0.001);
});
test('should handle scientific notation format', () => {
// Note: 'e' is not in the allowed character set, so this should fail
expect(() => math('1e3')).toThrow('Invalid characters in string');
});
test('should handle expression with only whitespace with fallback', () => {
expect(math(' ', 100)).toBe(100);
});
test('should handle +number syntax', () => {
expect(math('+42')).toBe(42);
});
test('should handle expression starting with negative', () => {
expect(math('-5 + 10')).toBe(5);
});
test('should handle multiple decimal points with fallback', () => {
// Invalid syntax should return fallback value
expect(math('1.2.3', 100)).toBe(100);
});
test('should throw for multiple decimal points without fallback', () => {
expect(() => math('1.2.3')).toThrow();
});
});
describe('type coercion edge cases', () => {
test('should handle object input with fallback', () => {
// @ts-expect-error - testing runtime behavior with invalid input
expect(math({}, 100)).toBe(100);
});
test('should handle array input with fallback', () => {
// @ts-expect-error - testing runtime behavior with invalid input
expect(math([], 100)).toBe(100);
});
test('should handle boolean true with fallback', () => {
// @ts-expect-error - testing runtime behavior with invalid input
expect(math(true, 100)).toBe(100);
});
test('should handle boolean false with fallback', () => {
// @ts-expect-error - testing runtime behavior with invalid input
expect(math(false, 100)).toBe(100);
});
test('should throw for object input without fallback', () => {
// @ts-expect-error - testing runtime behavior with invalid input
expect(() => math({})).toThrow('str is object, but should be a string');
});
test('should throw for array input without fallback', () => {
// @ts-expect-error - testing runtime behavior with invalid input
expect(() => math([])).toThrow('str is object, but should be a string');
});
});
});

View file

@ -1,3 +1,5 @@
import { evaluate } from 'mathjs';
/**
* Evaluates a mathematical expression provided as a string and returns the result.
*
@ -5,6 +7,8 @@
* If the input is not a string or contains invalid characters, an error is thrown.
* If the evaluated result is not a number, an error is thrown.
*
* Uses mathjs for safe expression evaluation instead of eval().
*
* @param str - The mathematical expression to evaluate, or a number.
* @param fallbackValue - The default value to return if the input is not a string or number, or if the evaluated result is not a number.
*
@ -32,14 +36,22 @@ export function math(str: string | number | undefined, fallbackValue?: number):
throw new Error('Invalid characters in string');
}
const value = eval(str);
try {
const value = evaluate(str);
if (typeof value !== 'number') {
if (typeof value !== 'number') {
if (fallback) {
return fallbackValue;
}
throw new Error(`[math] str did not evaluate to a number but to a ${typeof value}`);
}
return value;
} catch (error) {
if (fallback) {
return fallbackValue;
}
throw new Error(`[math] str did not evaluate to a number but to a ${typeof value}`);
const originalMessage = error instanceof Error ? error.message : String(error);
throw new Error(`[math] Error while evaluating mathematical expression: ${originalMessage}`);
}
return value;
}

View file

@ -27,4 +27,4 @@ export function sanitizeTitle(rawTitle: string): string {
// Step 5: Return trimmed result or fallback if empty
return trimmed.length > 0 ? trimmed : DEFAULT_FALLBACK;
}
}

View file

@ -8,7 +8,7 @@
"target": "es2015",
"moduleResolution": "node",
"allowSyntheticDefaultImports": true,
"lib": ["es2017", "dom", "ES2021.String"],
"lib": ["es2017", "dom", "ES2021.String", "ES2021.WeakRef"],
"allowJs": true,
"skipLibCheck": true,
"esModuleInterop": true,