Merge branch 'main' into fix-image_gen_oai-with-nova-models

This commit is contained in:
Peter 2026-02-18 08:55:05 +01:00 committed by GitHub
commit cbb6b1a7d5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
211 changed files with 10746 additions and 1656 deletions

View file

@ -10,6 +10,17 @@ export default {
],
coverageReporters: ['text', 'cobertura'],
testResultsProcessor: 'jest-junit',
transform: {
'\\.[jt]sx?$': [
'babel-jest',
{
presets: [
['@babel/preset-env', { targets: { node: 'current' } }],
'@babel/preset-typescript',
],
},
],
},
moduleNameMapper: {
'^@src/(.*)$': '<rootDir>/src/$1',
'~/(.*)': '<rootDir>/src/$1',

View file

@ -87,11 +87,11 @@
"@google/genai": "^1.19.0",
"@keyv/redis": "^4.3.3",
"@langchain/core": "^0.3.80",
"@librechat/agents": "^3.1.38",
"@librechat/agents": "^3.1.50",
"@librechat/data-schemas": "*",
"@modelcontextprotocol/sdk": "^1.26.0",
"@smithy/node-http-handler": "^4.4.5",
"axios": "^1.12.1",
"axios": "^1.13.5",
"connect-redis": "^8.1.0",
"eventsource": "^3.0.2",
"express": "^5.1.0",

View file

@ -0,0 +1,284 @@
import { Providers } from '@librechat/agents';
import { EModelEndpoint } from 'librechat-data-provider';
import type { Agent } from 'librechat-data-provider';
import type { ServerRequest, InitializeResultBase } from '~/types';
import type { InitializeAgentDbMethods } from '../initialize';
// Mock logger
jest.mock('winston', () => ({
createLogger: jest.fn(() => ({
debug: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
})),
format: {
combine: jest.fn(),
colorize: jest.fn(),
simple: jest.fn(),
},
transports: {
Console: jest.fn(),
},
}));
const mockExtractLibreChatParams = jest.fn();
const mockGetModelMaxTokens = jest.fn();
const mockOptionalChainWithEmptyCheck = jest.fn();
const mockGetThreadData = jest.fn();
jest.mock('~/utils', () => ({
extractLibreChatParams: (...args: unknown[]) => mockExtractLibreChatParams(...args),
getModelMaxTokens: (...args: unknown[]) => mockGetModelMaxTokens(...args),
optionalChainWithEmptyCheck: (...args: unknown[]) => mockOptionalChainWithEmptyCheck(...args),
getThreadData: (...args: unknown[]) => mockGetThreadData(...args),
}));
const mockGetProviderConfig = jest.fn();
jest.mock('~/endpoints', () => ({
getProviderConfig: (...args: unknown[]) => mockGetProviderConfig(...args),
}));
jest.mock('~/files', () => ({
filterFilesByEndpointConfig: jest.fn(() => []),
}));
jest.mock('~/prompts', () => ({
generateArtifactsPrompt: jest.fn(() => null),
}));
jest.mock('../resources', () => ({
primeResources: jest.fn().mockResolvedValue({
attachments: [],
tool_resources: undefined,
}),
}));
import { initializeAgent } from '../initialize';
/**
* Creates minimal mock objects for initializeAgent tests.
*/
function createMocks(overrides?: {
maxContextTokens?: number;
modelDefault?: number;
maxOutputTokens?: number;
}) {
const { maxContextTokens, modelDefault = 200000, maxOutputTokens = 4096 } = overrides ?? {};
const agent = {
id: 'agent-1',
model: 'test-model',
provider: Providers.OPENAI,
tools: [],
model_parameters: { model: 'test-model' },
} as unknown as Agent;
const req = {
user: { id: 'user-1' },
config: {},
} as unknown as ServerRequest;
const res = {} as unknown as import('express').Response;
const mockGetOptions = jest.fn().mockResolvedValue({
llmConfig: {
model: 'test-model',
maxTokens: maxOutputTokens,
},
endpointTokenConfig: undefined,
} satisfies InitializeResultBase);
mockGetProviderConfig.mockReturnValue({
getOptions: mockGetOptions,
overrideProvider: Providers.OPENAI,
});
// extractLibreChatParams returns maxContextTokens when provided in model_parameters
mockExtractLibreChatParams.mockReturnValue({
resendFiles: false,
maxContextTokens,
modelOptions: { model: 'test-model' },
});
// getModelMaxTokens returns the model's default context window
mockGetModelMaxTokens.mockReturnValue(modelDefault);
// Implement real optionalChainWithEmptyCheck behavior
mockOptionalChainWithEmptyCheck.mockImplementation(
(...values: (string | number | undefined)[]) => {
for (const v of values) {
if (v !== undefined && v !== null && v !== '') {
return v;
}
}
return values[values.length - 1];
},
);
const loadTools = jest.fn().mockResolvedValue({
tools: [],
toolContextMap: {},
userMCPAuthMap: undefined,
toolRegistry: undefined,
toolDefinitions: [],
hasDeferredTools: false,
});
const db: InitializeAgentDbMethods = {
getFiles: jest.fn().mockResolvedValue([]),
getConvoFiles: jest.fn().mockResolvedValue([]),
updateFilesUsage: jest.fn().mockResolvedValue([]),
getUserKey: jest.fn().mockResolvedValue('user-1'),
getUserKeyValues: jest.fn().mockResolvedValue([]),
getToolFilesByIds: jest.fn().mockResolvedValue([]),
};
return { agent, req, res, loadTools, db };
}
describe('initializeAgent — maxContextTokens', () => {
beforeEach(() => {
jest.clearAllMocks();
});
it('uses user-configured maxContextTokens when provided via model_parameters', async () => {
const userValue = 50000;
const { agent, req, res, loadTools, db } = createMocks({
maxContextTokens: userValue,
modelDefault: 200000,
maxOutputTokens: 4096,
});
const result = await initializeAgent(
{
req,
res,
agent,
loadTools,
endpointOption: {
endpoint: EModelEndpoint.agents,
model_parameters: { maxContextTokens: userValue },
},
allowedProviders: new Set([Providers.OPENAI]),
isInitialAgent: true,
},
db,
);
expect(result.maxContextTokens).toBe(userValue);
});
it('falls back to formula when maxContextTokens is NOT provided', async () => {
const modelDefault = 200000;
const maxOutputTokens = 4096;
const { agent, req, res, loadTools, db } = createMocks({
maxContextTokens: undefined,
modelDefault,
maxOutputTokens,
});
const result = await initializeAgent(
{
req,
res,
agent,
loadTools,
endpointOption: { endpoint: EModelEndpoint.agents },
allowedProviders: new Set([Providers.OPENAI]),
isInitialAgent: true,
},
db,
);
const expected = Math.round((modelDefault - maxOutputTokens) * 0.9);
expect(result.maxContextTokens).toBe(expected);
});
it('falls back to formula when maxContextTokens is 0', async () => {
const maxOutputTokens = 4096;
const { agent, req, res, loadTools, db } = createMocks({
maxContextTokens: 0,
modelDefault: 200000,
maxOutputTokens,
});
const result = await initializeAgent(
{
req,
res,
agent,
loadTools,
endpointOption: {
endpoint: EModelEndpoint.agents,
model_parameters: { maxContextTokens: 0 },
},
allowedProviders: new Set([Providers.OPENAI]),
isInitialAgent: true,
},
db,
);
// 0 is not used as-is; the formula kicks in.
// optionalChainWithEmptyCheck(0, 200000, 18000) returns 0 (not null/undefined),
// then Number(0) || 18000 = 18000 (the fallback default).
expect(result.maxContextTokens).not.toBe(0);
const expected = Math.round((18000 - maxOutputTokens) * 0.9);
expect(result.maxContextTokens).toBe(expected);
});
it('falls back to formula when maxContextTokens is negative', async () => {
const maxOutputTokens = 4096;
const { agent, req, res, loadTools, db } = createMocks({
maxContextTokens: -1,
modelDefault: 200000,
maxOutputTokens,
});
const result = await initializeAgent(
{
req,
res,
agent,
loadTools,
endpointOption: {
endpoint: EModelEndpoint.agents,
model_parameters: { maxContextTokens: -1 },
},
allowedProviders: new Set([Providers.OPENAI]),
isInitialAgent: true,
},
db,
);
// -1 is not used as-is; the formula kicks in
expect(result.maxContextTokens).not.toBe(-1);
});
it('preserves small user-configured value (e.g. 1000 from modelSpec)', async () => {
const userValue = 1000;
const { agent, req, res, loadTools, db } = createMocks({
maxContextTokens: userValue,
modelDefault: 128000,
maxOutputTokens: 4096,
});
const result = await initializeAgent(
{
req,
res,
agent,
loadTools,
endpointOption: {
endpoint: EModelEndpoint.agents,
model_parameters: { maxContextTokens: userValue },
},
allowedProviders: new Set([Providers.OPENAI]),
isInitialAgent: true,
},
db,
);
// Should NOT be overridden to Math.round((128000 - 4096) * 0.9) = 111,514
expect(result.maxContextTokens).toBe(userValue);
});
});

View file

@ -0,0 +1,162 @@
import { logger } from '@librechat/data-schemas';
import { isAgentsEndpoint } from 'librechat-data-provider';
import { labelContentByAgent, getTokenCountForMessage } from '@librechat/agents';
import type { MessageContentComplex } from '@librechat/agents';
import type { Agent, TMessage } from 'librechat-data-provider';
import type { BaseMessage } from '@langchain/core/messages';
import type { ServerRequest } from '~/types';
import Tokenizer from '~/utils/tokenizer';
import { logAxiosError } from '~/utils';
export const omitTitleOptions = new Set([
'stream',
'thinking',
'streaming',
'clientOptions',
'thinkingConfig',
'thinkingBudget',
'includeThoughts',
'maxOutputTokens',
'additionalModelRequestFields',
]);
export function payloadParser({ req, endpoint }: { req: ServerRequest; endpoint: string }) {
if (isAgentsEndpoint(endpoint)) {
return;
}
return req.body?.endpointOption?.model_parameters;
}
export function createTokenCounter(encoding: Parameters<typeof Tokenizer.getTokenCount>[1]) {
return function (message: BaseMessage) {
const countTokens = (text: string) => Tokenizer.getTokenCount(text, encoding);
return getTokenCountForMessage(message, countTokens);
};
}
export function logToolError(_graph: unknown, error: unknown, toolId: string) {
logAxiosError({
error,
message: `[api/server/controllers/agents/client.js #chatCompletion] Tool Error "${toolId}"`,
});
}
const AGENT_SUFFIX_PATTERN = /____(\d+)$/;
/** Finds the primary agent ID within a set of agent IDs (no suffix or lowest suffix number) */
export function findPrimaryAgentId(agentIds: Set<string>): string | null {
let primaryAgentId: string | null = null;
let lowestSuffixIndex = Infinity;
for (const agentId of agentIds) {
const suffixMatch = agentId.match(AGENT_SUFFIX_PATTERN);
if (!suffixMatch) {
return agentId;
}
const suffixIndex = parseInt(suffixMatch[1], 10);
if (suffixIndex < lowestSuffixIndex) {
lowestSuffixIndex = suffixIndex;
primaryAgentId = agentId;
}
}
return primaryAgentId;
}
type ContentPart = TMessage['content'] extends (infer U)[] | undefined ? U : never;
/**
* Creates a mapMethod for getMessagesForConversation that processes agent content.
* - Strips agentId/groupId metadata from all content
* - For parallel agents (addedConvo with groupId): filters each group to its primary agent
* - For handoffs (agentId without groupId): keeps all content from all agents
* - For multi-agent: applies agent labels to content
*
* The key distinction:
* - Parallel execution (addedConvo): Parts have both agentId AND groupId
* - Handoffs: Parts only have agentId, no groupId
*/
export function createMultiAgentMapper(primaryAgent: Agent, agentConfigs?: Map<string, Agent>) {
const hasMultipleAgents = (primaryAgent.edges?.length ?? 0) > 0 || (agentConfigs?.size ?? 0) > 0;
let agentNames: Record<string, string> | null = null;
if (hasMultipleAgents) {
agentNames = { [primaryAgent.id]: primaryAgent.name || 'Assistant' };
if (agentConfigs) {
for (const [agentId, agentConfig] of agentConfigs.entries()) {
agentNames[agentId] = agentConfig.name || agentConfig.id;
}
}
}
return (message: TMessage): TMessage => {
if (message.isCreatedByUser || !Array.isArray(message.content)) {
return message;
}
const hasAgentMetadata = message.content.some(
(part) =>
(part as ContentPart & { agentId?: string; groupId?: number })?.agentId ||
(part as ContentPart & { groupId?: number })?.groupId != null,
);
if (!hasAgentMetadata) {
return message;
}
try {
const groupAgentMap = new Map<number, Set<string>>();
for (const part of message.content) {
const p = part as ContentPart & { agentId?: string; groupId?: number };
const groupId = p?.groupId;
const agentId = p?.agentId;
if (groupId != null && agentId) {
if (!groupAgentMap.has(groupId)) {
groupAgentMap.set(groupId, new Set());
}
groupAgentMap.get(groupId)!.add(agentId);
}
}
const groupPrimaryMap = new Map<number, string>();
for (const [groupId, agentIds] of groupAgentMap) {
const primary = findPrimaryAgentId(agentIds);
if (primary) {
groupPrimaryMap.set(groupId, primary);
}
}
const filteredContent: ContentPart[] = [];
const agentIdMap: Record<number, string> = {};
for (const part of message.content) {
const p = part as ContentPart & { agentId?: string; groupId?: number };
const agentId = p?.agentId;
const groupId = p?.groupId;
const isParallelPart = groupId != null;
const groupPrimary = isParallelPart ? groupPrimaryMap.get(groupId) : null;
const shouldInclude = !isParallelPart || !agentId || agentId === groupPrimary;
if (shouldInclude) {
const newIndex = filteredContent.length;
const { agentId: _a, groupId: _g, ...cleanPart } = p;
filteredContent.push(cleanPart as ContentPart);
if (agentId && hasMultipleAgents) {
agentIdMap[newIndex] = agentId;
}
}
}
const finalContent =
Object.keys(agentIdMap).length > 0 && agentNames
? labelContentByAgent(filteredContent as MessageContentComplex[], agentIdMap, agentNames)
: filteredContent;
return { ...message, content: finalContent as TMessage['content'] };
} catch (error) {
logger.error('[AgentClient] Error processing multi-agent message:', error);
return message;
}
};
}

View file

@ -1,5 +1,6 @@
export * from './avatars';
export * from './chain';
export * from './client';
export * from './context';
export * from './edges';
export * from './handlers';

View file

@ -413,7 +413,10 @@ export async function initializeAgent(
toolContextMap: toolContextMap ?? {},
useLegacyContent: !!options.useLegacyContent,
tools: (tools ?? []) as GenericTool[] & string[],
maxContextTokens: Math.round((agentMaxContextNum - maxOutputTokensNum) * 0.9),
maxContextTokens:
maxContextTokens != null && maxContextTokens > 0
? maxContextTokens
: Math.round((agentMaxContextNum - maxOutputTokensNum) * 0.9),
};
return initializedAgent;

View file

@ -0,0 +1,193 @@
#!/usr/bin/env bash
#
# Live integration tests for the Responses API endpoint.
# Sends curl requests to a running LibreChat server to verify
# multi-turn conversations with output_text / refusal blocks work.
#
# Usage:
# ./responses-api.live.test.sh <BASE_URL> <API_KEY> <AGENT_ID>
#
# Example:
# ./responses-api.live.test.sh http://localhost:3080 sk-abc123 agent_xyz
set -euo pipefail
BASE_URL="${1:?Usage: $0 <BASE_URL> <API_KEY> <AGENT_ID>}"
API_KEY="${2:?Usage: $0 <BASE_URL> <API_KEY> <AGENT_ID>}"
AGENT_ID="${3:?Usage: $0 <BASE_URL> <API_KEY> <AGENT_ID>}"
ENDPOINT="${BASE_URL}/v1/responses"
PASS=0
FAIL=0
# ── Helpers ───────────────────────────────────────────────────────────
post_json() {
local label="$1"
local body="$2"
local stream="${3:-false}"
echo "──────────────────────────────────────────────"
echo "TEST: ${label}"
echo "──────────────────────────────────────────────"
local http_code
local response
if [ "$stream" = "true" ]; then
# For streaming, just check we get a 200 and some SSE data
response=$(curl -s -w "\n%{http_code}" \
-X POST "${ENDPOINT}" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer ${API_KEY}" \
-d "${body}" \
--max-time 60)
else
response=$(curl -s -w "\n%{http_code}" \
-X POST "${ENDPOINT}" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer ${API_KEY}" \
-d "${body}" \
--max-time 60)
fi
http_code=$(echo "$response" | tail -1)
local body_out
body_out=$(echo "$response" | sed '$d')
if [ "$http_code" = "200" ]; then
echo " ✓ HTTP 200"
PASS=$((PASS + 1))
else
echo " ✗ HTTP ${http_code}"
echo " Response: ${body_out}"
FAIL=$((FAIL + 1))
fi
# Print truncated response for inspection
echo " Response (first 300 chars): ${body_out:0:300}"
echo ""
# Return the body for chaining
echo "$body_out"
}
extract_response_id() {
# Extract "id" field from JSON response
echo "$1" | grep -o '"id":"[^"]*"' | head -1 | cut -d'"' -f4
}
# ── Test 1: Basic single-turn request ─────────────────────────────────
RESP1=$(post_json "Basic single-turn request" "$(cat <<EOF
{
"model": "${AGENT_ID}",
"input": "Say hello in exactly 5 words.",
"stream": false
}
EOF
)")
# ── Test 2: Multi-turn with output_text assistant blocks ──────────────
post_json "Multi-turn with output_text blocks (the original bug)" "$(cat <<EOF
{
"model": "${AGENT_ID}",
"input": [
{
"type": "message",
"role": "user",
"content": [{"type": "input_text", "text": "What is 2+2?"}]
},
{
"type": "message",
"role": "assistant",
"content": [{"type": "output_text", "text": "2+2 equals 4.", "annotations": [], "logprobs": []}]
},
{
"type": "message",
"role": "user",
"content": [{"type": "input_text", "text": "And what is 3+3?"}]
}
],
"stream": false
}
EOF
)" > /dev/null
# ── Test 3: Multi-turn with refusal blocks ────────────────────────────
post_json "Multi-turn with refusal blocks" "$(cat <<EOF
{
"model": "${AGENT_ID}",
"input": [
{
"type": "message",
"role": "user",
"content": [{"type": "input_text", "text": "Do something bad"}]
},
{
"type": "message",
"role": "assistant",
"content": [{"type": "refusal", "refusal": "I cannot help with that."}]
},
{
"type": "message",
"role": "user",
"content": [{"type": "input_text", "text": "OK, just say hello then."}]
}
],
"stream": false
}
EOF
)" > /dev/null
# ── Test 4: Streaming request ─────────────────────────────────────────
post_json "Streaming single-turn request" "$(cat <<EOF
{
"model": "${AGENT_ID}",
"input": "Say hi in one word.",
"stream": true
}
EOF
)" "true" > /dev/null
# ── Test 5: Back-and-forth using previous_response_id ─────────────────
RESP5=$(post_json "First turn for previous_response_id chain" "$(cat <<EOF
{
"model": "${AGENT_ID}",
"input": "Remember this number: 42. Just confirm you got it.",
"stream": false
}
EOF
)")
RESP5_ID=$(extract_response_id "$RESP5")
if [ -n "$RESP5_ID" ]; then
echo " Extracted response ID: ${RESP5_ID}"
post_json "Follow-up using previous_response_id" "$(cat <<EOF
{
"model": "${AGENT_ID}",
"input": "What number did I ask you to remember?",
"previous_response_id": "${RESP5_ID}",
"stream": false
}
EOF
)" > /dev/null
else
echo " ⚠ Could not extract response ID, skipping follow-up test"
FAIL=$((FAIL + 1))
fi
# ── Summary ───────────────────────────────────────────────────────────
echo "══════════════════════════════════════════════"
echo "RESULTS: ${PASS} passed, ${FAIL} failed"
echo "══════════════════════════════════════════════"
if [ "$FAIL" -gt 0 ]; then
exit 1
fi

View file

@ -0,0 +1,333 @@
import { convertInputToMessages } from '../service';
import type { InputItem } from '../types';
describe('convertInputToMessages', () => {
// ── String input shorthand ─────────────────────────────────────────
it('converts a string input to a single user message', () => {
const result = convertInputToMessages('Hello');
expect(result).toEqual([{ role: 'user', content: 'Hello' }]);
});
// ── Empty input array ──────────────────────────────────────────────
it('returns an empty array for empty input', () => {
const result = convertInputToMessages([]);
expect(result).toEqual([]);
});
// ── Role mapping ───────────────────────────────────────────────────
it('maps developer role to system', () => {
const input: InputItem[] = [
{ type: 'message', role: 'developer', content: 'You are helpful.' },
];
expect(convertInputToMessages(input)).toEqual([
{ role: 'system', content: 'You are helpful.' },
]);
});
it('maps system role to system', () => {
const input: InputItem[] = [{ type: 'message', role: 'system', content: 'System prompt.' }];
expect(convertInputToMessages(input)).toEqual([{ role: 'system', content: 'System prompt.' }]);
});
it('maps user role to user', () => {
const input: InputItem[] = [{ type: 'message', role: 'user', content: 'Hi' }];
expect(convertInputToMessages(input)).toEqual([{ role: 'user', content: 'Hi' }]);
});
it('maps assistant role to assistant', () => {
const input: InputItem[] = [{ type: 'message', role: 'assistant', content: 'Hello!' }];
expect(convertInputToMessages(input)).toEqual([{ role: 'assistant', content: 'Hello!' }]);
});
it('defaults unknown roles to user', () => {
const input = [
{ type: 'message', role: 'unknown_role', content: 'test' },
] as unknown as InputItem[];
expect(convertInputToMessages(input)[0].role).toBe('user');
});
// ── input_text content blocks ──────────────────────────────────────
it('converts input_text blocks to text blocks', () => {
const input: InputItem[] = [
{
type: 'message',
role: 'user',
content: [{ type: 'input_text', text: 'Hello world' }],
},
];
const result = convertInputToMessages(input);
expect(result).toEqual([{ role: 'user', content: [{ type: 'text', text: 'Hello world' }] }]);
});
// ── output_text content blocks (the original bug) ──────────────────
it('converts output_text blocks to text blocks', () => {
const input: InputItem[] = [
{
type: 'message',
role: 'assistant',
content: [{ type: 'output_text', text: 'I can help!', annotations: [], logprobs: [] }],
},
];
const result = convertInputToMessages(input);
expect(result).toEqual([
{ role: 'assistant', content: [{ type: 'text', text: 'I can help!' }] },
]);
});
// ── refusal content blocks ─────────────────────────────────────────
it('converts refusal blocks to text blocks', () => {
const input: InputItem[] = [
{
type: 'message',
role: 'assistant',
content: [{ type: 'refusal', refusal: 'I cannot do that.' }],
},
];
const result = convertInputToMessages(input);
expect(result).toEqual([
{ role: 'assistant', content: [{ type: 'text', text: 'I cannot do that.' }] },
]);
});
// ── input_image content blocks ─────────────────────────────────────
it('converts input_image blocks to image_url blocks', () => {
const input: InputItem[] = [
{
type: 'message',
role: 'user',
content: [
{ type: 'input_image', image_url: 'https://example.com/img.png', detail: 'high' },
],
},
];
const result = convertInputToMessages(input);
expect(result).toEqual([
{
role: 'user',
content: [
{
type: 'image_url',
image_url: { url: 'https://example.com/img.png', detail: 'high' },
},
],
},
]);
});
// ── input_file content blocks ──────────────────────────────────────
it('converts input_file blocks to text placeholders', () => {
const input: InputItem[] = [
{
type: 'message',
role: 'user',
content: [{ type: 'input_file', filename: 'report.pdf', file_id: 'f_123' }],
},
];
const result = convertInputToMessages(input);
expect(result).toEqual([
{ role: 'user', content: [{ type: 'text', text: '[File: report.pdf]' }] },
]);
});
it('uses "unknown" for input_file without filename', () => {
const input: InputItem[] = [
{
type: 'message',
role: 'user',
content: [{ type: 'input_file', file_id: 'f_123' }],
},
];
const result = convertInputToMessages(input);
expect(result).toEqual([
{ role: 'user', content: [{ type: 'text', text: '[File: unknown]' }] },
]);
});
// ── Null / undefined filtering ─────────────────────────────────────
it('filters out null elements in content arrays', () => {
const input = [
{
type: 'message',
role: 'user',
content: [null, { type: 'input_text', text: 'valid' }, undefined],
},
] as unknown as InputItem[];
const result = convertInputToMessages(input);
expect(result).toEqual([{ role: 'user', content: [{ type: 'text', text: 'valid' }] }]);
});
// ── Missing text field defaults to empty string ────────────────────
it('defaults to empty string when text field is missing on input_text', () => {
const input = [
{
type: 'message',
role: 'user',
content: [{ type: 'input_text' }],
},
] as unknown as InputItem[];
const result = convertInputToMessages(input);
expect(result).toEqual([{ role: 'user', content: [{ type: 'text', text: '' }] }]);
});
it('defaults to empty string when text field is missing on output_text', () => {
const input = [
{
type: 'message',
role: 'assistant',
content: [{ type: 'output_text' }],
},
] as unknown as InputItem[];
const result = convertInputToMessages(input);
expect(result).toEqual([{ role: 'assistant', content: [{ type: 'text', text: '' }] }]);
});
it('defaults to empty string when refusal field is missing on refusal block', () => {
const input = [
{
type: 'message',
role: 'assistant',
content: [{ type: 'refusal' }],
},
] as unknown as InputItem[];
const result = convertInputToMessages(input);
expect(result).toEqual([{ role: 'assistant', content: [{ type: 'text', text: '' }] }]);
});
// ── Unknown block types are filtered out ───────────────────────────
it('filters out unknown content block types', () => {
const input = [
{
type: 'message',
role: 'user',
content: [
{ type: 'input_text', text: 'keep me' },
{ type: 'some_future_type', data: 'ignore' },
],
},
] as unknown as InputItem[];
const result = convertInputToMessages(input);
expect(result).toEqual([{ role: 'user', content: [{ type: 'text', text: 'keep me' }] }]);
});
// ── Mixed valid/invalid content in same array ──────────────────────
it('handles mixed valid and invalid content blocks', () => {
const input = [
{
type: 'message',
role: 'assistant',
content: [
{ type: 'output_text', text: 'Hello', annotations: [], logprobs: [] },
null,
{ type: 'unknown_type' },
{ type: 'refusal', refusal: 'No can do' },
],
},
] as unknown as InputItem[];
const result = convertInputToMessages(input);
expect(result).toEqual([
{
role: 'assistant',
content: [
{ type: 'text', text: 'Hello' },
{ type: 'text', text: 'No can do' },
],
},
]);
});
// ── Non-array, non-string content defaults to empty string ─────────
it('defaults to empty string for non-array non-string content', () => {
const input = [{ type: 'message', role: 'user', content: 42 }] as unknown as InputItem[];
const result = convertInputToMessages(input);
expect(result).toEqual([{ role: 'user', content: '' }]);
});
// ── Function call items ────────────────────────────────────────────
it('converts function_call items to assistant messages with tool_calls', () => {
const input: InputItem[] = [
{
type: 'function_call',
id: 'fc_1',
call_id: 'call_abc',
name: 'get_weather',
arguments: '{"city":"NYC"}',
},
];
const result = convertInputToMessages(input);
expect(result).toEqual([
{
role: 'assistant',
content: '',
tool_calls: [
{
id: 'call_abc',
type: 'function',
function: { name: 'get_weather', arguments: '{"city":"NYC"}' },
},
],
},
]);
});
// ── Function call output items ─────────────────────────────────────
it('converts function_call_output items to tool messages', () => {
const input: InputItem[] = [
{
type: 'function_call_output',
call_id: 'call_abc',
output: '{"temp":72}',
},
];
const result = convertInputToMessages(input);
expect(result).toEqual([
{
role: 'tool',
content: '{"temp":72}',
tool_call_id: 'call_abc',
},
]);
});
// ── Item references are skipped ────────────────────────────────────
it('skips item_reference items', () => {
const input: InputItem[] = [
{ type: 'item_reference', id: 'ref_123' },
{ type: 'message', role: 'user', content: 'Hello' },
];
const result = convertInputToMessages(input);
expect(result).toEqual([{ role: 'user', content: 'Hello' }]);
});
// ── Multi-turn conversation (the real-world scenario) ──────────────
it('handles a full multi-turn conversation with output_text blocks', () => {
const input: InputItem[] = [
{
type: 'message',
role: 'developer',
content: [{ type: 'input_text', text: 'You are a helpful assistant.' }],
},
{
type: 'message',
role: 'user',
content: [{ type: 'input_text', text: 'What is 2+2?' }],
},
{
type: 'message',
role: 'assistant',
content: [{ type: 'output_text', text: '2+2 is 4.', annotations: [], logprobs: [] }],
},
{
type: 'message',
role: 'user',
content: [{ type: 'input_text', text: 'And 3+3?' }],
},
];
const result = convertInputToMessages(input);
expect(result).toEqual([
{ role: 'system', content: [{ type: 'text', text: 'You are a helpful assistant.' }] },
{ role: 'user', content: [{ type: 'text', text: 'What is 2+2?' }] },
{ role: 'assistant', content: [{ type: 'text', text: '2+2 is 4.' }] },
{ role: 'user', content: [{ type: 'text', text: 'And 3+3?' }] },
]);
});
});

View file

@ -6,11 +6,12 @@
*/
import type { Response as ServerResponse } from 'express';
import type {
ResponseRequest,
RequestValidationResult,
InputItem,
InputContent,
ResponseRequest,
ResponseContext,
InputContent,
ModelContent,
InputItem,
Response,
} from './types';
import {
@ -134,7 +135,7 @@ export function convertInputToMessages(input: string | InputItem[]): InternalMes
const messageItem = item as {
type: 'message';
role: string;
content: string | InputContent[];
content: string | (InputContent | ModelContent)[];
};
let content: InternalMessage['content'];
@ -142,21 +143,31 @@ export function convertInputToMessages(input: string | InputItem[]): InternalMes
if (typeof messageItem.content === 'string') {
content = messageItem.content;
} else if (Array.isArray(messageItem.content)) {
content = messageItem.content.map((part) => {
if (part.type === 'input_text') {
return { type: 'text', text: part.text };
}
if (part.type === 'input_image') {
return {
type: 'image_url',
image_url: {
url: (part as { image_url?: string }).image_url,
detail: (part as { detail?: string }).detail,
},
};
}
return { type: part.type };
});
content = messageItem.content
.filter((part): part is InputContent | ModelContent => part != null)
.map((part) => {
if (part.type === 'input_text' || part.type === 'output_text') {
return { type: 'text', text: (part as { text?: string }).text ?? '' };
}
if (part.type === 'refusal') {
return { type: 'text', text: (part as { refusal?: string }).refusal ?? '' };
}
if (part.type === 'input_image') {
return {
type: 'image_url',
image_url: {
url: (part as { image_url?: string }).image_url,
detail: (part as { detail?: string }).detail,
},
};
}
if (part.type === 'input_file') {
const filePart = part as { filename?: string };
return { type: 'text', text: `[File: ${filePart.filename ?? 'unknown'}]` };
}
return null;
})
.filter((part): part is NonNullable<typeof part> => part != null);
} else {
content = '';
}

View file

@ -1,7 +1,7 @@
import { getTransactionsConfig, getBalanceConfig } from './config';
import { getTransactionsConfig, getBalanceConfig, getCustomEndpointConfig } from './config';
import { logger } from '@librechat/data-schemas';
import { FileSources } from 'librechat-data-provider';
import type { TCustomConfig } from 'librechat-data-provider';
import { FileSources, EModelEndpoint } from 'librechat-data-provider';
import type { TCustomConfig, TEndpoint } from 'librechat-data-provider';
import type { AppConfig } from '@librechat/data-schemas';
// Helper function to create a minimal AppConfig for testing
@ -282,3 +282,75 @@ describe('getBalanceConfig', () => {
});
});
});
describe('getCustomEndpointConfig', () => {
describe('when appConfig is not provided', () => {
it('should throw an error', () => {
expect(() => getCustomEndpointConfig({ endpoint: 'test' })).toThrow(
'Config not found for the test custom endpoint.',
);
});
});
describe('when appConfig is provided', () => {
it('should return undefined when no custom endpoints are configured', () => {
const appConfig = createTestAppConfig();
const result = getCustomEndpointConfig({ endpoint: 'test', appConfig });
expect(result).toBeUndefined();
});
it('should return the matching endpoint config when found', () => {
const appConfig = createTestAppConfig({
endpoints: {
[EModelEndpoint.custom]: [
{
name: 'TestEndpoint',
apiKey: 'test-key',
} as TEndpoint,
],
},
});
const result = getCustomEndpointConfig({ endpoint: 'TestEndpoint', appConfig });
expect(result).toEqual({
name: 'TestEndpoint',
apiKey: 'test-key',
});
});
it('should handle case-insensitive matching for Ollama endpoint', () => {
const appConfig = createTestAppConfig({
endpoints: {
[EModelEndpoint.custom]: [
{
name: 'Ollama',
apiKey: 'ollama-key',
} as TEndpoint,
],
},
});
const result = getCustomEndpointConfig({ endpoint: 'Ollama', appConfig });
expect(result).toEqual({
name: 'Ollama',
apiKey: 'ollama-key',
});
});
it('should handle mixed case endpoint names', () => {
const appConfig = createTestAppConfig({
endpoints: {
[EModelEndpoint.custom]: [
{
name: 'CustomAI',
apiKey: 'custom-key',
} as TEndpoint,
],
},
});
const result = getCustomEndpointConfig({ endpoint: 'customai', appConfig });
expect(result).toBeUndefined();
});
});
});

View file

@ -64,7 +64,7 @@ export const getCustomEndpointConfig = ({
const customEndpoints = appConfig.endpoints?.[EModelEndpoint.custom] ?? [];
return customEndpoints.find(
(endpointConfig) => normalizeEndpointName(endpointConfig.name) === endpoint,
(endpointConfig) => normalizeEndpointName(endpointConfig.name) === normalizeEndpointName(endpoint),
);
};

View file

@ -0,0 +1,113 @@
jest.mock('node:dns', () => {
const actual = jest.requireActual('node:dns');
return {
...actual,
lookup: jest.fn(),
};
});
import dns from 'node:dns';
import { createSSRFSafeAgents, createSSRFSafeUndiciConnect } from './agent';
type LookupCallback = (err: NodeJS.ErrnoException | null, address: string, family: number) => void;
const mockedDnsLookup = dns.lookup as jest.MockedFunction<typeof dns.lookup>;
function mockDnsResult(address: string, family: number): void {
mockedDnsLookup.mockImplementation(((
_hostname: string,
_options: unknown,
callback: LookupCallback,
) => {
callback(null, address, family);
}) as never);
}
function mockDnsError(err: NodeJS.ErrnoException): void {
mockedDnsLookup.mockImplementation(((
_hostname: string,
_options: unknown,
callback: LookupCallback,
) => {
callback(err, '', 0);
}) as never);
}
describe('createSSRFSafeAgents', () => {
afterEach(() => {
jest.clearAllMocks();
});
it('should return httpAgent and httpsAgent', () => {
const agents = createSSRFSafeAgents();
expect(agents.httpAgent).toBeDefined();
expect(agents.httpsAgent).toBeDefined();
});
it('should patch httpAgent createConnection to inject SSRF lookup', () => {
const agents = createSSRFSafeAgents();
const internal = agents.httpAgent as unknown as {
createConnection: (opts: Record<string, unknown>) => unknown;
};
expect(internal.createConnection).toBeInstanceOf(Function);
});
});
describe('createSSRFSafeUndiciConnect', () => {
afterEach(() => {
jest.clearAllMocks();
});
it('should return an object with a lookup function', () => {
const connect = createSSRFSafeUndiciConnect();
expect(connect).toHaveProperty('lookup');
expect(connect.lookup).toBeInstanceOf(Function);
});
it('lookup should block private IPs', async () => {
mockDnsResult('10.0.0.1', 4);
const connect = createSSRFSafeUndiciConnect();
const result = await new Promise<{ err: NodeJS.ErrnoException | null }>((resolve) => {
connect.lookup('evil.example.com', {}, (err) => {
resolve({ err });
});
});
expect(result.err).toBeTruthy();
expect(result.err!.code).toBe('ESSRF');
});
it('lookup should allow public IPs', async () => {
mockDnsResult('93.184.216.34', 4);
const connect = createSSRFSafeUndiciConnect();
const result = await new Promise<{ err: NodeJS.ErrnoException | null; address: string }>(
(resolve) => {
connect.lookup('example.com', {}, (err, address) => {
resolve({ err, address: address as string });
});
},
);
expect(result.err).toBeNull();
expect(result.address).toBe('93.184.216.34');
});
it('lookup should forward DNS errors', async () => {
const dnsError = Object.assign(new Error('ENOTFOUND'), {
code: 'ENOTFOUND',
}) as NodeJS.ErrnoException;
mockDnsError(dnsError);
const connect = createSSRFSafeUndiciConnect();
const result = await new Promise<{ err: NodeJS.ErrnoException | null }>((resolve) => {
connect.lookup('nonexistent.example.com', {}, (err) => {
resolve({ err });
});
});
expect(result.err).toBeTruthy();
expect(result.err!.code).toBe('ENOTFOUND');
});
});

View file

@ -0,0 +1,61 @@
import dns from 'node:dns';
import http from 'node:http';
import https from 'node:https';
import type { LookupFunction } from 'node:net';
import { isPrivateIP } from './domain';
/** DNS lookup wrapper that blocks resolution to private/reserved IP addresses */
const ssrfSafeLookup: LookupFunction = (hostname, options, callback) => {
dns.lookup(hostname, options, (err, address, family) => {
if (err) {
callback(err, '', 0);
return;
}
if (typeof address === 'string' && isPrivateIP(address)) {
const ssrfError = Object.assign(
new Error(`SSRF protection: ${hostname} resolved to blocked address ${address}`),
{ code: 'ESSRF' },
) as NodeJS.ErrnoException;
callback(ssrfError, address, family as number);
return;
}
callback(null, address as string, family as number);
});
};
/** Internal agent shape exposing createConnection (exists at runtime but not in TS types) */
type AgentInternal = {
createConnection: (options: Record<string, unknown>, oncreate?: unknown) => unknown;
};
/** Patches an agent instance to inject SSRF-safe DNS lookup at connect time */
function withSSRFProtection<T extends http.Agent>(agent: T): T {
const internal = agent as unknown as AgentInternal;
const origCreate = internal.createConnection.bind(agent);
internal.createConnection = (options: Record<string, unknown>, oncreate?: unknown) => {
options.lookup = ssrfSafeLookup;
return origCreate(options, oncreate);
};
return agent;
}
/**
* Creates HTTP and HTTPS agents that block TCP connections to private/reserved IP addresses.
* Provides TOCTOU-safe SSRF protection by validating the resolved IP at connect time,
* preventing DNS rebinding attacks where a hostname resolves to a public IP during
* pre-validation but to a private IP when the actual connection is made.
*/
export function createSSRFSafeAgents(): { httpAgent: http.Agent; httpsAgent: https.Agent } {
return {
httpAgent: withSSRFProtection(new http.Agent()),
httpsAgent: withSSRFProtection(new https.Agent()),
};
}
/**
* Returns undici-compatible `connect` options with SSRF-safe DNS lookup.
* Pass the result as the `connect` property when constructing an undici `Agent`.
*/
export function createSSRFSafeUndiciConnect(): { lookup: LookupFunction } {
return { lookup: ssrfSafeLookup };
}

View file

@ -1,12 +1,21 @@
/* eslint-disable @typescript-eslint/ban-ts-comment */
jest.mock('node:dns/promises', () => ({
lookup: jest.fn(),
}));
import { lookup } from 'node:dns/promises';
import {
extractMCPServerDomain,
isActionDomainAllowed,
isEmailDomainAllowed,
isMCPDomainAllowed,
isPrivateIP,
isSSRFTarget,
resolveHostnameSSRF,
} from './domain';
const mockedLookup = lookup as jest.MockedFunction<typeof lookup>;
describe('isEmailDomainAllowed', () => {
afterEach(() => {
jest.clearAllMocks();
@ -192,7 +201,154 @@ describe('isSSRFTarget', () => {
});
});
describe('isPrivateIP', () => {
describe('IPv4 private ranges', () => {
it('should detect loopback addresses', () => {
expect(isPrivateIP('127.0.0.1')).toBe(true);
expect(isPrivateIP('127.255.255.255')).toBe(true);
});
it('should detect 10.x.x.x private range', () => {
expect(isPrivateIP('10.0.0.1')).toBe(true);
expect(isPrivateIP('10.255.255.255')).toBe(true);
});
it('should detect 172.16-31.x.x private range', () => {
expect(isPrivateIP('172.16.0.1')).toBe(true);
expect(isPrivateIP('172.31.255.255')).toBe(true);
expect(isPrivateIP('172.15.0.1')).toBe(false);
expect(isPrivateIP('172.32.0.1')).toBe(false);
});
it('should detect 192.168.x.x private range', () => {
expect(isPrivateIP('192.168.0.1')).toBe(true);
expect(isPrivateIP('192.168.255.255')).toBe(true);
});
it('should detect 169.254.x.x link-local range', () => {
expect(isPrivateIP('169.254.169.254')).toBe(true);
expect(isPrivateIP('169.254.0.1')).toBe(true);
});
it('should detect 0.0.0.0', () => {
expect(isPrivateIP('0.0.0.0')).toBe(true);
});
it('should allow public IPs', () => {
expect(isPrivateIP('8.8.8.8')).toBe(false);
expect(isPrivateIP('1.1.1.1')).toBe(false);
expect(isPrivateIP('93.184.216.34')).toBe(false);
});
});
describe('IPv6 private ranges', () => {
it('should detect loopback', () => {
expect(isPrivateIP('::1')).toBe(true);
expect(isPrivateIP('::')).toBe(true);
expect(isPrivateIP('[::1]')).toBe(true);
});
it('should detect unique local (fc/fd) and link-local (fe80)', () => {
expect(isPrivateIP('fc00::1')).toBe(true);
expect(isPrivateIP('fd00::1')).toBe(true);
expect(isPrivateIP('fe80::1')).toBe(true);
});
});
describe('IPv4-mapped IPv6 addresses', () => {
it('should detect private IPs in IPv4-mapped IPv6 form', () => {
expect(isPrivateIP('::ffff:169.254.169.254')).toBe(true);
expect(isPrivateIP('::ffff:127.0.0.1')).toBe(true);
expect(isPrivateIP('::ffff:10.0.0.1')).toBe(true);
expect(isPrivateIP('::ffff:192.168.1.1')).toBe(true);
});
it('should allow public IPs in IPv4-mapped IPv6 form', () => {
expect(isPrivateIP('::ffff:8.8.8.8')).toBe(false);
expect(isPrivateIP('::ffff:93.184.216.34')).toBe(false);
});
});
});
describe('resolveHostnameSSRF', () => {
afterEach(() => {
jest.clearAllMocks();
});
it('should detect domains that resolve to private IPs (nip.io bypass)', async () => {
mockedLookup.mockResolvedValueOnce([{ address: '169.254.169.254', family: 4 }] as never);
expect(await resolveHostnameSSRF('169.254.169.254.nip.io')).toBe(true);
});
it('should detect domains that resolve to loopback', async () => {
mockedLookup.mockResolvedValueOnce([{ address: '127.0.0.1', family: 4 }] as never);
expect(await resolveHostnameSSRF('loopback.example.com')).toBe(true);
});
it('should detect when any resolved address is private', async () => {
mockedLookup.mockResolvedValueOnce([
{ address: '93.184.216.34', family: 4 },
{ address: '10.0.0.1', family: 4 },
] as never);
expect(await resolveHostnameSSRF('dual.example.com')).toBe(true);
});
it('should allow domains that resolve to public IPs', async () => {
mockedLookup.mockResolvedValueOnce([{ address: '93.184.216.34', family: 4 }] as never);
expect(await resolveHostnameSSRF('example.com')).toBe(false);
});
it('should skip literal IPv4 addresses (handled by isSSRFTarget)', async () => {
expect(await resolveHostnameSSRF('169.254.169.254')).toBe(false);
expect(mockedLookup).not.toHaveBeenCalled();
});
it('should skip literal IPv6 addresses', async () => {
expect(await resolveHostnameSSRF('::1')).toBe(false);
expect(mockedLookup).not.toHaveBeenCalled();
});
it('should fail open on DNS resolution failure', async () => {
mockedLookup.mockRejectedValueOnce(new Error('ENOTFOUND'));
expect(await resolveHostnameSSRF('nonexistent.example.com')).toBe(false);
});
});
describe('isActionDomainAllowed - DNS resolution SSRF protection', () => {
afterEach(() => {
jest.clearAllMocks();
});
it('should block domains resolving to cloud metadata IP (169.254.169.254)', async () => {
mockedLookup.mockResolvedValueOnce([{ address: '169.254.169.254', family: 4 }] as never);
expect(await isActionDomainAllowed('169.254.169.254.nip.io', null)).toBe(false);
});
it('should block domains resolving to private 10.x range', async () => {
mockedLookup.mockResolvedValueOnce([{ address: '10.0.0.5', family: 4 }] as never);
expect(await isActionDomainAllowed('internal.attacker.com', null)).toBe(false);
});
it('should block domains resolving to 172.16.x range', async () => {
mockedLookup.mockResolvedValueOnce([{ address: '172.16.0.1', family: 4 }] as never);
expect(await isActionDomainAllowed('docker.attacker.com', null)).toBe(false);
});
it('should allow domains resolving to public IPs when no allowlist', async () => {
mockedLookup.mockResolvedValueOnce([{ address: '93.184.216.34', family: 4 }] as never);
expect(await isActionDomainAllowed('example.com', null)).toBe(true);
});
it('should not perform DNS check when allowedDomains is configured', async () => {
expect(await isActionDomainAllowed('example.com', ['example.com'])).toBe(true);
expect(mockedLookup).not.toHaveBeenCalled();
});
});
describe('isActionDomainAllowed', () => {
beforeEach(() => {
mockedLookup.mockResolvedValue([{ address: '93.184.216.34', family: 4 }] as never);
});
afterEach(() => {
jest.clearAllMocks();
});
@ -541,6 +697,9 @@ describe('extractMCPServerDomain', () => {
});
describe('isMCPDomainAllowed', () => {
beforeEach(() => {
mockedLookup.mockResolvedValue([{ address: '93.184.216.34', family: 4 }] as never);
});
afterEach(() => {
jest.clearAllMocks();
});

View file

@ -1,3 +1,5 @@
import { lookup } from 'node:dns/promises';
/**
* @param email
* @param allowedDomains
@ -22,6 +24,88 @@ export function isEmailDomainAllowed(email: string, allowedDomains?: string[] |
return allowedDomains.some((allowedDomain) => allowedDomain?.toLowerCase() === domain);
}
/** Checks if IPv4 octets fall within private, reserved, or link-local ranges */
function isPrivateIPv4(a: number, b: number, c: number): boolean {
if (a === 127) {
return true;
}
if (a === 10) {
return true;
}
if (a === 172 && b >= 16 && b <= 31) {
return true;
}
if (a === 192 && b === 168) {
return true;
}
if (a === 169 && b === 254) {
return true;
}
if (a === 0 && b === 0 && c === 0) {
return true;
}
return false;
}
/**
* Checks if an IP address belongs to a private, reserved, or link-local range.
* Handles IPv4, IPv6, and IPv4-mapped IPv6 addresses (::ffff:A.B.C.D).
*/
export function isPrivateIP(ip: string): boolean {
const normalized = ip.toLowerCase().trim();
const mappedMatch = normalized.match(/^::ffff:(\d{1,3})\.(\d{1,3})\.(\d{1,3})\.(\d{1,3})$/);
if (mappedMatch) {
const [, a, b, c] = mappedMatch.map(Number);
return isPrivateIPv4(a, b, c);
}
const ipv4Match = normalized.match(/^(\d{1,3})\.(\d{1,3})\.(\d{1,3})\.(\d{1,3})$/);
if (ipv4Match) {
const [, a, b, c] = ipv4Match.map(Number);
return isPrivateIPv4(a, b, c);
}
const ipv6 = normalized.replace(/^\[|\]$/g, '');
if (
ipv6 === '::1' ||
ipv6 === '::' ||
ipv6.startsWith('fc') ||
ipv6.startsWith('fd') ||
ipv6.startsWith('fe80')
) {
return true;
}
return false;
}
/**
* Resolves a hostname via DNS and checks if any resolved address is a private/reserved IP.
* Detects DNS-based SSRF bypasses (e.g., nip.io wildcard DNS, attacker-controlled nameservers).
* Fails open: returns false if DNS resolution fails, since hostname-only checks still apply
* and the actual HTTP request would also fail.
*/
export async function resolveHostnameSSRF(hostname: string): Promise<boolean> {
const normalizedHost = hostname.toLowerCase().trim();
if (/^(\d{1,3})\.(\d{1,3})\.(\d{1,3})\.(\d{1,3})$/.test(normalizedHost)) {
return false;
}
const ipv6Check = normalizedHost.replace(/^\[|\]$/g, '');
if (ipv6Check.includes(':')) {
return false;
}
try {
const addresses = await lookup(hostname, { all: true });
return addresses.some((entry) => isPrivateIP(entry.address));
} catch {
return false;
}
}
/**
* SSRF Protection: Checks if a hostname/IP is a potentially dangerous internal target.
* Blocks private IPs, localhost, cloud metadata IPs, and common internal hostnames.
@ -31,7 +115,6 @@ export function isEmailDomainAllowed(email: string, allowedDomains?: string[] |
export function isSSRFTarget(hostname: string): boolean {
const normalizedHost = hostname.toLowerCase().trim();
// Block localhost variations
if (
normalizedHost === 'localhost' ||
normalizedHost === 'localhost.localdomain' ||
@ -40,51 +123,7 @@ export function isSSRFTarget(hostname: string): boolean {
return true;
}
// Check if it's an IP address and block private/internal ranges
const ipv4Match = normalizedHost.match(/^(\d{1,3})\.(\d{1,3})\.(\d{1,3})\.(\d{1,3})$/);
if (ipv4Match) {
const [, a, b, c] = ipv4Match.map(Number);
// 127.0.0.0/8 - Loopback
if (a === 127) {
return true;
}
// 10.0.0.0/8 - Private
if (a === 10) {
return true;
}
// 172.16.0.0/12 - Private (172.16.x.x - 172.31.x.x)
if (a === 172 && b >= 16 && b <= 31) {
return true;
}
// 192.168.0.0/16 - Private
if (a === 192 && b === 168) {
return true;
}
// 169.254.0.0/16 - Link-local (includes cloud metadata 169.254.169.254)
if (a === 169 && b === 254) {
return true;
}
// 0.0.0.0 - Special
if (a === 0 && b === 0 && c === 0) {
return true;
}
}
// IPv6 loopback and private ranges
const ipv6Normalized = normalizedHost.replace(/^\[|\]$/g, ''); // Remove brackets if present
if (
ipv6Normalized === '::1' ||
ipv6Normalized === '::' ||
ipv6Normalized.startsWith('fc') || // fc00::/7 - Unique local
ipv6Normalized.startsWith('fd') || // fd00::/8 - Unique local
ipv6Normalized.startsWith('fe80') // fe80::/10 - Link-local
) {
if (isPrivateIP(normalizedHost)) {
return true;
}
@ -257,6 +296,10 @@ async function isDomainAllowedCore(
if (isSSRFTarget(inputSpec.hostname)) {
return false;
}
/** SECURITY: Resolve hostname and block if it points to a private/reserved IP */
if (await resolveHostnameSSRF(inputSpec.hostname)) {
return false;
}
return true;
}

View file

@ -1,3 +1,4 @@
export * from './domain';
export * from './openid';
export * from './exchange';
export * from './agent';

View file

@ -215,16 +215,30 @@ describe('cacheConfig', () => {
}).rejects.toThrow('Invalid cache keys in FORCED_IN_MEMORY_CACHE_NAMESPACES: INVALID_KEY');
});
test('should handle empty string gracefully', async () => {
test('should produce empty array when set to empty string (opt out of defaults)', async () => {
process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES = '';
const { cacheConfig } = await import('../cacheConfig');
expect(cacheConfig.FORCED_IN_MEMORY_CACHE_NAMESPACES).toEqual([]);
});
test('should handle undefined env var gracefully', async () => {
test('should default to CONFIG_STORE and APP_CONFIG when env var is not set', async () => {
const { cacheConfig } = await import('../cacheConfig');
expect(cacheConfig.FORCED_IN_MEMORY_CACHE_NAMESPACES).toEqual([]);
expect(cacheConfig.FORCED_IN_MEMORY_CACHE_NAMESPACES).toEqual(['CONFIG_STORE', 'APP_CONFIG']);
});
test('should accept TOOL_CACHE as a valid namespace', async () => {
process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES = 'TOOL_CACHE';
const { cacheConfig } = await import('../cacheConfig');
expect(cacheConfig.FORCED_IN_MEMORY_CACHE_NAMESPACES).toEqual(['TOOL_CACHE']);
});
test('should accept CONFIG_STORE and APP_CONFIG together for blue/green deployments', async () => {
process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES = 'CONFIG_STORE,APP_CONFIG';
const { cacheConfig } = await import('../cacheConfig');
expect(cacheConfig.FORCED_IN_MEMORY_CACHE_NAMESPACES).toEqual(['CONFIG_STORE', 'APP_CONFIG']);
});
});
});

View file

@ -0,0 +1,135 @@
import { CacheKeys } from 'librechat-data-provider';
const mockKeyvRedisInstance = {
namespace: '',
keyPrefixSeparator: '',
on: jest.fn(),
};
const MockKeyvRedis = jest.fn().mockReturnValue(mockKeyvRedisInstance);
jest.mock('@keyv/redis', () => ({
default: MockKeyvRedis,
}));
const mockKeyvRedisClient = { scanIterator: jest.fn() };
jest.mock('../../redisClients', () => ({
keyvRedisClient: mockKeyvRedisClient,
ioredisClient: null,
}));
jest.mock('../../redisUtils', () => ({
batchDeleteKeys: jest.fn(),
scanKeys: jest.fn(),
}));
jest.mock('@librechat/data-schemas', () => ({
logger: {
error: jest.fn(),
warn: jest.fn(),
debug: jest.fn(),
},
}));
describe('standardCache - CONFIG_STORE vs TOOL_CACHE namespace isolation', () => {
afterEach(() => {
jest.resetModules();
MockKeyvRedis.mockClear();
});
/**
* Core behavioral test for blue/green deployments:
* When CONFIG_STORE and APP_CONFIG are forced in-memory,
* TOOL_CACHE should still use Redis for cross-container sharing.
*/
it('should force CONFIG_STORE to in-memory while TOOL_CACHE uses Redis', async () => {
jest.doMock('../../cacheConfig', () => ({
cacheConfig: {
FORCED_IN_MEMORY_CACHE_NAMESPACES: [CacheKeys.CONFIG_STORE, CacheKeys.APP_CONFIG],
REDIS_KEY_PREFIX: '',
GLOBAL_PREFIX_SEPARATOR: '>>',
},
}));
const { standardCache } = await import('../../cacheFactory');
MockKeyvRedis.mockClear();
const configCache = standardCache(CacheKeys.CONFIG_STORE);
expect(MockKeyvRedis).not.toHaveBeenCalled();
expect(configCache).toBeDefined();
const appConfigCache = standardCache(CacheKeys.APP_CONFIG);
expect(MockKeyvRedis).not.toHaveBeenCalled();
expect(appConfigCache).toBeDefined();
const toolCache = standardCache(CacheKeys.TOOL_CACHE);
expect(MockKeyvRedis).toHaveBeenCalledTimes(1);
expect(MockKeyvRedis).toHaveBeenCalledWith(mockKeyvRedisClient);
expect(toolCache).toBeDefined();
});
it('CONFIG_STORE and TOOL_CACHE should be independent stores', async () => {
jest.doMock('../../cacheConfig', () => ({
cacheConfig: {
FORCED_IN_MEMORY_CACHE_NAMESPACES: [CacheKeys.CONFIG_STORE],
REDIS_KEY_PREFIX: '',
GLOBAL_PREFIX_SEPARATOR: '>>',
},
}));
const { standardCache } = await import('../../cacheFactory');
const configCache = standardCache(CacheKeys.CONFIG_STORE);
const toolCache = standardCache(CacheKeys.TOOL_CACHE);
await configCache.set('STARTUP_CONFIG', { version: 'v2-green' });
await toolCache.set('tools:global', { myTool: { type: 'function' } });
expect(await configCache.get('STARTUP_CONFIG')).toEqual({ version: 'v2-green' });
expect(await configCache.get('tools:global')).toBeUndefined();
expect(await toolCache.get('STARTUP_CONFIG')).toBeUndefined();
});
it('should use Redis for all namespaces when nothing is forced in-memory', async () => {
jest.doMock('../../cacheConfig', () => ({
cacheConfig: {
FORCED_IN_MEMORY_CACHE_NAMESPACES: [],
REDIS_KEY_PREFIX: '',
GLOBAL_PREFIX_SEPARATOR: '>>',
},
}));
const { standardCache } = await import('../../cacheFactory');
MockKeyvRedis.mockClear();
standardCache(CacheKeys.CONFIG_STORE);
standardCache(CacheKeys.TOOL_CACHE);
standardCache(CacheKeys.APP_CONFIG);
expect(MockKeyvRedis).toHaveBeenCalledTimes(3);
});
it('forcing TOOL_CACHE to in-memory should not affect CONFIG_STORE', async () => {
jest.doMock('../../cacheConfig', () => ({
cacheConfig: {
FORCED_IN_MEMORY_CACHE_NAMESPACES: [CacheKeys.TOOL_CACHE],
REDIS_KEY_PREFIX: '',
GLOBAL_PREFIX_SEPARATOR: '>>',
},
}));
const { standardCache } = await import('../../cacheFactory');
MockKeyvRedis.mockClear();
standardCache(CacheKeys.TOOL_CACHE);
expect(MockKeyvRedis).not.toHaveBeenCalled();
standardCache(CacheKeys.CONFIG_STORE);
expect(MockKeyvRedis).toHaveBeenCalledTimes(1);
});
});

View file

@ -20,6 +20,24 @@ interface ViolationData {
};
}
/** Waits for both Redis clients (ioredis + keyv/node-redis) to be ready */
async function waitForRedisClients() {
const redisClients = await import('../../redisClients');
const { ioredisClient, keyvRedisClientReady } = redisClients;
if (ioredisClient && ioredisClient.status !== 'ready') {
await new Promise<void>((resolve) => {
ioredisClient.once('ready', resolve);
});
}
if (keyvRedisClientReady) {
await keyvRedisClientReady;
}
return redisClients;
}
describe('violationCache', () => {
let originalEnv: NodeJS.ProcessEnv;
@ -45,17 +63,9 @@ describe('violationCache', () => {
test('should create violation cache with Redis when USE_REDIS is true', async () => {
const cacheFactory = await import('../../cacheFactory');
const redisClients = await import('../../redisClients');
const { ioredisClient } = redisClients;
await waitForRedisClients();
const cache = cacheFactory.violationCache('test-violations', 60000); // 60 second TTL
// Wait for Redis connection to be ready
if (ioredisClient && ioredisClient.status !== 'ready') {
await new Promise<void>((resolve) => {
ioredisClient.once('ready', resolve);
});
}
// Verify it returns a Keyv instance
expect(cache).toBeDefined();
expect(cache.constructor.name).toBe('Keyv');
@ -112,18 +122,10 @@ describe('violationCache', () => {
test('should respect namespace prefixing', async () => {
const cacheFactory = await import('../../cacheFactory');
const redisClients = await import('../../redisClients');
const { ioredisClient } = redisClients;
await waitForRedisClients();
const cache1 = cacheFactory.violationCache('namespace1');
const cache2 = cacheFactory.violationCache('namespace2');
// Wait for Redis connection to be ready
if (ioredisClient && ioredisClient.status !== 'ready') {
await new Promise<void>((resolve) => {
ioredisClient.once('ready', resolve);
});
}
const testKey = 'shared-key';
const value1: ViolationData = { namespace: 1 };
const value2: ViolationData = { namespace: 2 };
@ -146,18 +148,10 @@ describe('violationCache', () => {
test('should respect TTL settings', async () => {
const cacheFactory = await import('../../cacheFactory');
const redisClients = await import('../../redisClients');
const { ioredisClient } = redisClients;
await waitForRedisClients();
const ttl = 1000; // 1 second TTL
const cache = cacheFactory.violationCache('ttl-test', ttl);
// Wait for Redis connection to be ready
if (ioredisClient && ioredisClient.status !== 'ready') {
await new Promise<void>((resolve) => {
ioredisClient.once('ready', resolve);
});
}
const testKey = 'ttl-key';
const testValue: ViolationData = { data: 'expires soon' };
@ -178,17 +172,9 @@ describe('violationCache', () => {
test('should handle complex violation data structures', async () => {
const cacheFactory = await import('../../cacheFactory');
const redisClients = await import('../../redisClients');
const { ioredisClient } = redisClients;
await waitForRedisClients();
const cache = cacheFactory.violationCache('complex-violations');
// Wait for Redis connection to be ready
if (ioredisClient && ioredisClient.status !== 'ready') {
await new Promise<void>((resolve) => {
ioredisClient.once('ready', resolve);
});
}
const complexData: ViolationData = {
userId: 'user123',
violations: [

View file

@ -27,9 +27,14 @@ const USE_REDIS_STREAMS =
// Comma-separated list of cache namespaces that should be forced to use in-memory storage
// even when Redis is enabled. This allows selective performance optimization for specific caches.
const FORCED_IN_MEMORY_CACHE_NAMESPACES = process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES
? process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES.split(',').map((key) => key.trim())
: [];
// Defaults to CONFIG_STORE,APP_CONFIG so YAML-derived config stays per-container.
// Set to empty string to force all namespaces through Redis.
const FORCED_IN_MEMORY_CACHE_NAMESPACES =
process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES !== undefined
? process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES.split(',')
.map((key) => key.trim())
.filter(Boolean)
: [CacheKeys.CONFIG_STORE, CacheKeys.APP_CONFIG];
// Validate against CacheKeys enum
if (FORCED_IN_MEMORY_CACHE_NAMESPACES.length > 0) {

View file

@ -120,7 +120,9 @@ export const limiterCache = (prefix: string): RedisStore | undefined => {
if (!cacheConfig.USE_REDIS) {
return undefined;
}
// TODO: The prefix is not actually applied. Also needs to account for global prefix.
// Note: The `prefix` is applied by RedisStore internally to its key operations.
// The global REDIS_KEY_PREFIX is applied by ioredisClient's keyPrefix setting.
// Combined key format: `{REDIS_KEY_PREFIX}::{prefix}{identifier}`
prefix = prefix.endsWith(':') ? prefix : `${prefix}:`;
try {

View file

@ -29,7 +29,9 @@ if (cacheConfig.USE_REDIS) {
);
return null;
}
const delay = Math.min(times * 50, cacheConfig.REDIS_RETRY_MAX_DELAY);
const base = Math.min(Math.pow(2, times) * 50, cacheConfig.REDIS_RETRY_MAX_DELAY);
const jitter = Math.floor(Math.random() * Math.min(base, 1000));
const delay = Math.min(base + jitter, cacheConfig.REDIS_RETRY_MAX_DELAY);
logger.info(`ioredis reconnecting... attempt ${times}, delay ${delay}ms`);
return delay;
},
@ -71,7 +73,9 @@ if (cacheConfig.USE_REDIS) {
);
return null;
}
const delay = Math.min(times * 100, cacheConfig.REDIS_RETRY_MAX_DELAY);
const base = Math.min(Math.pow(2, times) * 100, cacheConfig.REDIS_RETRY_MAX_DELAY);
const jitter = Math.floor(Math.random() * Math.min(base, 1000));
const delay = Math.min(base + jitter, cacheConfig.REDIS_RETRY_MAX_DELAY);
logger.info(`ioredis cluster reconnecting... attempt ${times}, delay ${delay}ms`);
return delay;
},
@ -149,7 +153,9 @@ if (cacheConfig.USE_REDIS) {
);
return new Error('Max reconnection attempts reached');
}
const delay = Math.min(retries * 100, cacheConfig.REDIS_RETRY_MAX_DELAY);
const base = Math.min(Math.pow(2, retries) * 100, cacheConfig.REDIS_RETRY_MAX_DELAY);
const jitter = Math.floor(Math.random() * Math.min(base, 1000));
const delay = Math.min(base + jitter, cacheConfig.REDIS_RETRY_MAX_DELAY);
logger.info(`@keyv/redis reconnecting... attempt ${retries}, delay ${delay}ms`);
return delay;
},

View file

@ -65,7 +65,7 @@ function getClaudeHeaders(
/**
* Configures reasoning-related options for Claude models.
* Models supporting adaptive thinking (Opus 4.6+, Sonnet 5+) use effort control instead of manual budget_tokens.
* Models supporting adaptive thinking (Opus 4.6+, Sonnet 4.6+) use effort control instead of manual budget_tokens.
*/
function configureReasoning(
anthropicInput: AnthropicClientOptions & { max_tokens?: number },

View file

@ -121,6 +121,39 @@ describe('getLLMConfig', () => {
});
});
it('should add "context-1m" beta header for claude-sonnet-4-6 model', () => {
const modelOptions = {
model: 'claude-sonnet-4-6',
promptCache: true,
};
const result = getLLMConfig('test-key', { modelOptions });
const clientOptions = result.llmConfig.clientOptions;
expect(clientOptions?.defaultHeaders).toBeDefined();
expect(clientOptions?.defaultHeaders).toHaveProperty('anthropic-beta');
const defaultHeaders = clientOptions?.defaultHeaders as Record<string, string>;
expect(defaultHeaders['anthropic-beta']).toBe('context-1m-2025-08-07');
expect(result.llmConfig.promptCache).toBe(true);
});
it('should add "context-1m" beta header for claude-sonnet-4-6 model formats', () => {
const modelVariations = [
'claude-sonnet-4-6',
'claude-sonnet-4-6-20260101',
'anthropic/claude-sonnet-4-6',
];
modelVariations.forEach((model) => {
const modelOptions = { model, promptCache: true };
const result = getLLMConfig('test-key', { modelOptions });
const clientOptions = result.llmConfig.clientOptions;
expect(clientOptions?.defaultHeaders).toBeDefined();
expect(clientOptions?.defaultHeaders).toHaveProperty('anthropic-beta');
const defaultHeaders = clientOptions?.defaultHeaders as Record<string, string>;
expect(defaultHeaders['anthropic-beta']).toBe('context-1m-2025-08-07');
expect(result.llmConfig.promptCache).toBe(true);
});
});
it('should pass promptCache boolean for claude-opus-4-5 model (no beta header needed)', () => {
const modelOptions = {
model: 'claude-opus-4-5',
@ -963,6 +996,51 @@ describe('getLLMConfig', () => {
});
});
it('should use adaptive thinking for Sonnet 4.6 instead of enabled + budget_tokens', () => {
const result = getLLMConfig('test-key', {
modelOptions: {
model: 'claude-sonnet-4-6',
thinking: true,
thinkingBudget: 10000,
},
});
expect((result.llmConfig.thinking as unknown as { type: string }).type).toBe('adaptive');
expect(result.llmConfig.thinking).not.toHaveProperty('budget_tokens');
expect(result.llmConfig.maxTokens).toBe(64000);
});
it('should set effort via output_config for Sonnet 4.6', () => {
const result = getLLMConfig('test-key', {
modelOptions: {
model: 'claude-sonnet-4-6',
thinking: true,
effort: AnthropicEffort.high,
},
});
expect((result.llmConfig.thinking as unknown as { type: string }).type).toBe('adaptive');
expect(result.llmConfig.invocationKwargs).toHaveProperty('output_config');
expect(result.llmConfig.invocationKwargs?.output_config).toEqual({
effort: AnthropicEffort.high,
});
});
it('should exclude topP/topK for Sonnet 4.6 with adaptive thinking', () => {
const result = getLLMConfig('test-key', {
modelOptions: {
model: 'claude-sonnet-4-6',
thinking: true,
topP: 0.9,
topK: 40,
},
});
expect((result.llmConfig.thinking as unknown as { type: string }).type).toBe('adaptive');
expect(result.llmConfig).not.toHaveProperty('topP');
expect(result.llmConfig).not.toHaveProperty('topK');
});
it('should NOT set adaptive thinking or effort for non-adaptive models', () => {
const nonAdaptiveModels = [
'claude-opus-4-5',

View file

@ -5,5 +5,6 @@ export * from './filter';
export * from './mistral/crud';
export * from './ocr';
export * from './parse';
export * from './rag';
export * from './validation';
export * from './text';

View file

@ -0,0 +1,150 @@
jest.mock('@librechat/data-schemas', () => ({
logger: {
debug: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
},
}));
jest.mock('~/crypto/jwt', () => ({
generateShortLivedToken: jest.fn().mockReturnValue('mock-jwt-token'),
}));
jest.mock('axios', () => ({
delete: jest.fn(),
interceptors: {
request: { use: jest.fn(), eject: jest.fn() },
response: { use: jest.fn(), eject: jest.fn() },
},
}));
import axios from 'axios';
import { deleteRagFile } from './rag';
import { logger } from '@librechat/data-schemas';
import { generateShortLivedToken } from '~/crypto/jwt';
const mockedAxios = axios as jest.Mocked<typeof axios>;
const mockedLogger = logger as jest.Mocked<typeof logger>;
const mockedGenerateShortLivedToken = generateShortLivedToken as jest.MockedFunction<
typeof generateShortLivedToken
>;
describe('deleteRagFile', () => {
const originalEnv = process.env;
beforeEach(() => {
jest.clearAllMocks();
process.env = { ...originalEnv };
process.env.RAG_API_URL = 'http://localhost:8000';
});
afterEach(() => {
process.env = originalEnv;
});
describe('when file is embedded and RAG_API_URL is configured', () => {
it('should delete the document from RAG API successfully', async () => {
const file = { file_id: 'file-123', embedded: true };
mockedAxios.delete.mockResolvedValueOnce({ status: 200 });
const result = await deleteRagFile({ userId: 'user123', file });
expect(result).toBe(true);
expect(mockedGenerateShortLivedToken).toHaveBeenCalledWith('user123');
expect(mockedAxios.delete).toHaveBeenCalledWith('http://localhost:8000/documents', {
headers: {
Authorization: 'Bearer mock-jwt-token',
'Content-Type': 'application/json',
accept: 'application/json',
},
data: ['file-123'],
});
expect(mockedLogger.debug).toHaveBeenCalledWith(
'[deleteRagFile] Successfully deleted document file-123 from RAG API',
);
});
it('should return true and log warning when document is not found (404)', async () => {
const file = { file_id: 'file-not-found', embedded: true };
const error = new Error('Not Found') as Error & { response?: { status?: number } };
error.response = { status: 404 };
mockedAxios.delete.mockRejectedValueOnce(error);
const result = await deleteRagFile({ userId: 'user123', file });
expect(result).toBe(true);
expect(mockedLogger.warn).toHaveBeenCalledWith(
'[deleteRagFile] Document file-not-found not found in RAG API, may have been deleted already',
);
});
it('should return false and log error on other errors', async () => {
const file = { file_id: 'file-error', embedded: true };
const error = new Error('Server Error') as Error & { response?: { status?: number } };
error.response = { status: 500 };
mockedAxios.delete.mockRejectedValueOnce(error);
const result = await deleteRagFile({ userId: 'user123', file });
expect(result).toBe(false);
expect(mockedLogger.error).toHaveBeenCalledWith(
'[deleteRagFile] Error deleting document from RAG API:',
'Server Error',
);
});
});
describe('when file is not embedded', () => {
it('should skip RAG deletion and return true', async () => {
const file = { file_id: 'file-123', embedded: false };
const result = await deleteRagFile({ userId: 'user123', file });
expect(result).toBe(true);
expect(mockedAxios.delete).not.toHaveBeenCalled();
expect(mockedGenerateShortLivedToken).not.toHaveBeenCalled();
});
it('should skip RAG deletion when embedded is undefined', async () => {
const file = { file_id: 'file-123' };
const result = await deleteRagFile({ userId: 'user123', file });
expect(result).toBe(true);
expect(mockedAxios.delete).not.toHaveBeenCalled();
});
});
describe('when RAG_API_URL is not configured', () => {
it('should skip RAG deletion and return true', async () => {
delete process.env.RAG_API_URL;
const file = { file_id: 'file-123', embedded: true };
const result = await deleteRagFile({ userId: 'user123', file });
expect(result).toBe(true);
expect(mockedAxios.delete).not.toHaveBeenCalled();
});
});
describe('userId handling', () => {
it('should return false when no userId is provided', async () => {
const file = { file_id: 'file-123', embedded: true };
const result = await deleteRagFile({ userId: '', file });
expect(result).toBe(false);
expect(mockedLogger.error).toHaveBeenCalledWith('[deleteRagFile] No user ID provided');
expect(mockedAxios.delete).not.toHaveBeenCalled();
});
it('should return false when userId is undefined', async () => {
const file = { file_id: 'file-123', embedded: true };
const result = await deleteRagFile({ userId: undefined as unknown as string, file });
expect(result).toBe(false);
expect(mockedLogger.error).toHaveBeenCalledWith('[deleteRagFile] No user ID provided');
});
});
});

View file

@ -0,0 +1,60 @@
import axios from 'axios';
import { logger } from '@librechat/data-schemas';
import { generateShortLivedToken } from '~/crypto/jwt';
interface DeleteRagFileParams {
/** The user ID. Required for authentication. If not provided, the function returns false and logs an error. */
userId: string;
/** The file object. Must have `embedded` and `file_id` properties. */
file: {
file_id: string;
embedded?: boolean;
};
}
/**
* Deletes embedded document(s) from the RAG API.
* This is a shared utility function used by all file storage strategies
* (S3, Azure, Firebase, Local) to delete RAG embeddings when a file is deleted.
*
* @param params - The parameters object.
* @param params.userId - The user ID for authentication.
* @param params.file - The file object. Must have `embedded` and `file_id` properties.
* @returns Returns true if deletion was successful or skipped, false if there was an error.
*/
export async function deleteRagFile({ userId, file }: DeleteRagFileParams): Promise<boolean> {
if (!file.embedded || !process.env.RAG_API_URL) {
return true;
}
if (!userId) {
logger.error('[deleteRagFile] No user ID provided');
return false;
}
const jwtToken = generateShortLivedToken(userId);
try {
await axios.delete(`${process.env.RAG_API_URL}/documents`, {
headers: {
Authorization: `Bearer ${jwtToken}`,
'Content-Type': 'application/json',
accept: 'application/json',
},
data: [file.file_id],
});
logger.debug(`[deleteRagFile] Successfully deleted document ${file.file_id} from RAG API`);
return true;
} catch (error) {
const axiosError = error as { response?: { status?: number }; message?: string };
if (axiosError.response?.status === 404) {
logger.warn(
`[deleteRagFile] Document ${file.file_id} not found in RAG API, may have been deleted already`,
);
return true;
} else {
logger.error('[deleteRagFile] Error deleting document from RAG API:', axiosError.message);
return false;
}
}
}

View file

@ -4,6 +4,8 @@ import { MCPConnection } from './connection';
import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry';
import type * as t from './types';
const CONNECT_CONCURRENCY = 3;
/**
* Manages MCP connections with lazy loading and reconnection.
* Maintains a pool of connections and handles connection lifecycle management.
@ -73,6 +75,7 @@ export class ConnectionsRepository {
{
serverName,
serverConfig,
useSSRFProtection: MCPServersRegistry.getInstance().shouldEnableSSRFProtection(),
},
this.oauthOpts,
);
@ -83,9 +86,17 @@ export class ConnectionsRepository {
/** Gets or creates connections for multiple servers concurrently */
async getMany(serverNames: string[]): Promise<Map<string, MCPConnection>> {
const connectionPromises = serverNames.map(async (name) => [name, await this.get(name)]);
const connections = await Promise.all(connectionPromises);
return new Map((connections as [string, MCPConnection][]).filter((v) => !!v[1]));
const results: [string, MCPConnection | null][] = [];
for (let i = 0; i < serverNames.length; i += CONNECT_CONCURRENCY) {
const batch = serverNames.slice(i, i + CONNECT_CONCURRENCY);
const batchResults = await Promise.all(
batch.map(
async (name): Promise<[string, MCPConnection | null]> => [name, await this.get(name)],
),
);
results.push(...batchResults);
}
return new Map(results.filter((v): v is [string, MCPConnection] => v[1] != null));
}
/** Returns all currently loaded connections without creating new ones */

View file

@ -29,6 +29,7 @@ export class MCPConnectionFactory {
protected readonly serverConfig: t.MCPOptions;
protected readonly logPrefix: string;
protected readonly useOAuth: boolean;
protected readonly useSSRFProtection: boolean;
// OAuth-related properties (only set when useOAuth is true)
protected readonly userId?: string;
@ -72,6 +73,7 @@ export class MCPConnectionFactory {
serverConfig: this.serverConfig,
userId: this.userId,
oauthTokens,
useSSRFProtection: this.useSSRFProtection,
});
const oauthHandler = async () => {
@ -146,6 +148,7 @@ export class MCPConnectionFactory {
serverConfig: this.serverConfig,
userId: this.userId,
oauthTokens: null,
useSSRFProtection: this.useSSRFProtection,
});
unauthConnection.on('oauthRequired', () => {
@ -189,6 +192,7 @@ export class MCPConnectionFactory {
});
this.serverName = basic.serverName;
this.useOAuth = !!oauth?.useOAuth;
this.useSSRFProtection = basic.useSSRFProtection === true;
this.connectionTimeout = oauth?.connectionTimeout;
this.logPrefix = oauth?.user
? `[MCP][${basic.serverName}][${oauth.user.id}]`
@ -213,6 +217,7 @@ export class MCPConnectionFactory {
serverConfig: this.serverConfig,
userId: this.userId,
oauthTokens,
useSSRFProtection: this.useSSRFProtection,
});
let cleanupOAuthHandlers: (() => void) | null = null;
@ -293,38 +298,45 @@ export class MCPConnectionFactory {
const oauthHandler = async (data: { serverUrl?: string }) => {
logger.info(`${this.logPrefix} oauthRequired event received`);
// If we just want to initiate OAuth and return, handle it differently
if (this.returnOnOAuth) {
try {
const config = this.serverConfig;
const { authorizationUrl, flowId, flowMetadata } =
await MCPOAuthHandler.initiateOAuthFlow(
this.serverName,
data.serverUrl || '',
this.userId!,
config?.oauth_headers ?? {},
config?.oauth,
const flowId = MCPOAuthHandler.generateFlowId(this.userId!, this.serverName);
const existingFlow = await this.flowManager!.getFlowState(flowId, 'mcp_oauth');
if (existingFlow?.status === 'PENDING') {
logger.debug(
`${this.logPrefix} PENDING OAuth flow already exists, skipping new initiation`,
);
connection.emit('oauthFailed', new Error('OAuth flow initiated - return early'));
return;
}
// Delete any existing flow state to ensure we start fresh
// This prevents stale codeVerifier issues when re-authenticating
await this.flowManager!.deleteFlow(flowId, 'mcp_oauth');
const {
authorizationUrl,
flowId: newFlowId,
flowMetadata,
} = await MCPOAuthHandler.initiateOAuthFlow(
this.serverName,
data.serverUrl || '',
this.userId!,
config?.oauth_headers ?? {},
config?.oauth,
);
// Create the flow state so the OAuth callback can find it
// We spawn this in the background without waiting for it
// Pass signal so the flow can be aborted if the request is cancelled
this.flowManager!.createFlow(flowId, 'mcp_oauth', flowMetadata, this.signal).catch(() => {
// The OAuth callback will resolve this flow, so we expect it to timeout here
// or it will be aborted if the request is cancelled - both are fine
});
if (existingFlow) {
await this.flowManager!.deleteFlow(newFlowId, 'mcp_oauth');
}
this.flowManager!.createFlow(newFlowId, 'mcp_oauth', flowMetadata, this.signal).catch(
() => {},
);
if (this.oauthStart) {
logger.info(`${this.logPrefix} OAuth flow started, issuing authorization URL`);
await this.oauthStart(authorizationUrl);
}
// Emit oauthFailed to signal that connection should not proceed
// but OAuth was successfully initiated
connection.emit('oauthFailed', new Error('OAuth flow initiated - return early'));
return;
} catch (error) {
@ -386,11 +398,9 @@ export class MCPConnectionFactory {
logger.error(`${this.logPrefix} Failed to establish connection.`);
}
// Handles connection attempts with retry logic and OAuth error handling
private async connectTo(connection: MCPConnection): Promise<void> {
const maxAttempts = 3;
let attempts = 0;
let oauthHandled = false;
while (attempts < maxAttempts) {
try {
@ -403,22 +413,6 @@ export class MCPConnectionFactory {
attempts++;
if (this.useOAuth && this.isOAuthError(error)) {
// For returnOnOAuth mode, let the event handler (handleOAuthEvents) deal with OAuth
// We just need to stop retrying and let the error propagate
if (this.returnOnOAuth) {
logger.info(
`${this.logPrefix} OAuth required (return on OAuth mode), stopping retries`,
);
throw error;
}
// Normal flow - wait for OAuth to complete
if (this.oauthStart && !oauthHandled) {
oauthHandled = true;
logger.info(`${this.logPrefix} Handling OAuth`);
await this.handleOAuthRequired();
}
// Don't retry on OAuth errors - just throw
logger.info(`${this.logPrefix} OAuth required, stopping connection attempts`);
throw error;
}
@ -494,26 +488,15 @@ export class MCPConnectionFactory {
/** Check if there's already an ongoing OAuth flow for this flowId */
const existingFlow = await this.flowManager.getFlowState(flowId, 'mcp_oauth');
// If any flow exists (PENDING, COMPLETED, FAILED), cancel it and start fresh
// This ensures the user always gets a new OAuth URL instead of waiting for stale flows
if (existingFlow) {
logger.debug(
`${this.logPrefix} Found existing OAuth flow (status: ${existingFlow.status}), cancelling to start fresh`,
`${this.logPrefix} Found existing OAuth flow (status: ${existingFlow.status}), cleaning up to start fresh`,
);
try {
if (existingFlow.status === 'PENDING') {
await this.flowManager.failFlow(
flowId,
'mcp_oauth',
new Error('Cancelled for new OAuth request'),
);
} else {
await this.flowManager.deleteFlow(flowId, 'mcp_oauth');
}
await this.flowManager.deleteFlow(flowId, 'mcp_oauth');
} catch (error) {
logger.warn(`${this.logPrefix} Failed to cancel existing OAuth flow`, error);
logger.warn(`${this.logPrefix} Failed to clean up existing OAuth flow`, error);
}
// Continue to start a new flow below
}
logger.debug(`${this.logPrefix} Initiating new OAuth flow for ${this.serverName}...`);

View file

@ -102,7 +102,8 @@ export class MCPManager extends UserConnectionManager {
serverConfig.requiresOAuth || (serverConfig as t.ParsedServerConfig).oauthMetadata,
);
const basic: t.BasicConnectionOptions = { serverName, serverConfig };
const useSSRFProtection = MCPServersRegistry.getInstance().shouldEnableSSRFProtection();
const basic: t.BasicConnectionOptions = { serverName, serverConfig, useSSRFProtection };
if (!useOAuth) {
const result = await MCPConnectionFactory.discoverTools(basic);

View file

@ -117,6 +117,7 @@ export abstract class UserConnectionManager {
{
serverName: serverName,
serverConfig: config,
useSSRFProtection: MCPServersRegistry.getInstance().shouldEnableSSRFProtection(),
},
{
useOAuth: true,

View file

@ -24,6 +24,7 @@ jest.mock('../connection');
const mockRegistryInstance = {
getServerConfig: jest.fn(),
getAllServerConfigs: jest.fn(),
shouldEnableSSRFProtection: jest.fn().mockReturnValue(false),
};
jest.mock('../registry/MCPServersRegistry', () => ({
@ -108,6 +109,7 @@ describe('ConnectionsRepository', () => {
{
serverName: 'server1',
serverConfig: mockServerConfigs.server1,
useSSRFProtection: false,
},
undefined,
);
@ -129,6 +131,7 @@ describe('ConnectionsRepository', () => {
{
serverName: 'server1',
serverConfig: mockServerConfigs.server1,
useSSRFProtection: false,
},
undefined,
);
@ -167,6 +170,7 @@ describe('ConnectionsRepository', () => {
{
serverName: 'server1',
serverConfig: configWithCachedAt,
useSSRFProtection: false,
},
undefined,
);

View file

@ -84,6 +84,7 @@ describe('MCPConnectionFactory', () => {
serverConfig: mockServerConfig,
userId: undefined,
oauthTokens: null,
useSSRFProtection: false,
});
expect(mockConnectionInstance.connect).toHaveBeenCalled();
});
@ -125,6 +126,7 @@ describe('MCPConnectionFactory', () => {
serverConfig: mockServerConfig,
userId: 'user123',
oauthTokens: mockTokens,
useSSRFProtection: false,
});
});
});
@ -184,6 +186,7 @@ describe('MCPConnectionFactory', () => {
serverConfig: mockServerConfig,
userId: 'user123',
oauthTokens: null,
useSSRFProtection: false,
});
expect(mockLogger.debug).toHaveBeenCalledWith(
expect.stringContaining('No existing tokens found or error loading tokens'),
@ -267,7 +270,54 @@ describe('MCPConnectionFactory', () => {
);
});
it('should delete existing flow before creating new OAuth flow to prevent stale codeVerifier', async () => {
it('should skip new OAuth flow initiation when a PENDING flow already exists (returnOnOAuth)', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
user: mockUser,
};
const oauthOptions: t.OAuthConnectionOptions = {
user: mockUser,
useOAuth: true,
returnOnOAuth: true,
oauthStart: jest.fn(),
flowManager: mockFlowManager,
};
mockFlowManager.getFlowState.mockResolvedValue({
status: 'PENDING',
type: 'mcp_oauth',
metadata: { codeVerifier: 'existing-verifier' },
createdAt: Date.now(),
});
mockConnectionInstance.isConnected.mockResolvedValue(false);
let oauthRequiredHandler: (data: Record<string, unknown>) => Promise<void>;
mockConnectionInstance.on.mockImplementation((event, handler) => {
if (event === 'oauthRequired') {
oauthRequiredHandler = handler as (data: Record<string, unknown>) => Promise<void>;
}
return mockConnectionInstance;
});
try {
await MCPConnectionFactory.create(basicOptions, oauthOptions);
} catch {
// Expected to fail
}
await oauthRequiredHandler!({ serverUrl: 'https://api.example.com' });
expect(mockMCPOAuthHandler.initiateOAuthFlow).not.toHaveBeenCalled();
expect(mockFlowManager.deleteFlow).not.toHaveBeenCalled();
expect(mockConnectionInstance.emit).toHaveBeenCalledWith(
'oauthFailed',
expect.objectContaining({ message: 'OAuth flow initiated - return early' }),
);
});
it('should delete stale flow and create new OAuth flow when existing flow is COMPLETED', async () => {
const basicOptions = {
serverName: 'test-server',
serverConfig: mockServerConfig,
@ -300,6 +350,12 @@ describe('MCPConnectionFactory', () => {
},
};
mockFlowManager.getFlowState.mockResolvedValue({
status: 'COMPLETED',
type: 'mcp_oauth',
metadata: { codeVerifier: 'old-verifier' },
createdAt: Date.now() - 60000,
});
mockMCPOAuthHandler.initiateOAuthFlow.mockResolvedValue(mockFlowData);
mockFlowManager.deleteFlow.mockResolvedValue(true);
mockFlowManager.createFlow.mockRejectedValue(new Error('Timeout expected'));
@ -316,21 +372,17 @@ describe('MCPConnectionFactory', () => {
try {
await MCPConnectionFactory.create(basicOptions, oauthOptions);
} catch {
// Expected to fail due to connection not established
// Expected to fail
}
await oauthRequiredHandler!({ serverUrl: 'https://api.example.com' });
// Verify deleteFlow was called with correct parameters
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('user123:test-server', 'mcp_oauth');
// Verify deleteFlow was called before createFlow
const deleteCallOrder = mockFlowManager.deleteFlow.mock.invocationCallOrder[0];
const createCallOrder = mockFlowManager.createFlow.mock.invocationCallOrder[0];
expect(deleteCallOrder).toBeLessThan(createCallOrder);
// Verify createFlow was called with fresh metadata
// 4th arg is the abort signal (undefined in this test since no signal was provided)
expect(mockFlowManager.createFlow).toHaveBeenCalledWith(
'user123:test-server',
'mcp_oauth',

View file

@ -33,6 +33,7 @@ const mockRegistryInstance = {
getServerConfig: jest.fn(),
getAllServerConfigs: jest.fn(),
getOAuthServers: jest.fn(),
shouldEnableSSRFProtection: jest.fn().mockReturnValue(false),
};
jest.mock('~/mcp/registry/MCPServersRegistry', () => ({

View file

@ -2,7 +2,12 @@
// zod.spec.ts
import { z } from 'zod';
import type { JsonSchemaType } from '@librechat/data-schemas';
import { resolveJsonSchemaRefs, convertJsonSchemaToZod, convertWithResolvedRefs } from '../zod';
import {
convertWithResolvedRefs,
convertJsonSchemaToZod,
resolveJsonSchemaRefs,
normalizeJsonSchema,
} from '../zod';
describe('convertJsonSchemaToZod', () => {
describe('integer type handling', () => {
@ -206,7 +211,7 @@ describe('convertJsonSchemaToZod', () => {
type: 'number' as const,
enum: [1, 2, 3, 5, 8, 13],
};
const zodSchema = convertWithResolvedRefs(schema as JsonSchemaType);
const zodSchema = convertWithResolvedRefs(schema as unknown as JsonSchemaType);
expect(zodSchema?.parse(1)).toBe(1);
expect(zodSchema?.parse(13)).toBe(13);
@ -1599,6 +1604,34 @@ describe('convertJsonSchemaToZod', () => {
expect(() => zodSchema?.parse(testData)).not.toThrow();
});
it('should strip $defs from the resolved output', () => {
const schemaWithDefs = {
type: 'object' as const,
properties: {
item: { $ref: '#/$defs/Item' },
},
$defs: {
Item: {
type: 'object' as const,
properties: {
name: { type: 'string' as const },
},
},
},
};
const resolved = resolveJsonSchemaRefs(schemaWithDefs);
// $defs should NOT be in the output — it was only used for resolution
expect(resolved).not.toHaveProperty('$defs');
// The $ref should be resolved inline
expect(resolved.properties?.item).toEqual({
type: 'object',
properties: {
name: { type: 'string' },
},
});
});
it('should handle various edge cases safely', () => {
// Test with null/undefined
expect(resolveJsonSchemaRefs(null as any)).toBeNull();
@ -2002,3 +2035,329 @@ describe('convertJsonSchemaToZod', () => {
});
});
});
describe('normalizeJsonSchema', () => {
it('should convert const to enum', () => {
const schema = { type: 'string', const: 'hello' } as any;
const result = normalizeJsonSchema(schema);
expect(result).toEqual({ type: 'string', enum: ['hello'] });
expect(result).not.toHaveProperty('const');
});
it('should preserve existing enum when const is also present', () => {
const schema = { type: 'string', const: 'hello', enum: ['hello', 'world'] } as any;
const result = normalizeJsonSchema(schema);
expect(result).toEqual({ type: 'string', enum: ['hello', 'world'] });
expect(result).not.toHaveProperty('const');
});
it('should handle non-string const values (number, boolean, null)', () => {
expect(normalizeJsonSchema({ type: 'number', const: 42 } as any)).toEqual({
type: 'number',
enum: [42],
});
expect(normalizeJsonSchema({ type: 'boolean', const: true } as any)).toEqual({
type: 'boolean',
enum: [true],
});
expect(normalizeJsonSchema({ type: 'string', const: null } as any)).toEqual({
type: 'string',
enum: [null],
});
});
it('should recursively normalize nested object properties', () => {
const schema = {
type: 'object',
properties: {
mode: { type: 'string', const: 'advanced' },
count: { type: 'number', const: 5 },
name: { type: 'string', description: 'A name' },
},
} as any;
const result = normalizeJsonSchema(schema);
expect(result.properties.mode).toEqual({ type: 'string', enum: ['advanced'] });
expect(result.properties.count).toEqual({ type: 'number', enum: [5] });
expect(result.properties.name).toEqual({ type: 'string', description: 'A name' });
});
it('should normalize inside oneOf/anyOf/allOf arrays', () => {
const schema = {
type: 'object',
oneOf: [
{ type: 'object', properties: { kind: { type: 'string', const: 'A' } } },
{ type: 'object', properties: { kind: { type: 'string', const: 'B' } } },
],
anyOf: [{ type: 'string', const: 'x' }],
allOf: [{ type: 'number', const: 1 }],
} as any;
const result = normalizeJsonSchema(schema);
expect(result.oneOf[0].properties.kind).toEqual({ type: 'string', enum: ['A'] });
expect(result.oneOf[1].properties.kind).toEqual({ type: 'string', enum: ['B'] });
expect(result.anyOf[0]).toEqual({ type: 'string', enum: ['x'] });
expect(result.allOf[0]).toEqual({ type: 'number', enum: [1] });
});
it('should normalize array items with const', () => {
const schema = {
type: 'array',
items: { type: 'string', const: 'fixed' },
} as any;
const result = normalizeJsonSchema(schema);
expect(result.items).toEqual({ type: 'string', enum: ['fixed'] });
});
it('should normalize additionalProperties with const', () => {
const schema = {
type: 'object',
additionalProperties: { type: 'string', const: 'val' },
} as any;
const result = normalizeJsonSchema(schema);
expect(result.additionalProperties).toEqual({ type: 'string', enum: ['val'] });
});
it('should handle null, undefined, and primitive inputs safely', () => {
expect(normalizeJsonSchema(null as any)).toBeNull();
expect(normalizeJsonSchema(undefined as any)).toBeUndefined();
expect(normalizeJsonSchema('string' as any)).toBe('string');
expect(normalizeJsonSchema(42 as any)).toBe(42);
expect(normalizeJsonSchema(true as any)).toBe(true);
});
it('should be a no-op when no const is present', () => {
const schema = {
type: 'object',
properties: {
name: { type: 'string', description: 'Name' },
age: { type: 'number' },
tags: { type: 'array', items: { type: 'string' } },
},
required: ['name'],
} as any;
const result = normalizeJsonSchema(schema);
expect(result).toEqual(schema);
});
it('should handle a Tavily-like schema pattern with const', () => {
const schema = {
type: 'object',
properties: {
query: {
type: 'string',
description: 'The search query',
},
search_depth: {
type: 'string',
const: 'advanced',
description: 'The depth of the search',
},
topic: {
type: 'string',
enum: ['general', 'news'],
description: 'The search topic',
},
include_answer: {
type: 'boolean',
const: true,
},
max_results: {
type: 'number',
const: 5,
},
},
required: ['query'],
} as any;
const result = normalizeJsonSchema(schema);
// const fields should be converted to enum
expect(result.properties.search_depth).toEqual({
type: 'string',
enum: ['advanced'],
description: 'The depth of the search',
});
expect(result.properties.include_answer).toEqual({
type: 'boolean',
enum: [true],
});
expect(result.properties.max_results).toEqual({
type: 'number',
enum: [5],
});
// Existing enum should be preserved
expect(result.properties.topic).toEqual({
type: 'string',
enum: ['general', 'news'],
description: 'The search topic',
});
// Non-const fields should be unchanged
expect(result.properties.query).toEqual({
type: 'string',
description: 'The search query',
});
// Top-level fields preserved
expect(result.required).toEqual(['query']);
expect(result.type).toBe('object');
});
it('should handle arrays at the top level', () => {
const schemas = [
{ type: 'string', const: 'a' },
{ type: 'number', const: 1 },
] as any;
const result = normalizeJsonSchema(schemas);
expect(result).toEqual([
{ type: 'string', enum: ['a'] },
{ type: 'number', enum: [1] },
]);
});
it('should strip vendor extension fields (x-* prefixed keys)', () => {
const schema = {
type: 'object',
properties: {
travelMode: {
type: 'string',
enum: ['DRIVE', 'BICYCLE', 'TRANSIT', 'WALK'],
'x-google-enum-descriptions': ['By car', 'By bicycle', 'By public transit', 'By walking'],
description: 'Mode of travel',
},
},
} as any;
const result = normalizeJsonSchema(schema);
expect(result.properties.travelMode).toEqual({
type: 'string',
enum: ['DRIVE', 'BICYCLE', 'TRANSIT', 'WALK'],
description: 'Mode of travel',
});
expect(result.properties.travelMode).not.toHaveProperty('x-google-enum-descriptions');
});
it('should strip x-* fields at all nesting levels', () => {
const schema = {
type: 'object',
'x-custom-root': true,
properties: {
outer: {
type: 'object',
'x-custom-outer': 'value',
properties: {
inner: {
type: 'string',
'x-custom-inner': 42,
},
},
},
arr: {
type: 'array',
items: {
type: 'string',
'x-item-meta': 'something',
},
},
},
} as any;
const result = normalizeJsonSchema(schema);
expect(result).not.toHaveProperty('x-custom-root');
expect(result.properties.outer).not.toHaveProperty('x-custom-outer');
expect(result.properties.outer.properties.inner).not.toHaveProperty('x-custom-inner');
expect(result.properties.arr.items).not.toHaveProperty('x-item-meta');
// Standard fields should be preserved
expect(result.type).toBe('object');
expect(result.properties.outer.type).toBe('object');
expect(result.properties.outer.properties.inner.type).toBe('string');
expect(result.properties.arr.items.type).toBe('string');
});
it('should strip $defs and definitions as a safety net', () => {
const schema = {
type: 'object',
properties: {
name: { type: 'string' },
},
$defs: {
SomeType: { type: 'string' },
},
} as any;
const result = normalizeJsonSchema(schema);
expect(result).not.toHaveProperty('$defs');
expect(result.type).toBe('object');
expect(result.properties.name).toEqual({ type: 'string' });
});
it('should strip x-* fields inside oneOf/anyOf/allOf', () => {
const schema = {
type: 'object',
oneOf: [
{ type: 'string', 'x-meta': 'a' },
{ type: 'number', 'x-meta': 'b' },
],
} as any;
const result = normalizeJsonSchema(schema);
expect(result.oneOf[0]).toEqual({ type: 'string' });
expect(result.oneOf[1]).toEqual({ type: 'number' });
});
it('should handle a Google Maps MCP-like schema with $defs and x-google-enum-descriptions', () => {
const schema = {
type: 'object',
properties: {
origin: { type: 'string', description: 'Starting address' },
destination: { type: 'string', description: 'Ending address' },
travelMode: {
type: 'string',
enum: ['DRIVE', 'BICYCLE', 'TRANSIT', 'WALK'],
'x-google-enum-descriptions': ['By car', 'By bicycle', 'By public transit', 'By walking'],
},
waypoints: {
type: 'array',
items: { $ref: '#/$defs/Waypoint' },
},
},
required: ['origin', 'destination'],
$defs: {
Waypoint: {
type: 'object',
properties: {
location: { type: 'string' },
stopover: { type: 'boolean' },
},
},
},
} as any;
// First resolve refs, then normalize
const resolved = resolveJsonSchemaRefs(schema);
const result = normalizeJsonSchema(resolved);
// $defs should be stripped (by both resolveJsonSchemaRefs and normalizeJsonSchema)
expect(result).not.toHaveProperty('$defs');
// x-google-enum-descriptions should be stripped
expect(result.properties.travelMode).not.toHaveProperty('x-google-enum-descriptions');
// $ref should be resolved inline
expect(result.properties.waypoints.items).not.toHaveProperty('$ref');
expect(result.properties.waypoints.items).toEqual({
type: 'object',
properties: {
location: { type: 'string' },
stopover: { type: 'boolean' },
},
});
// Standard fields preserved
expect(result.properties.travelMode.enum).toEqual(['DRIVE', 'BICYCLE', 'TRANSIT', 'WALK']);
expect(result.properties.origin).toEqual({ type: 'string', description: 'Starting address' });
});
});

View file

@ -20,6 +20,7 @@ import type {
import type { MCPOAuthTokens } from './oauth/types';
import { withTimeout } from '~/utils/promise';
import type * as t from './types';
import { createSSRFSafeUndiciConnect, resolveHostnameSSRF } from '~/auth';
import { sanitizeUrlForLogging } from './utils';
import { mcpConfig } from './mcpConfig';
@ -213,6 +214,7 @@ interface MCPConnectionParams {
serverConfig: t.MCPOptions;
userId?: string;
oauthTokens?: MCPOAuthTokens | null;
useSSRFProtection?: boolean;
}
export class MCPConnection extends EventEmitter {
@ -233,6 +235,7 @@ export class MCPConnection extends EventEmitter {
private oauthTokens?: MCPOAuthTokens | null;
private requestHeaders?: Record<string, string> | null;
private oauthRequired = false;
private readonly useSSRFProtection: boolean;
iconPath?: string;
timeout?: number;
url?: string;
@ -263,6 +266,7 @@ export class MCPConnection extends EventEmitter {
this.options = params.serverConfig;
this.serverName = params.serverName;
this.userId = params.userId;
this.useSSRFProtection = params.useSSRFProtection === true;
this.iconPath = params.serverConfig.iconPath;
this.timeout = params.serverConfig.timeout;
this.lastPingTime = Date.now();
@ -301,6 +305,7 @@ export class MCPConnection extends EventEmitter {
getHeaders: () => Record<string, string> | null | undefined,
timeout?: number,
): (input: UndiciRequestInfo, init?: UndiciRequestInit) => Promise<UndiciResponse> {
const ssrfConnect = this.useSSRFProtection ? createSSRFSafeUndiciConnect() : undefined;
return function customFetch(
input: UndiciRequestInfo,
init?: UndiciRequestInit,
@ -310,6 +315,7 @@ export class MCPConnection extends EventEmitter {
const agent = new Agent({
bodyTimeout: effectiveTimeout,
headersTimeout: effectiveTimeout,
...(ssrfConnect != null ? { connect: ssrfConnect } : {}),
});
if (!requestHeaders) {
return undiciFetch(input, { ...init, dispatcher: agent });
@ -342,7 +348,7 @@ export class MCPConnection extends EventEmitter {
logger.error(`${this.getLogPrefix()} ${errorContext}: ${errorMessage}`);
}
private constructTransport(options: t.MCPOptions): Transport {
private async constructTransport(options: t.MCPOptions): Promise<Transport> {
try {
let type: t.MCPOptions['type'];
if (isStdioOptions(options)) {
@ -378,6 +384,15 @@ export class MCPConnection extends EventEmitter {
throw new Error('Invalid options for websocket transport.');
}
this.url = options.url;
if (this.useSSRFProtection) {
const wsHostname = new URL(options.url).hostname;
const isSSRF = await resolveHostnameSSRF(wsHostname);
if (isSSRF) {
throw new Error(
`SSRF protection: WebSocket host "${wsHostname}" resolved to a private/reserved IP address`,
);
}
}
return new WebSocketClientTransport(new URL(options.url));
case 'sse': {
@ -402,6 +417,7 @@ export class MCPConnection extends EventEmitter {
* The connect timeout is extended because proxies may delay initial response.
*/
const sseTimeout = this.timeout || SSE_CONNECT_TIMEOUT;
const ssrfConnect = this.useSSRFProtection ? createSSRFSafeUndiciConnect() : undefined;
const transport = new SSEClientTransport(url, {
requestInit: {
/** User/OAuth headers override SSE defaults */
@ -420,6 +436,7 @@ export class MCPConnection extends EventEmitter {
/** Extended keep-alive for long-lived SSE connections */
keepAliveTimeout: sseTimeout,
keepAliveMaxTimeout: sseTimeout * 2,
...(ssrfConnect != null ? { connect: ssrfConnect } : {}),
});
return undiciFetch(url, {
...init,
@ -542,7 +559,11 @@ export class MCPConnection extends EventEmitter {
}
this.isReconnecting = true;
const backoffDelay = (attempt: number) => Math.min(1000 * Math.pow(2, attempt), 30000);
const backoffDelay = (attempt: number) => {
const base = Math.min(1000 * Math.pow(2, attempt), 30000);
const jitter = Math.floor(Math.random() * 1000); // up to 1s of random jitter
return base + jitter;
};
try {
while (
@ -629,7 +650,7 @@ export class MCPConnection extends EventEmitter {
}
}
this.transport = this.constructTransport(this.options);
this.transport = await this.constructTransport(this.options);
this.setupTransportDebugHandlers();
const connectTimeout = this.options.initTimeout ?? 120000;

View file

@ -336,6 +336,69 @@ describe('OAuthReconnectionManager', () => {
});
});
describe('reconnection staggering', () => {
let reconnectionTracker: OAuthReconnectionTracker;
beforeEach(async () => {
jest.useFakeTimers();
reconnectionTracker = new OAuthReconnectionTracker();
reconnectionManager = await OAuthReconnectionManager.createInstance(
flowManager,
tokenMethods,
reconnectionTracker,
);
});
afterEach(() => {
jest.useRealTimers();
});
it('should stagger reconnection attempts for multiple servers', async () => {
const userId = 'user-123';
const oauthServers = new Set(['server1', 'server2', 'server3']);
(mockRegistryInstance.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
// All servers have valid tokens and are not connected
tokenMethods.findToken.mockImplementation(async ({ identifier }) => {
return {
userId,
identifier,
expiresAt: new Date(Date.now() + 3600000),
} as unknown as MCPOAuthTokens;
});
const mockNewConnection = {
isConnected: jest.fn().mockResolvedValue(true),
disconnect: jest.fn(),
};
mockMCPManager.getUserConnection.mockResolvedValue(
mockNewConnection as unknown as MCPConnection,
);
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue(
{} as unknown as MCPOptions,
);
await reconnectionManager.reconnectServers(userId);
// Only the first server should have been attempted immediately
expect(mockMCPManager.getUserConnection).toHaveBeenCalledTimes(1);
expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith(
expect.objectContaining({ serverName: 'server1' }),
);
// After advancing all timers, all servers should have been attempted
await jest.runAllTimersAsync();
expect(mockMCPManager.getUserConnection).toHaveBeenCalledTimes(3);
expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith(
expect.objectContaining({ serverName: 'server2' }),
);
expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith(
expect.objectContaining({ serverName: 'server3' }),
);
});
});
describe('reconnection timeout behavior', () => {
let reconnectionTracker: OAuthReconnectionTracker;

View file

@ -7,6 +7,7 @@ import { MCPManager } from '~/mcp/MCPManager';
import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry';
const DEFAULT_CONNECTION_TIMEOUT_MS = 10_000; // ms
const RECONNECT_STAGGER_MS = 500; // ms between each server reconnection
export class OAuthReconnectionManager {
private static instance: OAuthReconnectionManager | null = null;
@ -84,9 +85,14 @@ export class OAuthReconnectionManager {
this.reconnectionsTracker.setActive(userId, serverName);
}
// 3. attempt to reconnect the servers
for (const serverName of serversToReconnect) {
void this.tryReconnect(userId, serverName);
// 3. attempt to reconnect the servers with staggered delays to avoid connection storms
for (let i = 0; i < serversToReconnect.length; i++) {
const serverName = serversToReconnect[i];
if (i === 0) {
void this.tryReconnect(userId, serverName);
} else {
setTimeout(() => void this.tryReconnect(userId, serverName), i * RECONNECT_STAGGER_MS);
}
}
}

View file

@ -18,6 +18,7 @@ export class MCPServerInspector {
private readonly serverName: string,
private readonly config: t.ParsedServerConfig,
private connection: MCPConnection | undefined,
private readonly useSSRFProtection: boolean = false,
) {}
/**
@ -42,8 +43,9 @@ export class MCPServerInspector {
throw new MCPDomainNotAllowedError(domain ?? 'unknown');
}
const useSSRFProtection = !Array.isArray(allowedDomains) || allowedDomains.length === 0;
const start = Date.now();
const inspector = new MCPServerInspector(serverName, rawConfig, connection);
const inspector = new MCPServerInspector(serverName, rawConfig, connection, useSSRFProtection);
await inspector.inspectServer();
inspector.config.initDuration = Date.now() - start;
return inspector.config;
@ -59,6 +61,7 @@ export class MCPServerInspector {
this.connection = await MCPConnectionFactory.create({
serverName: this.serverName,
serverConfig: this.config,
useSSRFProtection: this.useSSRFProtection,
});
}

View file

@ -77,6 +77,15 @@ export class MCPServersRegistry {
return MCPServersRegistry.instance;
}
public getAllowedDomains(): string[] | null | undefined {
return this.allowedDomains;
}
/** Returns true when no explicit allowedDomains allowlist is configured, enabling SSRF TOCTOU protection */
public shouldEnableSSRFProtection(): boolean {
return !Array.isArray(this.allowedDomains) || this.allowedDomains.length === 0;
}
public async getServerConfig(
serverName: string,
userId?: string,

View file

@ -276,6 +276,7 @@ describe('MCPServerInspector', () => {
expect(MCPConnectionFactory.create).toHaveBeenCalledWith({
serverName: 'test_server',
serverConfig: expect.objectContaining({ type: 'stdio', command: 'node' }),
useSSRFProtection: true,
});
// Verify temporary connection was disconnected

View file

@ -166,6 +166,7 @@ export type AddServerResult = {
export interface BasicConnectionOptions {
serverName: string;
serverConfig: MCPOptions;
useSSRFProtection?: boolean;
}
export interface OAuthConnectionOptions {

View file

@ -203,9 +203,9 @@ export function resolveJsonSchemaRefs<T extends Record<string, unknown>>(
const result: Record<string, unknown> = {};
for (const [key, value] of Object.entries(schema)) {
// Skip $defs/definitions at root level to avoid infinite recursion
if ((key === '$defs' || key === 'definitions') && !visited.size) {
result[key] = value;
// Skip $defs/definitions — they are only used for resolving $ref and
// should not appear in the resolved output (e.g. Google/Gemini API rejects them).
if (key === '$defs' || key === 'definitions') {
continue;
}
@ -248,6 +248,80 @@ export function resolveJsonSchemaRefs<T extends Record<string, unknown>>(
return result as T;
}
/**
* Recursively normalizes a JSON schema for LLM API compatibility.
*
* Transformations applied:
* - Converts `const` values to `enum` arrays (Gemini/Vertex AI rejects `const`)
* - Strips vendor extension fields (`x-*` prefixed keys, e.g. `x-google-enum-descriptions`)
* - Strips leftover `$defs`/`definitions` blocks that may survive ref resolution
*
* @param schema - The JSON schema to normalize
* @returns The normalized schema
*/
export function normalizeJsonSchema<T extends Record<string, unknown>>(schema: T): T {
if (!schema || typeof schema !== 'object') {
return schema;
}
if (Array.isArray(schema)) {
return schema.map((item) =>
item && typeof item === 'object' ? normalizeJsonSchema(item) : item,
) as unknown as T;
}
const result: Record<string, unknown> = {};
for (const [key, value] of Object.entries(schema)) {
// Strip vendor extension fields (e.g. x-google-enum-descriptions) —
// these are valid in JSON Schema but rejected by Google/Gemini API.
if (key.startsWith('x-')) {
continue;
}
// Strip leftover $defs/definitions (should already be resolved by resolveJsonSchemaRefs,
// but strip as a safety net for schemas that bypass ref resolution).
if (key === '$defs' || key === 'definitions') {
continue;
}
if (key === 'const' && !('enum' in schema)) {
result['enum'] = [value];
continue;
}
if (key === 'const' && 'enum' in schema) {
// Skip `const` when `enum` already exists
continue;
}
if (key === 'properties' && value && typeof value === 'object' && !Array.isArray(value)) {
const newProps: Record<string, unknown> = {};
for (const [propKey, propValue] of Object.entries(value as Record<string, unknown>)) {
newProps[propKey] =
propValue && typeof propValue === 'object'
? normalizeJsonSchema(propValue as Record<string, unknown>)
: propValue;
}
result[key] = newProps;
} else if (
(key === 'items' || key === 'additionalProperties') &&
value &&
typeof value === 'object'
) {
result[key] = normalizeJsonSchema(value as Record<string, unknown>);
} else if ((key === 'oneOf' || key === 'anyOf' || key === 'allOf') && Array.isArray(value)) {
result[key] = value.map((item) =>
item && typeof item === 'object' ? normalizeJsonSchema(item) : item,
);
} else {
result[key] = value;
}
}
return result as T;
}
/**
* Converts a JSON Schema to a Zod schema.
*

View file

@ -0,0 +1,258 @@
import type { Redis, Cluster } from 'ioredis';
/**
* Integration tests for concurrency middleware atomic Lua scripts.
*
* Tests that the Lua-based check-and-increment / decrement operations
* are truly atomic and eliminate the INCR+check+DECR race window.
*
* Run with: USE_REDIS=true npx jest --config packages/api/jest.config.js concurrency.cache_integration
*/
describe('Concurrency Middleware Integration Tests', () => {
let originalEnv: NodeJS.ProcessEnv;
let ioredisClient: Redis | Cluster | null = null;
let checkAndIncrementPendingRequest: (
userId: string,
) => Promise<{ allowed: boolean; pendingRequests: number; limit: number }>;
let decrementPendingRequest: (userId: string) => Promise<void>;
const testPrefix = 'Concurrency-Integration-Test';
beforeAll(async () => {
originalEnv = { ...process.env };
process.env.USE_REDIS = process.env.USE_REDIS ?? 'true';
process.env.USE_REDIS_CLUSTER = process.env.USE_REDIS_CLUSTER ?? 'false';
process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379';
process.env.REDIS_KEY_PREFIX = testPrefix;
process.env.REDIS_PING_INTERVAL = '0';
process.env.REDIS_RETRY_MAX_ATTEMPTS = '5';
process.env.LIMIT_CONCURRENT_MESSAGES = 'true';
process.env.CONCURRENT_MESSAGE_MAX = '2';
jest.resetModules();
const { ioredisClient: client } = await import('../../cache/redisClients');
ioredisClient = client;
if (!ioredisClient) {
console.warn('Redis not available, skipping integration tests');
return;
}
// Import concurrency module after Redis client is available
const concurrency = await import('../concurrency');
checkAndIncrementPendingRequest = concurrency.checkAndIncrementPendingRequest;
decrementPendingRequest = concurrency.decrementPendingRequest;
});
afterEach(async () => {
if (!ioredisClient) {
return;
}
try {
const keys = await ioredisClient.keys(`${testPrefix}*`);
if (keys.length > 0) {
await Promise.all(keys.map((key) => ioredisClient!.del(key)));
}
} catch (error) {
console.warn('Error cleaning up test keys:', error);
}
});
afterAll(async () => {
if (ioredisClient) {
try {
await ioredisClient.quit();
} catch {
try {
ioredisClient.disconnect();
} catch {
// Ignore
}
}
}
process.env = originalEnv;
});
describe('Atomic Check and Increment', () => {
test('should allow requests within the concurrency limit', async () => {
if (!ioredisClient) {
return;
}
const userId = `user-allow-${Date.now()}`;
// First request - should be allowed (count = 1, limit = 2)
const result1 = await checkAndIncrementPendingRequest(userId);
expect(result1.allowed).toBe(true);
expect(result1.pendingRequests).toBe(1);
expect(result1.limit).toBe(2);
// Second request - should be allowed (count = 2, limit = 2)
const result2 = await checkAndIncrementPendingRequest(userId);
expect(result2.allowed).toBe(true);
expect(result2.pendingRequests).toBe(2);
});
test('should reject requests over the concurrency limit', async () => {
if (!ioredisClient) {
return;
}
const userId = `user-reject-${Date.now()}`;
// Fill up to the limit
await checkAndIncrementPendingRequest(userId);
await checkAndIncrementPendingRequest(userId);
// Third request - should be rejected (count would be 3, limit = 2)
const result = await checkAndIncrementPendingRequest(userId);
expect(result.allowed).toBe(false);
expect(result.pendingRequests).toBe(3); // Reports the count that was over-limit
});
test('should not leave stale counter after rejection (atomic rollback)', async () => {
if (!ioredisClient) {
return;
}
const userId = `user-rollback-${Date.now()}`;
// Fill up to the limit
await checkAndIncrementPendingRequest(userId);
await checkAndIncrementPendingRequest(userId);
// Attempt over-limit (should be rejected and atomically rolled back)
const rejected = await checkAndIncrementPendingRequest(userId);
expect(rejected.allowed).toBe(false);
// The key value should still be 2, not 3 — verify the Lua script decremented back
const key = `PENDING_REQ:${userId}`;
const rawValue = await ioredisClient.get(key);
expect(rawValue).toBe('2');
});
test('should handle concurrent requests atomically (no over-admission)', async () => {
if (!ioredisClient) {
return;
}
const userId = `user-concurrent-${Date.now()}`;
// Fire 20 concurrent requests for the same user (limit = 2)
const results = await Promise.all(
Array.from({ length: 20 }, () => checkAndIncrementPendingRequest(userId)),
);
const allowed = results.filter((r) => r.allowed);
const rejected = results.filter((r) => !r.allowed);
// Exactly 2 should be allowed (the concurrency limit)
expect(allowed.length).toBe(2);
expect(rejected.length).toBe(18);
// The key value should be exactly 2 after all atomic operations
const key = `PENDING_REQ:${userId}`;
const rawValue = await ioredisClient.get(key);
expect(rawValue).toBe('2');
// Clean up
await decrementPendingRequest(userId);
await decrementPendingRequest(userId);
});
});
describe('Atomic Decrement', () => {
test('should decrement pending requests', async () => {
if (!ioredisClient) {
return;
}
const userId = `user-decrement-${Date.now()}`;
await checkAndIncrementPendingRequest(userId);
await checkAndIncrementPendingRequest(userId);
// Decrement once
await decrementPendingRequest(userId);
const key = `PENDING_REQ:${userId}`;
const rawValue = await ioredisClient.get(key);
expect(rawValue).toBe('1');
});
test('should clean up key when count reaches zero', async () => {
if (!ioredisClient) {
return;
}
const userId = `user-cleanup-${Date.now()}`;
await checkAndIncrementPendingRequest(userId);
await decrementPendingRequest(userId);
// Key should be deleted (not left as "0")
const key = `PENDING_REQ:${userId}`;
const exists = await ioredisClient.exists(key);
expect(exists).toBe(0);
});
test('should clean up key on double-decrement (negative protection)', async () => {
if (!ioredisClient) {
return;
}
const userId = `user-double-decr-${Date.now()}`;
await checkAndIncrementPendingRequest(userId);
await decrementPendingRequest(userId);
await decrementPendingRequest(userId); // Double-decrement
// Key should be deleted, not negative
const key = `PENDING_REQ:${userId}`;
const exists = await ioredisClient.exists(key);
expect(exists).toBe(0);
});
test('should allow new requests after decrement frees a slot', async () => {
if (!ioredisClient) {
return;
}
const userId = `user-free-slot-${Date.now()}`;
// Fill to limit
await checkAndIncrementPendingRequest(userId);
await checkAndIncrementPendingRequest(userId);
// Verify at limit
const atLimit = await checkAndIncrementPendingRequest(userId);
expect(atLimit.allowed).toBe(false);
// Free a slot
await decrementPendingRequest(userId);
// Should now be allowed again
const allowed = await checkAndIncrementPendingRequest(userId);
expect(allowed.allowed).toBe(true);
expect(allowed.pendingRequests).toBe(2);
});
});
describe('TTL Behavior', () => {
test('should set TTL on the concurrency key', async () => {
if (!ioredisClient) {
return;
}
const userId = `user-ttl-${Date.now()}`;
await checkAndIncrementPendingRequest(userId);
const key = `PENDING_REQ:${userId}`;
const ttl = await ioredisClient.ttl(key);
expect(ttl).toBeGreaterThan(0);
expect(ttl).toBeLessThanOrEqual(60);
});
});
});

View file

@ -9,6 +9,40 @@ const LIMIT_CONCURRENT_MESSAGES = process.env.LIMIT_CONCURRENT_MESSAGES;
const CONCURRENT_MESSAGE_MAX = math(process.env.CONCURRENT_MESSAGE_MAX, 2);
const CONCURRENT_VIOLATION_SCORE = math(process.env.CONCURRENT_VIOLATION_SCORE, 1);
/**
* Lua script for atomic check-and-increment.
* Increments the key, sets TTL, and if over limit decrements back.
* Returns positive count if allowed, negative count if rejected.
* Single round-trip, fully atomic eliminates the INCR/check/DECR race window.
*/
const CHECK_AND_INCREMENT_SCRIPT = `
local key = KEYS[1]
local limit = tonumber(ARGV[1])
local ttl = tonumber(ARGV[2])
local current = redis.call('INCR', key)
redis.call('EXPIRE', key, ttl)
if current > limit then
redis.call('DECR', key)
return -current
end
return current
`;
/**
* Lua script for atomic decrement-and-cleanup.
* Decrements the key and deletes it if the count reaches zero or below.
* Eliminates the DECR-then-DEL race window.
*/
const DECREMENT_SCRIPT = `
local key = KEYS[1]
local current = redis.call('DECR', key)
if current <= 0 then
redis.call('DEL', key)
return 0
end
return current
`;
/** Lazily initialized cache for pending requests (used only for in-memory fallback) */
let pendingReqCache: ReturnType<typeof standardCache> | null = null;
@ -80,36 +114,28 @@ export async function checkAndIncrementPendingRequest(
return { allowed: true, pendingRequests: 0, limit };
}
// Use atomic Redis INCR when available to prevent race conditions
// Use atomic Lua script when Redis is available to prevent race conditions.
// A single EVAL round-trip atomically increments, checks, and decrements if over-limit.
if (USE_REDIS && ioredisClient) {
const key = buildKey(userId);
try {
// Pipeline ensures INCR and EXPIRE execute atomically in one round-trip
// This prevents edge cases where crash between operations leaves key without TTL
const pipeline = ioredisClient.pipeline();
pipeline.incr(key);
pipeline.expire(key, 60);
const results = await pipeline.exec();
const result = (await ioredisClient.eval(
CHECK_AND_INCREMENT_SCRIPT,
1,
key,
limit,
60,
)) as number;
if (!results || results[0][0]) {
throw new Error('Pipeline execution failed');
if (result < 0) {
// Negative return means over-limit (absolute value is the count before decrement)
const count = -result;
logger.debug(`[concurrency] User ${userId} exceeded concurrent limit: ${count}/${limit}`);
return { allowed: false, pendingRequests: count, limit };
}
const newCount = results[0][1] as number;
if (newCount > limit) {
// Over limit - decrement back and reject
await ioredisClient.decr(key);
logger.debug(
`[concurrency] User ${userId} exceeded concurrent limit: ${newCount}/${limit}`,
);
return { allowed: false, pendingRequests: newCount, limit };
}
logger.debug(
`[concurrency] User ${userId} incremented pending requests: ${newCount}/${limit}`,
);
return { allowed: true, pendingRequests: newCount, limit };
logger.debug(`[concurrency] User ${userId} incremented pending requests: ${result}/${limit}`);
return { allowed: true, pendingRequests: result, limit };
} catch (error) {
logger.error('[concurrency] Redis atomic increment failed:', error);
// On Redis error, allow the request to proceed (fail-open)
@ -164,18 +190,12 @@ export async function decrementPendingRequest(userId: string): Promise<void> {
return;
}
// Use atomic Redis DECR when available
// Use atomic Lua script to decrement and clean up zero/negative keys in one round-trip
if (USE_REDIS && ioredisClient) {
const key = buildKey(userId);
try {
const newCount = await ioredisClient.decr(key);
if (newCount < 0) {
// Counter went negative - reset to 0 and delete
await ioredisClient.del(key);
logger.debug(`[concurrency] User ${userId} pending requests cleared (was negative)`);
} else if (newCount === 0) {
// Clean up zero-value keys
await ioredisClient.del(key);
const newCount = (await ioredisClient.eval(DECREMENT_SCRIPT, 1, key)) as number;
if (newCount === 0) {
logger.debug(`[concurrency] User ${userId} pending requests cleared`);
} else {
logger.debug(`[concurrency] User ${userId} decremented pending requests: ${newCount}`);

View file

@ -0,0 +1,99 @@
import { shouldUseSecureCookie } from './csrf';
describe('shouldUseSecureCookie', () => {
const originalEnv = process.env;
beforeEach(() => {
process.env = { ...originalEnv };
});
afterAll(() => {
process.env = originalEnv;
});
it('should return true in production with a non-localhost domain', () => {
process.env.NODE_ENV = 'production';
process.env.DOMAIN_SERVER = 'https://myapp.example.com';
expect(shouldUseSecureCookie()).toBe(true);
});
it('should return false in development regardless of domain', () => {
process.env.NODE_ENV = 'development';
process.env.DOMAIN_SERVER = 'https://myapp.example.com';
expect(shouldUseSecureCookie()).toBe(false);
});
it('should return false when NODE_ENV is not set', () => {
delete process.env.NODE_ENV;
process.env.DOMAIN_SERVER = 'https://myapp.example.com';
expect(shouldUseSecureCookie()).toBe(false);
});
describe('localhost detection in production', () => {
beforeEach(() => {
process.env.NODE_ENV = 'production';
});
it('should return false for http://localhost:3080', () => {
process.env.DOMAIN_SERVER = 'http://localhost:3080';
expect(shouldUseSecureCookie()).toBe(false);
});
it('should return false for https://localhost:3080', () => {
process.env.DOMAIN_SERVER = 'https://localhost:3080';
expect(shouldUseSecureCookie()).toBe(false);
});
it('should return false for http://localhost (no port)', () => {
process.env.DOMAIN_SERVER = 'http://localhost';
expect(shouldUseSecureCookie()).toBe(false);
});
it('should return false for http://127.0.0.1:3080', () => {
process.env.DOMAIN_SERVER = 'http://127.0.0.1:3080';
expect(shouldUseSecureCookie()).toBe(false);
});
it('should return true for http://[::1]:3080 (IPv6 loopback — not detected due to URL bracket parsing)', () => {
// Known limitation: new URL('http://[::1]:3080').hostname returns '[::1]' (with brackets)
// but the check compares against '::1' (without brackets). IPv6 localhost is rare in practice.
process.env.DOMAIN_SERVER = 'http://[::1]:3080';
expect(shouldUseSecureCookie()).toBe(true);
});
it('should return false for subdomain of localhost', () => {
process.env.DOMAIN_SERVER = 'http://app.localhost:3080';
expect(shouldUseSecureCookie()).toBe(false);
});
it('should return true for a domain containing "localhost" as a substring but not as hostname', () => {
process.env.DOMAIN_SERVER = 'https://notlocalhost.example.com';
expect(shouldUseSecureCookie()).toBe(true);
});
it('should return true for a regular production domain', () => {
process.env.DOMAIN_SERVER = 'https://chat.example.com';
expect(shouldUseSecureCookie()).toBe(true);
});
it('should return true when DOMAIN_SERVER is empty (conservative default)', () => {
process.env.DOMAIN_SERVER = '';
expect(shouldUseSecureCookie()).toBe(true);
});
it('should return true when DOMAIN_SERVER is not set (conservative default)', () => {
delete process.env.DOMAIN_SERVER;
expect(shouldUseSecureCookie()).toBe(true);
});
it('should handle DOMAIN_SERVER without protocol prefix', () => {
process.env.DOMAIN_SERVER = 'localhost:3080';
expect(shouldUseSecureCookie()).toBe(false);
});
it('should handle case-insensitive hostnames', () => {
process.env.DOMAIN_SERVER = 'http://LOCALHOST:3080';
expect(shouldUseSecureCookie()).toBe(false);
});
});
});

View file

@ -0,0 +1,119 @@
import crypto from 'crypto';
import type { Request, Response, NextFunction } from 'express';
export const OAUTH_CSRF_COOKIE = 'oauth_csrf';
export const OAUTH_CSRF_MAX_AGE = 10 * 60 * 1000;
export const OAUTH_SESSION_COOKIE = 'oauth_session';
export const OAUTH_SESSION_MAX_AGE = 24 * 60 * 60 * 1000;
export const OAUTH_SESSION_COOKIE_PATH = '/api';
/**
* Determines if secure cookies should be used.
* Returns `true` in production unless the server is running on localhost (HTTP).
* This allows cookies to work on `http://localhost` during local development
* even when `NODE_ENV=production` (common in Docker Compose setups).
*/
export function shouldUseSecureCookie(): boolean {
const isProduction = process.env.NODE_ENV === 'production';
const domainServer = process.env.DOMAIN_SERVER || '';
let hostname = '';
if (domainServer) {
try {
const normalized = /^https?:\/\//i.test(domainServer)
? domainServer
: `http://${domainServer}`;
const url = new URL(normalized);
hostname = (url.hostname || '').toLowerCase();
} catch {
hostname = domainServer.toLowerCase();
}
}
const isLocalhost =
hostname === 'localhost' ||
hostname === '127.0.0.1' ||
hostname === '::1' ||
hostname.endsWith('.localhost');
return isProduction && !isLocalhost;
}
/** Generates an HMAC-based token for OAuth CSRF protection */
export function generateOAuthCsrfToken(flowId: string, secret?: string): string {
const key = secret || process.env.JWT_SECRET;
if (!key) {
throw new Error('JWT_SECRET is required for OAuth CSRF token generation');
}
return crypto.createHmac('sha256', key).update(flowId).digest('hex').slice(0, 32);
}
/** Sets a SameSite=Lax CSRF cookie bound to a specific OAuth flow */
export function setOAuthCsrfCookie(res: Response, flowId: string, cookiePath: string): void {
res.cookie(OAUTH_CSRF_COOKIE, generateOAuthCsrfToken(flowId), {
httpOnly: true,
secure: shouldUseSecureCookie(),
sameSite: 'lax',
maxAge: OAUTH_CSRF_MAX_AGE,
path: cookiePath,
});
}
/**
* Validates the per-flow CSRF cookie against the expected HMAC.
* Uses timing-safe comparison and always clears the cookie to prevent replay.
*/
export function validateOAuthCsrf(
req: Request,
res: Response,
flowId: string,
cookiePath: string,
): boolean {
const cookie = (req.cookies as Record<string, string> | undefined)?.[OAUTH_CSRF_COOKIE];
res.clearCookie(OAUTH_CSRF_COOKIE, { path: cookiePath });
if (!cookie) {
return false;
}
const expected = generateOAuthCsrfToken(flowId);
if (cookie.length !== expected.length) {
return false;
}
return crypto.timingSafeEqual(Buffer.from(cookie), Buffer.from(expected));
}
/**
* Express middleware that sets the OAuth session cookie after JWT authentication.
* Chain after requireJwtAuth on routes that precede an OAuth redirect (e.g., reinitialize, bind).
*/
export function setOAuthSession(req: Request, res: Response, next: NextFunction): void {
const user = (req as Request & { user?: { id?: string } }).user;
if (user?.id && !(req.cookies as Record<string, string> | undefined)?.[OAUTH_SESSION_COOKIE]) {
setOAuthSessionCookie(res, user.id);
}
next();
}
/** Sets a SameSite=Lax session cookie that binds the browser to the authenticated userId */
export function setOAuthSessionCookie(res: Response, userId: string): void {
res.cookie(OAUTH_SESSION_COOKIE, generateOAuthCsrfToken(userId), {
httpOnly: true,
secure: shouldUseSecureCookie(),
sameSite: 'lax',
maxAge: OAUTH_SESSION_MAX_AGE,
path: OAUTH_SESSION_COOKIE_PATH,
});
}
/** Validates the session cookie against the expected userId using timing-safe comparison */
export function validateOAuthSession(req: Request, userId: string): boolean {
const cookie = (req.cookies as Record<string, string> | undefined)?.[OAUTH_SESSION_COOKIE];
if (!cookie) {
return false;
}
const expected = generateOAuthCsrfToken(userId);
if (cookie.length !== expected.length) {
return false;
}
return crypto.timingSafeEqual(Buffer.from(cookie), Buffer.from(expected));
}

View file

@ -1 +1,2 @@
export * from './csrf';
export * from './tokens';

View file

@ -745,7 +745,6 @@ class GenerationJobManagerClass {
const subscription = this.eventTransport.subscribe(streamId, {
onChunk: (event) => {
const e = event as t.ServerSentEvent;
// Filter out internal events
if (!(e as Record<string, unknown>)._internal) {
onChunk(e);
}
@ -754,14 +753,15 @@ class GenerationJobManagerClass {
onError,
});
// Check if this is the first subscriber
if (subscription.ready) {
await subscription.ready;
}
const isFirst = this.eventTransport.isFirstSubscriber(streamId);
// First subscriber: replay buffered events and mark as connected
if (!runtime.hasSubscriber) {
runtime.hasSubscriber = true;
// Replay any events that were emitted before subscriber connected
if (runtime.earlyEventBuffer.length > 0) {
logger.debug(
`[GenerationJobManager] Replaying ${runtime.earlyEventBuffer.length} buffered events for ${streamId}`,
@ -771,6 +771,8 @@ class GenerationJobManagerClass {
}
runtime.earlyEventBuffer = [];
}
this.eventTransport.syncReorderBuffer?.(streamId);
}
if (isFirst) {
@ -823,12 +825,13 @@ class GenerationJobManagerClass {
}
}
// Buffer early events if no subscriber yet (replay when first subscriber connects)
if (!runtime.hasSubscriber) {
runtime.earlyEventBuffer.push(event);
if (!this._isRedis) {
return;
}
}
// Await the transport emit - critical for Redis mode to maintain event order
await this.eventTransport.emitChunk(streamId, event);
}

View file

@ -19,8 +19,11 @@ describe('RedisEventTransport Integration Tests', () => {
originalEnv = { ...process.env };
process.env.USE_REDIS = process.env.USE_REDIS ?? 'true';
process.env.USE_REDIS_CLUSTER = process.env.USE_REDIS_CLUSTER ?? 'false';
process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379';
process.env.REDIS_KEY_PREFIX = testPrefix;
process.env.REDIS_PING_INTERVAL = '0';
process.env.REDIS_RETRY_MAX_ATTEMPTS = '5';
jest.resetModules();
@ -890,4 +893,121 @@ describe('RedisEventTransport Integration Tests', () => {
subscriber.disconnect();
});
});
describe('Publish Error Propagation', () => {
test('should swallow emitChunk publish errors (callers fire-and-forget)', async () => {
const { RedisEventTransport } = await import('../implementations/RedisEventTransport');
const mockPublisher = {
publish: jest.fn().mockRejectedValue(new Error('Redis connection lost')),
};
const mockSubscriber = {
on: jest.fn(),
subscribe: jest.fn().mockResolvedValue(undefined),
unsubscribe: jest.fn().mockResolvedValue(undefined),
};
const transport = new RedisEventTransport(
mockPublisher as unknown as Redis,
mockSubscriber as unknown as Redis,
);
const streamId = `error-prop-chunk-${Date.now()}`;
// emitChunk swallows errors because callers often fire-and-forget (no await).
// Throwing would cause unhandled promise rejections.
await expect(transport.emitChunk(streamId, { data: 'test' })).resolves.toBeUndefined();
transport.destroy();
});
test('should throw when emitDone publish fails', async () => {
const { RedisEventTransport } = await import('../implementations/RedisEventTransport');
const mockPublisher = {
publish: jest.fn().mockRejectedValue(new Error('Redis connection lost')),
};
const mockSubscriber = {
on: jest.fn(),
subscribe: jest.fn().mockResolvedValue(undefined),
unsubscribe: jest.fn().mockResolvedValue(undefined),
};
const transport = new RedisEventTransport(
mockPublisher as unknown as Redis,
mockSubscriber as unknown as Redis,
);
const streamId = `error-prop-done-${Date.now()}`;
await expect(transport.emitDone(streamId, { finished: true })).rejects.toThrow(
'Redis connection lost',
);
transport.destroy();
});
test('should throw when emitError publish fails', async () => {
const { RedisEventTransport } = await import('../implementations/RedisEventTransport');
const mockPublisher = {
publish: jest.fn().mockRejectedValue(new Error('Redis connection lost')),
};
const mockSubscriber = {
on: jest.fn(),
subscribe: jest.fn().mockResolvedValue(undefined),
unsubscribe: jest.fn().mockResolvedValue(undefined),
};
const transport = new RedisEventTransport(
mockPublisher as unknown as Redis,
mockSubscriber as unknown as Redis,
);
const streamId = `error-prop-error-${Date.now()}`;
await expect(transport.emitError(streamId, 'some error')).rejects.toThrow(
'Redis connection lost',
);
transport.destroy();
});
test('should still deliver events successfully when publish succeeds', async () => {
if (!ioredisClient) {
console.warn('Redis not available, skipping test');
return;
}
const { RedisEventTransport } = await import('../implementations/RedisEventTransport');
const subscriber = (ioredisClient as Redis).duplicate();
const transport = new RedisEventTransport(ioredisClient, subscriber);
const streamId = `error-prop-success-${Date.now()}`;
const receivedChunks: unknown[] = [];
let doneEvent: unknown = null;
transport.subscribe(streamId, {
onChunk: (event) => receivedChunks.push(event),
onDone: (event) => {
doneEvent = event;
},
});
await new Promise((resolve) => setTimeout(resolve, 200));
// These should NOT throw
await transport.emitChunk(streamId, { text: 'hello' });
await transport.emitDone(streamId, { finished: true });
await new Promise((resolve) => setTimeout(resolve, 200));
expect(receivedChunks.length).toBe(1);
expect(doneEvent).toEqual({ finished: true });
transport.destroy();
subscriber.disconnect();
});
});
});

View file

@ -24,8 +24,11 @@ describe('RedisJobStore Integration Tests', () => {
// Set up test environment
process.env.USE_REDIS = process.env.USE_REDIS ?? 'true';
process.env.USE_REDIS_CLUSTER = process.env.USE_REDIS_CLUSTER ?? 'false';
process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379';
process.env.REDIS_KEY_PREFIX = testPrefix;
process.env.REDIS_PING_INTERVAL = '0';
process.env.REDIS_RETRY_MAX_ATTEMPTS = '5';
jest.resetModules();
@ -880,6 +883,67 @@ describe('RedisJobStore Integration Tests', () => {
});
});
describe('Race Condition: updateJob after deleteJob', () => {
test('should not re-create job hash when updateJob runs after deleteJob', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const streamId = `race-condition-${Date.now()}`;
await store.createJob(streamId, 'user-1', streamId);
const jobKey = `stream:{${streamId}}:job`;
const ttlBefore = await ioredisClient.ttl(jobKey);
expect(ttlBefore).toBeGreaterThan(0);
await store.deleteJob(streamId);
const afterDelete = await ioredisClient.exists(jobKey);
expect(afterDelete).toBe(0);
await store.updateJob(streamId, { finalEvent: JSON.stringify({ final: true }) });
const afterUpdate = await ioredisClient.exists(jobKey);
expect(afterUpdate).toBe(0);
await store.destroy();
});
test('should not leave orphan keys from concurrent emitDone and deleteJob', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const streamId = `concurrent-race-${Date.now()}`;
await store.createJob(streamId, 'user-1', streamId);
const jobKey = `stream:{${streamId}}:job`;
await Promise.all([
store.updateJob(streamId, { finalEvent: JSON.stringify({ final: true }) }),
store.deleteJob(streamId),
]);
await new Promise((resolve) => setTimeout(resolve, 100));
const exists = await ioredisClient.exists(jobKey);
const ttl = exists ? await ioredisClient.ttl(jobKey) : -2;
expect(ttl === -2 || ttl > 0).toBe(true);
expect(ttl).not.toBe(-1);
await store.destroy();
});
});
describe('Local Graph Cache Optimization', () => {
test('should use local cache when available', async () => {
if (!ioredisClient) {
@ -972,4 +1036,196 @@ describe('RedisJobStore Integration Tests', () => {
await instance2.destroy();
});
});
describe('Batched Cleanup', () => {
test('should clean up many stale jobs in parallel batches', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
// Very short TTL so jobs are immediately stale
const store = new RedisJobStore(ioredisClient, { runningTtl: 1 });
await store.initialize();
const jobCount = 75; // More than one batch of 50
const veryOldTimestamp = Date.now() - 10000; // 10 seconds ago
// Create many stale jobs directly in Redis
for (let i = 0; i < jobCount; i++) {
const streamId = `batch-cleanup-${Date.now()}-${i}`;
const jobKey = `stream:{${streamId}}:job`;
await ioredisClient.hmset(jobKey, {
streamId,
userId: 'batch-user',
status: 'running',
createdAt: veryOldTimestamp.toString(),
syncSent: '0',
});
await ioredisClient.sadd('stream:running', streamId);
}
// Verify jobs are in the running set
const runningBefore = await ioredisClient.scard('stream:running');
expect(runningBefore).toBeGreaterThanOrEqual(jobCount);
// Run cleanup - should process in batches of 50
const cleaned = await store.cleanup();
expect(cleaned).toBeGreaterThanOrEqual(jobCount);
await store.destroy();
});
test('should not clean up valid running jobs during batch cleanup', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient, { runningTtl: 1200 });
await store.initialize();
// Create a mix of valid and stale jobs
const validStreamId = `valid-job-${Date.now()}`;
await store.createJob(validStreamId, 'user-1', validStreamId);
const staleStreamId = `stale-job-${Date.now()}`;
const jobKey = `stream:{${staleStreamId}}:job`;
await ioredisClient.hmset(jobKey, {
streamId: staleStreamId,
userId: 'user-1',
status: 'running',
createdAt: (Date.now() - 2000000).toString(), // Very old
syncSent: '0',
});
await ioredisClient.sadd('stream:running', staleStreamId);
const cleaned = await store.cleanup();
expect(cleaned).toBeGreaterThanOrEqual(1);
// Valid job should still exist
const validJob = await store.getJob(validStreamId);
expect(validJob).not.toBeNull();
expect(validJob?.status).toBe('running');
await store.destroy();
});
});
describe('appendChunk TTL Refresh', () => {
test('should set TTL on the chunk stream', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient, { runningTtl: 120 });
await store.initialize();
const streamId = `append-ttl-${Date.now()}`;
await store.createJob(streamId, 'user-1', streamId);
await store.appendChunk(streamId, {
event: 'on_message_delta',
data: { id: 'step-1', type: 'text', text: 'first' },
});
const chunkKey = `stream:{${streamId}}:chunks`;
const ttl = await ioredisClient.ttl(chunkKey);
expect(ttl).toBeGreaterThan(0);
expect(ttl).toBeLessThanOrEqual(120);
await store.destroy();
});
test('should refresh TTL on subsequent chunks (not just first)', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient, { runningTtl: 120 });
await store.initialize();
const streamId = `append-refresh-${Date.now()}`;
await store.createJob(streamId, 'user-1', streamId);
// Append first chunk
await store.appendChunk(streamId, {
event: 'on_message_delta',
data: { id: 'step-1', type: 'text', text: 'first' },
});
const chunkKey = `stream:{${streamId}}:chunks`;
const ttl1 = await ioredisClient.ttl(chunkKey);
expect(ttl1).toBeGreaterThan(0);
// Manually reduce TTL to simulate time passing
await ioredisClient.expire(chunkKey, 30);
const reducedTtl = await ioredisClient.ttl(chunkKey);
expect(reducedTtl).toBeLessThanOrEqual(30);
// Append another chunk - TTL should be refreshed back to running TTL
await store.appendChunk(streamId, {
event: 'on_message_delta',
data: { id: 'step-1', type: 'text', text: 'second' },
});
const ttl2 = await ioredisClient.ttl(chunkKey);
// Should be refreshed to ~120, not still ~30
expect(ttl2).toBeGreaterThan(30);
expect(ttl2).toBeLessThanOrEqual(120);
await store.destroy();
});
test('should store chunks correctly via pipeline', async () => {
if (!ioredisClient) {
return;
}
const { RedisJobStore } = await import('../implementations/RedisJobStore');
const store = new RedisJobStore(ioredisClient);
await store.initialize();
const streamId = `append-pipeline-${Date.now()}`;
await store.createJob(streamId, 'user-1', streamId);
const chunks = [
{
event: 'on_run_step',
data: {
id: 'step-1',
runId: 'run-1',
index: 0,
stepDetails: { type: 'message_creation' },
},
},
{
event: 'on_message_delta',
data: { id: 'step-1', delta: { content: { type: 'text', text: 'Hello ' } } },
},
{
event: 'on_message_delta',
data: { id: 'step-1', delta: { content: { type: 'text', text: 'world!' } } },
},
];
for (const chunk of chunks) {
await store.appendChunk(streamId, chunk);
}
// Verify all chunks were stored
const chunkKey = `stream:{${streamId}}:chunks`;
const len = await ioredisClient.xlen(chunkKey);
expect(len).toBe(3);
// Verify content can be reconstructed
const content = await store.getContentParts(streamId);
expect(content).not.toBeNull();
expect(content!.content.length).toBeGreaterThan(0);
await store.destroy();
});
});
});

View file

@ -146,7 +146,7 @@ describe('CollectedUsage - GenerationJobManager', () => {
cleanupOnComplete: false,
});
await GenerationJobManager.initialize();
GenerationJobManager.initialize();
const streamId = `manager-test-${Date.now()}`;
await GenerationJobManager.createJob(streamId, 'user-1');
@ -179,7 +179,7 @@ describe('CollectedUsage - GenerationJobManager', () => {
cleanupOnComplete: false,
});
await GenerationJobManager.initialize();
GenerationJobManager.initialize();
const streamId = `no-usage-test-${Date.now()}`;
await GenerationJobManager.createJob(streamId, 'user-1');
@ -202,7 +202,7 @@ describe('CollectedUsage - GenerationJobManager', () => {
isRedis: false,
});
await GenerationJobManager.initialize();
GenerationJobManager.initialize();
const collectedUsage: UsageMetadata[] = [
{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' },
@ -235,7 +235,7 @@ describe('AbortJob - Text and CollectedUsage', () => {
cleanupOnComplete: false,
});
await GenerationJobManager.initialize();
GenerationJobManager.initialize();
const streamId = `text-extract-${Date.now()}`;
await GenerationJobManager.createJob(streamId, 'user-1');
@ -267,7 +267,7 @@ describe('AbortJob - Text and CollectedUsage', () => {
cleanupOnComplete: false,
});
await GenerationJobManager.initialize();
GenerationJobManager.initialize();
const streamId = `empty-text-${Date.now()}`;
await GenerationJobManager.createJob(streamId, 'user-1');
@ -291,7 +291,7 @@ describe('AbortJob - Text and CollectedUsage', () => {
cleanupOnComplete: false,
});
await GenerationJobManager.initialize();
GenerationJobManager.initialize();
const streamId = `full-abort-${Date.now()}`;
await GenerationJobManager.createJob(streamId, 'user-1');
@ -328,7 +328,7 @@ describe('AbortJob - Text and CollectedUsage', () => {
isRedis: false,
});
await GenerationJobManager.initialize();
GenerationJobManager.initialize();
const abortResult = await GenerationJobManager.abortJob('non-existent-job');
@ -365,7 +365,7 @@ describe('Real-world Scenarios', () => {
cleanupOnComplete: false,
});
await GenerationJobManager.initialize();
GenerationJobManager.initialize();
const streamId = `parallel-abort-${Date.now()}`;
await GenerationJobManager.createJob(streamId, 'user-1');
@ -419,7 +419,7 @@ describe('Real-world Scenarios', () => {
cleanupOnComplete: false,
});
await GenerationJobManager.initialize();
GenerationJobManager.initialize();
const streamId = `cache-abort-${Date.now()}`;
await GenerationJobManager.createJob(streamId, 'user-1');
@ -459,7 +459,7 @@ describe('Real-world Scenarios', () => {
cleanupOnComplete: false,
});
await GenerationJobManager.initialize();
GenerationJobManager.initialize();
const streamId = `sequential-abort-${Date.now()}`;
await GenerationJobManager.createJob(streamId, 'user-1');

View file

@ -0,0 +1,450 @@
import type { Redis, Cluster } from 'ioredis';
import { RedisEventTransport } from '~/stream/implementations/RedisEventTransport';
import { GenerationJobManagerClass } from '~/stream/GenerationJobManager';
import { createStreamServices } from '~/stream/createStreamServices';
import {
ioredisClient as staticRedisClient,
keyvRedisClient as staticKeyvClient,
keyvRedisClientReady,
} from '~/cache/redisClients';
/**
* Regression tests for the reconnect reorder buffer desync bug.
*
* Bug: When a user disconnects and reconnects to a stream multiple times,
* the second+ reconnect lost chunks because the transport deleted stream state
* on last unsubscribe, destroying the allSubscribersLeftCallbacks registered
* by createJob(). This prevented hasSubscriber from being reset, which in turn
* prevented syncReorderBuffer from being called on reconnect.
*
* Fix: Preserve stream state (callbacks, abort handlers) across reconnect cycles
* instead of deleting it. The state is fully cleaned up by cleanup() when the
* job completes.
*
* Run with: USE_REDIS=true npx jest reconnect-reorder-desync
*/
describe('Reconnect Reorder Buffer Desync (Regression)', () => {
describe('Callback preservation across reconnect cycles (Unit)', () => {
test('allSubscribersLeft callback fires on every disconnect, not just the first', () => {
const mockPublisher = {
publish: jest.fn().mockResolvedValue(1),
};
const mockSubscriber = {
on: jest.fn(),
subscribe: jest.fn().mockResolvedValue(undefined),
unsubscribe: jest.fn().mockResolvedValue(undefined),
};
const transport = new RedisEventTransport(
mockPublisher as unknown as Redis,
mockSubscriber as unknown as Redis,
);
const streamId = 'callback-persist-test';
let callbackFireCount = 0;
// Register callback (simulates what createJob does)
transport.onAllSubscribersLeft(streamId, () => {
callbackFireCount++;
});
// First subscribe/unsubscribe cycle
const sub1 = transport.subscribe(streamId, { onChunk: () => {} });
sub1.unsubscribe();
expect(callbackFireCount).toBe(1);
// Second subscribe/unsubscribe cycle — callback must still fire
const sub2 = transport.subscribe(streamId, { onChunk: () => {} });
sub2.unsubscribe();
expect(callbackFireCount).toBe(2);
// Third cycle — continues to work
const sub3 = transport.subscribe(streamId, { onChunk: () => {} });
sub3.unsubscribe();
expect(callbackFireCount).toBe(3);
transport.destroy();
});
test('abort callback survives across reconnect cycles', () => {
const mockPublisher = {
publish: jest.fn().mockResolvedValue(1),
};
const mockSubscriber = {
on: jest.fn(),
subscribe: jest.fn().mockResolvedValue(undefined),
unsubscribe: jest.fn().mockResolvedValue(undefined),
};
const transport = new RedisEventTransport(
mockPublisher as unknown as Redis,
mockSubscriber as unknown as Redis,
);
const streamId = 'abort-callback-persist-test';
let abortCallbackFired = false;
// Register abort callback (simulates what createJob does)
transport.onAbort(streamId, () => {
abortCallbackFired = true;
});
// Subscribe/unsubscribe cycle
const sub1 = transport.subscribe(streamId, { onChunk: () => {} });
sub1.unsubscribe();
// Re-subscribe and receive an abort signal
const sub2 = transport.subscribe(streamId, { onChunk: () => {} });
const messageHandler = mockSubscriber.on.mock.calls.find(
(call) => call[0] === 'message',
)?.[1] as (channel: string, message: string) => void;
const channel = `stream:{${streamId}}:events`;
messageHandler(channel, JSON.stringify({ type: 'abort' }));
// Abort callback should fire — it was preserved across the reconnect
expect(abortCallbackFired).toBe(true);
sub2.unsubscribe();
transport.destroy();
});
});
describe('Reorder buffer sync on reconnect (Unit)', () => {
/**
* After the fix, the allSubscribersLeft callback fires on every disconnect,
* which resets hasSubscriber. GenerationJobManager.subscribe() then enters
* the if (!runtime.hasSubscriber) block and calls syncReorderBuffer.
*
* This test verifies at the transport level that when syncReorderBuffer IS
* called (as it now will be on every reconnect), messages are delivered
* immediately regardless of how many reconnect cycles have occurred.
*/
test('syncReorderBuffer works correctly on third+ reconnect', async () => {
const mockPublisher = {
publish: jest.fn().mockResolvedValue(1),
};
const mockSubscriber = {
on: jest.fn(),
subscribe: jest.fn().mockResolvedValue(undefined),
unsubscribe: jest.fn().mockResolvedValue(undefined),
};
const transport = new RedisEventTransport(
mockPublisher as unknown as Redis,
mockSubscriber as unknown as Redis,
);
const streamId = 'reorder-multi-reconnect-test';
transport.onAllSubscribersLeft(streamId, () => {
// Simulates the callback from createJob
});
const messageHandler = mockSubscriber.on.mock.calls.find(
(call) => call[0] === 'message',
)?.[1] as (channel: string, message: string) => void;
const channel = `stream:{${streamId}}:events`;
// Run 3 full subscribe/emit/unsubscribe cycles
for (let cycle = 0; cycle < 3; cycle++) {
const chunks: unknown[] = [];
const sub = transport.subscribe(streamId, {
onChunk: (event) => chunks.push(event),
});
// Sync reorder buffer (as GenerationJobManager.subscribe does)
transport.syncReorderBuffer(streamId);
const baseSeq = cycle * 10;
// Emit 10 chunks (advances publisher sequence)
for (let i = 0; i < 10; i++) {
await transport.emitChunk(streamId, { index: baseSeq + i });
}
// Deliver messages via pub/sub handler
for (let i = 0; i < 10; i++) {
messageHandler(
channel,
JSON.stringify({ type: 'chunk', seq: baseSeq + i, data: { index: baseSeq + i } }),
);
}
// Messages should be delivered immediately on every cycle
expect(chunks.length).toBe(10);
expect(chunks.map((c) => (c as { index: number }).index)).toEqual(
Array.from({ length: 10 }, (_, i) => baseSeq + i),
);
sub.unsubscribe();
}
transport.destroy();
});
test('reorder buffer works correctly when syncReorderBuffer IS called', async () => {
const mockPublisher = {
publish: jest.fn().mockResolvedValue(1),
};
const mockSubscriber = {
on: jest.fn(),
subscribe: jest.fn().mockResolvedValue(undefined),
unsubscribe: jest.fn().mockResolvedValue(undefined),
};
const transport = new RedisEventTransport(
mockPublisher as unknown as Redis,
mockSubscriber as unknown as Redis,
);
const streamId = 'reorder-sync-test';
// Emit 20 chunks to advance publisher sequence
for (let i = 0; i < 20; i++) {
await transport.emitChunk(streamId, { index: i });
}
// Subscribe and sync the reorder buffer
const chunks: unknown[] = [];
const sub = transport.subscribe(streamId, {
onChunk: (event) => chunks.push(event),
});
// This is the critical call - sync nextSeq to match publisher
transport.syncReorderBuffer(streamId);
// Deliver messages starting at seq 20
const messageHandler = mockSubscriber.on.mock.calls.find(
(call) => call[0] === 'message',
)?.[1] as (channel: string, message: string) => void;
const channel = `stream:{${streamId}}:events`;
for (let i = 20; i < 25; i++) {
messageHandler(channel, JSON.stringify({ type: 'chunk', seq: i, data: { index: i } }));
}
// Messages should be delivered IMMEDIATELY (no 500ms wait)
// because nextSeq was synced to 20
expect(chunks.length).toBe(5);
expect(chunks.map((c) => (c as { index: number }).index)).toEqual([20, 21, 22, 23, 24]);
sub.unsubscribe();
transport.destroy();
});
});
describe('End-to-end reconnect with GenerationJobManager (Integration)', () => {
let originalEnv: NodeJS.ProcessEnv;
let ioredisClient: Redis | Cluster | null = null;
let dynamicKeyvClient: unknown = null;
let dynamicKeyvReady: Promise<unknown> | null = null;
const testPrefix = 'ReconnectDesync-Test';
beforeAll(async () => {
originalEnv = { ...process.env };
process.env.USE_REDIS = process.env.USE_REDIS ?? 'true';
process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379';
process.env.REDIS_KEY_PREFIX = testPrefix;
jest.resetModules();
const redisModule = await import('~/cache/redisClients');
ioredisClient = redisModule.ioredisClient;
dynamicKeyvClient = redisModule.keyvRedisClient;
dynamicKeyvReady = redisModule.keyvRedisClientReady;
});
afterEach(async () => {
jest.resetModules();
if (ioredisClient) {
try {
const keys = await ioredisClient.keys(`${testPrefix}*`);
const streamKeys = await ioredisClient.keys('stream:*');
const allKeys = [...keys, ...streamKeys];
await Promise.all(allKeys.map((key) => ioredisClient!.del(key)));
} catch {
// Ignore cleanup errors
}
}
});
afterAll(async () => {
for (const ready of [keyvRedisClientReady, dynamicKeyvReady]) {
if (ready) {
await ready.catch(() => {});
}
}
const clients = [ioredisClient, staticRedisClient, staticKeyvClient, dynamicKeyvClient];
for (const client of clients) {
if (!client) {
continue;
}
try {
await (client as { disconnect: () => void | Promise<void> }).disconnect();
} catch {
/* ignore */
}
}
process.env = originalEnv;
});
/**
* Verifies that all reconnect cycles deliver chunks immediately
* not just the first reconnect.
*/
test('chunks are delivered immediately on every reconnect cycle', async () => {
if (!ioredisClient) {
console.warn('Redis not available, skipping test');
return;
}
const manager = new GenerationJobManagerClass();
const services = createStreamServices({
useRedis: true,
redisClient: ioredisClient,
});
manager.configure(services);
manager.initialize();
const streamId = `reconnect-fixed-${Date.now()}`;
await manager.createJob(streamId, 'user-1');
// Run 3 subscribe/emit/unsubscribe cycles
for (let cycle = 0; cycle < 3; cycle++) {
const chunks: unknown[] = [];
const sub = await manager.subscribe(streamId, (event) => chunks.push(event));
await new Promise((resolve) => setTimeout(resolve, 100));
// Emit 10 chunks
for (let i = 0; i < 10; i++) {
await manager.emitChunk(streamId, {
event: 'on_message_delta',
data: {
delta: { content: { type: 'text', text: `c${cycle}-${i}` } },
index: cycle * 10 + i,
},
});
}
// Chunks should arrive within 200ms (well under the 500ms force-flush timeout)
await new Promise((resolve) => setTimeout(resolve, 200));
expect(chunks.length).toBe(10);
sub!.unsubscribe();
await new Promise((resolve) => setTimeout(resolve, 100));
}
await manager.destroy();
});
/**
* Verifies that syncSent is correctly reset on every disconnect,
* proving the onAllSubscribersLeft callback survives reconnect cycles.
*/
test('onAllSubscribersLeft callback resets state on every disconnect', async () => {
if (!ioredisClient) {
console.warn('Redis not available, skipping test');
return;
}
const manager = new GenerationJobManagerClass();
const services = createStreamServices({
useRedis: true,
redisClient: ioredisClient,
});
manager.configure(services);
manager.initialize();
const streamId = `callback-persist-integ-${Date.now()}`;
await manager.createJob(streamId, 'user-1');
for (let cycle = 0; cycle < 3; cycle++) {
const sub = await manager.subscribe(streamId, () => {});
await new Promise((resolve) => setTimeout(resolve, 50));
// Mark sync as sent
manager.markSyncSent(streamId);
await new Promise((resolve) => setTimeout(resolve, 50));
let syncSent = await manager.wasSyncSent(streamId);
expect(syncSent).toBe(true);
// Disconnect
sub!.unsubscribe();
await new Promise((resolve) => setTimeout(resolve, 100));
// Callback should reset syncSent on every disconnect
syncSent = await manager.wasSyncSent(streamId);
expect(syncSent).toBe(false);
}
await manager.destroy();
});
/**
* Verifies all reconnect cycles deliver chunks immediately with no
* increasing gap pattern.
*/
test('no increasing gap pattern across reconnect cycles', async () => {
if (!ioredisClient) {
console.warn('Redis not available, skipping test');
return;
}
const manager = new GenerationJobManagerClass();
const services = createStreamServices({
useRedis: true,
redisClient: ioredisClient,
});
manager.configure(services);
manager.initialize();
const streamId = `no-gaps-${Date.now()}`;
await manager.createJob(streamId, 'user-1');
const chunksPerCycle = 15;
for (let cycle = 0; cycle < 4; cycle++) {
const chunks: unknown[] = [];
const sub = await manager.subscribe(streamId, (event) => chunks.push(event));
await new Promise((resolve) => setTimeout(resolve, 100));
// Emit chunks
for (let i = 0; i < chunksPerCycle; i++) {
await manager.emitChunk(streamId, {
event: 'on_message_delta',
data: {
delta: { content: { type: 'text', text: `c${cycle}-${i}` } },
index: cycle * chunksPerCycle + i,
},
});
}
// All chunks should arrive within 200ms on every cycle
await new Promise((resolve) => setTimeout(resolve, 200));
expect(chunks.length).toBe(chunksPerCycle);
sub!.unsubscribe();
await new Promise((resolve) => setTimeout(resolve, 100));
}
await manager.destroy();
});
});
});

View file

@ -32,7 +32,7 @@ export class InMemoryEventTransport implements IEventTransport {
onDone?: (event: unknown) => void;
onError?: (error: string) => void;
},
): { unsubscribe: () => void } {
): { unsubscribe: () => void; ready?: Promise<void> } {
const state = this.getOrCreateStream(streamId);
const chunkHandler = (event: unknown) => handlers.onChunk(event);
@ -58,9 +58,11 @@ export class InMemoryEventTransport implements IEventTransport {
// Check if all subscribers left - cleanup and notify
if (currentState.emitter.listenerCount('chunk') === 0) {
currentState.allSubscribersLeftCallback?.();
// Auto-cleanup the stream entry when no subscribers remain
/* Remove all EventEmitter listeners but preserve stream state
* (including allSubscribersLeftCallback) for reconnection.
* State is fully cleaned up by cleanup() when the job completes.
*/
currentState.emitter.removeAllListeners();
this.streams.delete(streamId);
}
}
},

View file

@ -92,8 +92,8 @@ export class RedisEventTransport implements IEventTransport {
private subscriber: Redis | Cluster;
/** Track subscribers per stream */
private streams = new Map<string, StreamSubscribers>();
/** Track which channels we're subscribed to */
private subscribedChannels = new Set<string>();
/** Track channel subscription state: resolved promise = active, pending = in-flight */
private channelSubscriptions = new Map<string, Promise<void>>();
/** Counter for generating unique subscriber IDs */
private subscriberIdCounter = 0;
/** Sequence counters per stream for publishing (ensures ordered delivery in cluster mode) */
@ -122,9 +122,32 @@ export class RedisEventTransport implements IEventTransport {
return current;
}
/** Reset sequence counter for a stream */
private resetSequence(streamId: string): void {
/** Reset publish sequence counter and subscriber reorder state for a stream (full cleanup only) */
resetSequence(streamId: string): void {
this.sequenceCounters.delete(streamId);
const state = this.streams.get(streamId);
if (state) {
if (state.reorderBuffer.flushTimeout) {
clearTimeout(state.reorderBuffer.flushTimeout);
state.reorderBuffer.flushTimeout = null;
}
state.reorderBuffer.nextSeq = 0;
state.reorderBuffer.pending.clear();
}
}
/** Advance subscriber reorder buffer to current publisher sequence without resetting publisher (cross-replica safe) */
syncReorderBuffer(streamId: string): void {
const currentSeq = this.sequenceCounters.get(streamId) ?? 0;
const state = this.streams.get(streamId);
if (state) {
if (state.reorderBuffer.flushTimeout) {
clearTimeout(state.reorderBuffer.flushTimeout);
state.reorderBuffer.flushTimeout = null;
}
state.reorderBuffer.nextSeq = currentSeq;
state.reorderBuffer.pending.clear();
}
}
/**
@ -331,7 +354,7 @@ export class RedisEventTransport implements IEventTransport {
onDone?: (event: unknown) => void;
onError?: (error: string) => void;
},
): { unsubscribe: () => void } {
): { unsubscribe: () => void; ready?: Promise<void> } {
const channel = CHANNELS.events(streamId);
const subscriberId = `sub_${++this.subscriberIdCounter}`;
@ -354,16 +377,23 @@ export class RedisEventTransport implements IEventTransport {
streamState.count++;
streamState.handlers.set(subscriberId, handlers);
// Subscribe to Redis channel if this is first subscriber
if (!this.subscribedChannels.has(channel)) {
this.subscribedChannels.add(channel);
this.subscriber.subscribe(channel).catch((err) => {
logger.error(`[RedisEventTransport] Failed to subscribe to ${channel}:`, err);
});
let readyPromise = this.channelSubscriptions.get(channel);
if (!readyPromise) {
readyPromise = this.subscriber
.subscribe(channel)
.then(() => {
logger.debug(`[RedisEventTransport] Subscription active for channel ${channel}`);
})
.catch((err) => {
this.channelSubscriptions.delete(channel);
logger.error(`[RedisEventTransport] Failed to subscribe to ${channel}:`, err);
});
this.channelSubscriptions.set(channel, readyPromise);
}
// Return unsubscribe function
return {
ready: readyPromise,
unsubscribe: () => {
const state = this.streams.get(streamId);
if (!state) {
@ -385,7 +415,7 @@ export class RedisEventTransport implements IEventTransport {
this.subscriber.unsubscribe(channel).catch((err) => {
logger.error(`[RedisEventTransport] Failed to unsubscribe from ${channel}:`, err);
});
this.subscribedChannels.delete(channel);
this.channelSubscriptions.delete(channel);
// Call all-subscribers-left callbacks
for (const callback of state.allSubscribersLeftCallbacks) {
@ -395,8 +425,15 @@ export class RedisEventTransport implements IEventTransport {
logger.error(`[RedisEventTransport] Error in allSubscribersLeft callback:`, err);
}
}
this.streams.delete(streamId);
/**
* Preserve stream state (callbacks, abort handlers) for reconnection.
* Previously this deleted the entire state, which lost the
* allSubscribersLeftCallbacks and abortCallbacks registered by
* GenerationJobManager.createJob(). On the next subscribe() call,
* fresh state was created without those callbacks, causing
* hasSubscriber to never reset and syncReorderBuffer to be skipped.
* State is fully cleaned up by cleanup() when the job completes.
*/
}
},
};
@ -431,6 +468,7 @@ export class RedisEventTransport implements IEventTransport {
await this.publisher.publish(channel, JSON.stringify(message));
} catch (err) {
logger.error(`[RedisEventTransport] Failed to publish done:`, err);
throw err;
}
}
@ -447,6 +485,7 @@ export class RedisEventTransport implements IEventTransport {
await this.publisher.publish(channel, JSON.stringify(message));
} catch (err) {
logger.error(`[RedisEventTransport] Failed to publish error:`, err);
throw err;
}
}
@ -532,12 +571,15 @@ export class RedisEventTransport implements IEventTransport {
state.abortCallbacks.push(callback);
// Subscribe to Redis channel if not already subscribed
if (!this.subscribedChannels.has(channel)) {
this.subscribedChannels.add(channel);
this.subscriber.subscribe(channel).catch((err) => {
logger.error(`[RedisEventTransport] Failed to subscribe to ${channel}:`, err);
});
if (!this.channelSubscriptions.has(channel)) {
const ready = this.subscriber
.subscribe(channel)
.then(() => {})
.catch((err) => {
this.channelSubscriptions.delete(channel);
logger.error(`[RedisEventTransport] Failed to subscribe to ${channel}:`, err);
});
this.channelSubscriptions.set(channel, ready);
}
}
@ -571,12 +613,11 @@ export class RedisEventTransport implements IEventTransport {
// Reset sequence counter for this stream
this.resetSequence(streamId);
// Unsubscribe from Redis channel
if (this.subscribedChannels.has(channel)) {
if (this.channelSubscriptions.has(channel)) {
this.subscriber.unsubscribe(channel).catch((err) => {
logger.error(`[RedisEventTransport] Failed to cleanup ${channel}:`, err);
});
this.subscribedChannels.delete(channel);
this.channelSubscriptions.delete(channel);
}
this.streams.delete(streamId);
@ -595,18 +636,20 @@ export class RedisEventTransport implements IEventTransport {
state.reorderBuffer.pending.clear();
}
// Unsubscribe from all channels
for (const channel of this.subscribedChannels) {
this.subscriber.unsubscribe(channel).catch(() => {
// Ignore errors during shutdown
});
for (const channel of this.channelSubscriptions.keys()) {
this.subscriber.unsubscribe(channel).catch(() => {});
}
this.subscribedChannels.clear();
this.channelSubscriptions.clear();
this.streams.clear();
this.sequenceCounters.clear();
// Note: Don't close Redis connections - they may be shared
try {
this.subscriber.disconnect();
} catch {
/* ignore */
}
logger.info('[RedisEventTransport] Destroyed');
}
}

View file

@ -156,13 +156,13 @@ export class RedisJobStore implements IJobStore {
// For cluster mode, we can't pipeline keys on different slots
// The job key uses hash tag {streamId}, runningJobs and userJobs are on different slots
if (this.isCluster) {
await this.redis.hmset(key, this.serializeJob(job));
await this.redis.hset(key, this.serializeJob(job));
await this.redis.expire(key, this.ttl.running);
await this.redis.sadd(KEYS.runningJobs, streamId);
await this.redis.sadd(userJobsKey, streamId);
} else {
const pipeline = this.redis.pipeline();
pipeline.hmset(key, this.serializeJob(job));
pipeline.hset(key, this.serializeJob(job));
pipeline.expire(key, this.ttl.running);
pipeline.sadd(KEYS.runningJobs, streamId);
pipeline.sadd(userJobsKey, streamId);
@ -183,17 +183,23 @@ export class RedisJobStore implements IJobStore {
async updateJob(streamId: string, updates: Partial<SerializableJobData>): Promise<void> {
const key = KEYS.job(streamId);
const exists = await this.redis.exists(key);
if (!exists) {
return;
}
const serialized = this.serializeJob(updates as SerializableJobData);
if (Object.keys(serialized).length === 0) {
return;
}
await this.redis.hmset(key, serialized);
const fields = Object.entries(serialized).flat();
const updated = await this.redis.eval(
'if redis.call("EXISTS", KEYS[1]) == 1 then redis.call("HSET", KEYS[1], unpack(ARGV)) return 1 else return 0 end',
1,
key,
...fields,
);
if (updated === 0) {
return;
}
// If status changed to complete/error/aborted, update TTL and remove from running set
// Note: userJobs cleanup is handled lazily via self-healing in getActiveJobIdsByUser
@ -296,32 +302,46 @@ export class RedisJobStore implements IJobStore {
}
}
for (const streamId of streamIds) {
const job = await this.getJob(streamId);
// Process in batches of 50 to avoid sequential per-job round-trips
const BATCH_SIZE = 50;
for (let i = 0; i < streamIds.length; i += BATCH_SIZE) {
const batch = streamIds.slice(i, i + BATCH_SIZE);
const results = await Promise.allSettled(
batch.map(async (streamId) => {
const job = await this.getJob(streamId);
// Job no longer exists (TTL expired) - remove from set
if (!job) {
await this.redis.srem(KEYS.runningJobs, streamId);
this.localGraphCache.delete(streamId);
this.localCollectedUsageCache.delete(streamId);
cleaned++;
continue;
}
// Job no longer exists (TTL expired) - remove from set
if (!job) {
await this.redis.srem(KEYS.runningJobs, streamId);
this.localGraphCache.delete(streamId);
this.localCollectedUsageCache.delete(streamId);
return 1;
}
// Job completed but still in running set (shouldn't happen, but handle it)
if (job.status !== 'running') {
await this.redis.srem(KEYS.runningJobs, streamId);
this.localGraphCache.delete(streamId);
this.localCollectedUsageCache.delete(streamId);
cleaned++;
continue;
}
// Job completed but still in running set (shouldn't happen, but handle it)
if (job.status !== 'running') {
await this.redis.srem(KEYS.runningJobs, streamId);
this.localGraphCache.delete(streamId);
this.localCollectedUsageCache.delete(streamId);
return 1;
}
// Stale running job (failsafe - running for > configured TTL)
if (now - job.createdAt > this.ttl.running * 1000) {
logger.warn(`[RedisJobStore] Cleaning up stale job: ${streamId}`);
await this.deleteJob(streamId);
cleaned++;
// Stale running job (failsafe - running for > configured TTL)
if (now - job.createdAt > this.ttl.running * 1000) {
logger.warn(`[RedisJobStore] Cleaning up stale job: ${streamId}`);
await this.deleteJob(streamId);
return 1;
}
return 0;
}),
);
for (const result of results) {
if (result.status === 'fulfilled') {
cleaned += result.value;
} else {
logger.warn(`[RedisJobStore] Cleanup failed for a job:`, result.reason);
}
}
}
@ -586,16 +606,14 @@ export class RedisJobStore implements IJobStore {
*/
async appendChunk(streamId: string, event: unknown): Promise<void> {
const key = KEYS.chunks(streamId);
const added = await this.redis.xadd(key, '*', 'event', JSON.stringify(event));
// Set TTL on first chunk (when stream is created)
// Subsequent chunks inherit the stream's TTL
if (added) {
const len = await this.redis.xlen(key);
if (len === 1) {
await this.redis.expire(key, this.ttl.running);
}
}
// Pipeline XADD + EXPIRE in a single round-trip.
// EXPIRE is O(1) and idempotent — refreshing TTL on every chunk is better than
// only setting it once, since the original approach could let the TTL expire
// during long-running streams.
const pipeline = this.redis.pipeline();
pipeline.xadd(key, '*', 'event', JSON.stringify(event));
pipeline.expire(key, this.ttl.running);
await pipeline.exec();
}
/**

View file

@ -286,7 +286,7 @@ export interface IJobStore {
* Implementations can use EventEmitter, Redis Pub/Sub, etc.
*/
export interface IEventTransport {
/** Subscribe to events for a stream */
/** Subscribe to events for a stream. `ready` resolves once the transport can receive messages. */
subscribe(
streamId: string,
handlers: {
@ -294,7 +294,7 @@ export interface IEventTransport {
onDone?: (event: unknown) => void;
onError?: (error: string) => void;
},
): { unsubscribe: () => void };
): { unsubscribe: () => void; ready?: Promise<void> };
/** Publish a chunk event - returns Promise in Redis mode for ordered delivery */
emitChunk(streamId: string, event: unknown): void | Promise<void>;
@ -329,6 +329,12 @@ export interface IEventTransport {
/** Listen for all subscribers leaving */
onAllSubscribersLeft(streamId: string, callback: () => void): void;
/** Reset publish sequence counter for a stream (used during full stream cleanup) */
resetSequence?(streamId: string): void;
/** Advance subscriber reorder buffer to match publisher sequence (cross-replica safe: doesn't reset publisher counter) */
syncReorderBuffer?(streamId: string): void;
/** Cleanup transport resources for a specific stream */
cleanup(streamId: string): void;

View file

@ -8,9 +8,10 @@
import { Constants, actionDelimiter } from 'librechat-data-provider';
import type { AgentToolOptions } from 'librechat-data-provider';
import type { LCToolRegistry, JsonSchemaType, LCTool, GenericTool } from '@librechat/agents';
import { buildToolClassification, type ToolDefinition } from './classification';
import type { ToolDefinition } from './classification';
import { resolveJsonSchemaRefs, normalizeJsonSchema } from '~/mcp/zod';
import { buildToolClassification } from './classification';
import { getToolDefinition } from './registry/definitions';
import { resolveJsonSchemaRefs } from '~/mcp/zod';
export interface MCPServerTool {
function?: {
@ -138,7 +139,7 @@ export async function loadToolDefinitions(
name: actualToolName,
description: toolDef.function.description,
parameters: toolDef.function.parameters
? resolveJsonSchemaRefs(toolDef.function.parameters)
? normalizeJsonSchema(resolveJsonSchemaRefs(toolDef.function.parameters))
: undefined,
serverName,
});
@ -153,7 +154,7 @@ export async function loadToolDefinitions(
name: toolName,
description: toolDef.function.description,
parameters: toolDef.function.parameters
? resolveJsonSchemaRefs(toolDef.function.parameters)
? normalizeJsonSchema(resolveJsonSchemaRefs(toolDef.function.parameters))
: undefined,
serverName,
});

View file

@ -1,5 +1,6 @@
import type { Request } from 'express';
import type { IUser, AppConfig } from '@librechat/data-schemas';
import type { TEndpointOption } from 'librechat-data-provider';
import type { Request } from 'express';
/**
* LibreChat-specific request body type that extends Express Request body
@ -11,8 +12,10 @@ export type RequestBody = {
conversationId?: string;
parentMessageId?: string;
endpoint?: string;
endpointType?: string;
model?: string;
key?: string;
endpointOption?: Partial<TEndpointOption>;
};
export type ServerRequest = Request<unknown, unknown, RequestBody> & {

View file

@ -427,6 +427,35 @@ describe('OpenID Token Utilities', () => {
expect(result).toContain('User:');
});
it('should resolve LIBRECHAT_OPENID_ID_TOKEN and LIBRECHAT_OPENID_ACCESS_TOKEN to different values', () => {
const user: Partial<TUser> = {
id: 'user-123',
provider: 'openid',
openidId: 'oidc-sub-456',
email: 'test@example.com',
name: 'Test User',
federatedTokens: {
access_token: 'my-access-token',
id_token: 'my-id-token',
refresh_token: 'my-refresh-token',
expires_at: Math.floor(Date.now() / 1000) + 3600,
},
};
const tokenInfo = extractOpenIDTokenInfo(user);
expect(tokenInfo).not.toBeNull();
expect(tokenInfo!.accessToken).toBe('my-access-token');
expect(tokenInfo!.idToken).toBe('my-id-token');
expect(tokenInfo!.accessToken).not.toBe(tokenInfo!.idToken);
const input = 'ACCESS={{LIBRECHAT_OPENID_ACCESS_TOKEN}}, ID={{LIBRECHAT_OPENID_ID_TOKEN}}';
const result = processOpenIDPlaceholders(input, tokenInfo!);
expect(result).toBe('ACCESS=my-access-token, ID=my-id-token');
// Verify they are not the same value (the reported bug)
expect(result).not.toBe('ACCESS=my-access-token, ID=my-access-token');
});
it('should handle expired tokens correctly', () => {
const user: Partial<TUser> = {
id: 'user-123',

View file

@ -148,6 +148,7 @@ const anthropicModels = {
'claude-3.5-sonnet-latest': 200000,
'claude-haiku-4-5': 200000,
'claude-sonnet-4': 1000000,
'claude-sonnet-4-6': 1000000,
'claude-4': 200000,
'claude-opus-4': 200000,
'claude-opus-4-5': 200000,
@ -197,6 +198,8 @@ const moonshotModels = {
'moonshot.kimi-k2.5': 262144,
'moonshot.kimi-k2-thinking': 262144,
'moonshot.kimi-k2-0711': 131072,
'moonshotai.kimi': 262144,
'moonshotai.kimi-k2.5': 262144,
};
const metaModels = {
@ -308,6 +311,11 @@ const amazonModels = {
'nova-premier': 995000, // -5000 from max
};
const openAIBedrockModels = {
'openai.gpt-oss-20b': 128000,
'openai.gpt-oss-120b': 128000,
};
const bedrockModels = {
...anthropicModels,
...mistralModels,
@ -317,6 +325,7 @@ const bedrockModels = {
...metaModels,
...ai21Models,
...amazonModels,
...openAIBedrockModels,
};
const xAIModels = {
@ -393,6 +402,7 @@ const anthropicMaxOutputs = {
'claude-3-opus': 4096,
'claude-haiku-4-5': 64000,
'claude-sonnet-4': 64000,
'claude-sonnet-4-6': 64000,
'claude-opus-4': 32000,
'claude-opus-4-5': 64000,
'claude-opus-4-6': 128000,

View file

@ -1,4 +1,4 @@
import { forwardRef, ReactNode, Ref } from 'react';
import { forwardRef, isValidElement, ReactNode, Ref } from 'react';
import {
OGDialogTitle,
OGDialogClose,
@ -19,13 +19,39 @@ type SelectionProps = {
isLoading?: boolean;
};
/**
* Type guard to check if selection is a legacy SelectionProps object
*/
function isSelectionProps(selection: unknown): selection is SelectionProps {
return (
typeof selection === 'object' &&
selection !== null &&
!isValidElement(selection) &&
('selectHandler' in selection ||
'selectClasses' in selection ||
'selectText' in selection ||
'isLoading' in selection)
);
}
type DialogTemplateProps = {
title: string;
description?: string;
main?: ReactNode;
buttons?: ReactNode;
leftButtons?: ReactNode;
selection?: SelectionProps;
/**
* Selection button configuration. Can be either:
* - An object with selectHandler, selectClasses, selectText, isLoading (legacy)
* - A ReactNode for custom selection component
* @example
* // Legacy usage
* selection={{ selectHandler: () => {}, selectText: 'Confirm' }}
* @example
* // Custom component
* selection={<Button onClick={handleConfirm}>Confirm</Button>}
*/
selection?: SelectionProps | ReactNode;
className?: string;
overlayClassName?: string;
headerClassName?: string;
@ -49,14 +75,39 @@ const OGDialogTemplate = forwardRef((props: DialogTemplateProps, ref: Ref<HTMLDi
mainClassName,
headerClassName,
footerClassName,
showCloseButton,
showCloseButton = false,
overlayClassName,
showCancelButton = true,
} = props;
const { selectHandler, selectClasses, selectText, isLoading } = selection || {};
const isLegacySelection = isSelectionProps(selection);
const { selectHandler, selectClasses, selectText, isLoading } = isLegacySelection
? selection
: {};
const defaultSelect =
'bg-gray-800 text-white transition-colors hover:bg-gray-700 disabled:cursor-not-allowed disabled:opacity-50 dark:bg-gray-200 dark:text-gray-800 dark:hover:bg-gray-200';
let selectionContent = null;
if (isLegacySelection) {
selectionContent = (
<OGDialogClose
onClick={selectHandler}
disabled={isLoading}
className={`${
selectClasses ?? defaultSelect
} flex h-10 items-center justify-center rounded-lg border-none px-4 py-2 text-sm disabled:opacity-80 max-sm:order-first max-sm:w-full sm:order-none`}
>
{isLoading === true ? (
<Spinner className="size-4 text-text-primary" />
) : (
(selectText as React.JSX.Element)
)}
</OGDialogClose>
);
} else if (selection) {
selectionContent = selection;
}
return (
<OGDialogContent
overlayClassName={overlayClassName}
@ -75,38 +126,18 @@ const OGDialogTemplate = forwardRef((props: DialogTemplateProps, ref: Ref<HTMLDi
</OGDialogHeader>
<div className={cn('px-0 py-2', mainClassName)}>{main != null ? main : null}</div>
<OGDialogFooter className={footerClassName}>
<div>
{leftButtons != null ? (
<div className="mt-3 flex h-auto gap-3 max-sm:w-full max-sm:flex-col sm:mt-0 sm:flex-row">
{leftButtons}
</div>
) : null}
</div>
<div className="flex h-auto gap-3 max-sm:w-full max-sm:flex-col sm:flex-row">
{showCancelButton && (
<OGDialogClose asChild>
<Button variant="outline" aria-label={localize('com_ui_cancel')}>
{localize('com_ui_cancel')}
</Button>
</OGDialogClose>
)}
{buttons != null ? buttons : null}
{selection ? (
<OGDialogClose
onClick={selectHandler}
disabled={isLoading}
className={`${
selectClasses ?? defaultSelect
} flex h-10 items-center justify-center rounded-lg border-none px-4 py-2 text-sm disabled:opacity-80 max-sm:order-first max-sm:w-full sm:order-none`}
>
{isLoading === true ? (
<Spinner className="size-4 text-white" />
) : (
(selectText as React.JSX.Element)
)}
</OGDialogClose>
) : null}
</div>
{leftButtons != null ? (
<div className="mr-auto flex flex-row gap-2">{leftButtons}</div>
) : null}
{showCancelButton && (
<OGDialogClose asChild>
<Button variant="outline" aria-label={localize('com_ui_cancel')}>
{localize('com_ui_cancel')}
</Button>
</OGDialogClose>
)}
{buttons != null ? buttons : null}
{selectionContent}
</OGDialogFooter>
</OGDialogContent>
);

View file

@ -14,6 +14,7 @@ interface RadioProps {
disabled?: boolean;
className?: string;
fullWidth?: boolean;
'aria-labelledby'?: string;
}
const Radio = memo(function Radio({
@ -23,6 +24,7 @@ const Radio = memo(function Radio({
disabled = false,
className = '',
fullWidth = false,
'aria-labelledby': ariaLabelledBy,
}: RadioProps) {
const localize = useLocalize();
const buttonRefs = useRef<(HTMLButtonElement | null)[]>([]);
@ -79,6 +81,7 @@ const Radio = memo(function Radio({
<div
className="relative inline-flex items-center rounded-lg bg-muted p-1 opacity-50"
role="radiogroup"
aria-labelledby={ariaLabelledBy}
>
<span className="px-4 py-2 text-xs text-muted-foreground">
{localize('com_ui_no_options')}
@ -93,6 +96,7 @@ const Radio = memo(function Radio({
<div
className={`relative ${fullWidth ? 'flex' : 'inline-flex'} items-center rounded-lg bg-muted p-1 ${className}`}
role="radiogroup"
aria-labelledby={ariaLabelledBy}
>
{selectedIndex >= 0 && isMounted && (
<div

View file

@ -16,7 +16,7 @@ const Theme = ({ theme, onChange }: { theme: string; onChange: (value: string) =
const themeIcons: Record<ThemeType, JSX.Element> = {
system: <Monitor aria-hidden="true" />,
dark: <Moon color="white" aria-hidden="true" />,
dark: <Moon aria-hidden="true" />,
light: <Sun aria-hidden="true" />,
};
@ -35,7 +35,7 @@ const Theme = ({ theme, onChange }: { theme: string; onChange: (value: string) =
return (
<button
className="flex items-center gap-2 rounded-lg p-2 transition-colors hover:bg-surface-hover focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-blue-600 focus-visible:ring-offset-2 dark:focus-visible:ring-0"
className="flex items-center gap-2 rounded-lg p-2 text-text-primary transition-colors hover:bg-surface-hover focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-blue-600 focus-visible:ring-offset-2 dark:focus-visible:ring-0"
aria-label={localize('com_ui_toggle_theme')}
aria-keyshortcuts="Ctrl+Shift+T"
onClick={(e) => {
@ -77,13 +77,6 @@ const ThemeSelector = ({ returnThemeOnly }: { returnThemeOnly?: boolean }) => {
[setTheme, localize],
);
useEffect(() => {
if (theme === 'system') {
const prefersDarkScheme = window.matchMedia('(prefers-color-scheme: dark)').matches;
setTheme(prefersDarkScheme ? 'dark' : 'light');
}
}, [theme, setTheme]);
useEffect(() => {
if (announcement) {
const timeout = setTimeout(() => setAnnouncement(''), 1000);

View file

@ -1,17 +1,18 @@
// This file is kept for backward compatibility but is no longer used internally.
// Theme state is now managed via React useState + localStorage in ThemeProvider.
import { atomWithStorage } from 'jotai/utils';
import { IThemeRGB } from '../types';
/**
* Atom for storing the theme mode (light/dark/system) in localStorage
* Key: 'color-theme'
* @deprecated Use ThemeContext instead. This atom is no longer used internally.
*/
export const themeModeAtom = atomWithStorage<string>('color-theme', 'system', undefined, {
getOnInit: true,
});
/**
* Atom for storing custom theme colors in localStorage
* Key: 'theme-colors'
* @deprecated Use ThemeContext instead. This atom is no longer used internally.
*/
export const themeColorsAtom = atomWithStorage<IThemeRGB | undefined>(
'theme-colors',
@ -23,8 +24,7 @@ export const themeColorsAtom = atomWithStorage<IThemeRGB | undefined>(
);
/**
* Atom for storing the theme name in localStorage
* Key: 'theme-name'
* @deprecated Use ThemeContext instead. This atom is no longer used internally.
*/
export const themeNameAtom = atomWithStorage<string | undefined>(
'theme-name',

View file

@ -1,8 +1,10 @@
import React, { createContext, useContext, useEffect, useMemo, useCallback, useRef } from 'react';
import { useAtom } from 'jotai';
import React, { createContext, useContext, useEffect, useMemo, useCallback, useState, useRef } from 'react';
import { IThemeRGB } from '../types';
import applyTheme from '../utils/applyTheme';
import { themeModeAtom, themeColorsAtom, themeNameAtom } from '../atoms/themeAtoms';
const THEME_KEY = 'color-theme';
const THEME_COLORS_KEY = 'theme-colors';
const THEME_NAME_KEY = 'theme-name';
type ThemeContextType = {
theme: string; // 'light' | 'dark' | 'system'
@ -40,6 +42,70 @@ export const isDark = (theme: string): boolean => {
return theme === 'dark';
};
/**
* Validate that a parsed value looks like an IThemeRGB object
*/
const isValidThemeColors = (value: unknown): value is IThemeRGB => {
if (typeof value !== 'object' || value === null || Array.isArray(value)) {
return false;
}
for (const key of Object.keys(value)) {
const val = (value as Record<string, unknown>)[key];
if (val !== undefined && typeof val !== 'string') {
return false;
}
}
return true;
};
/**
* Get initial theme from localStorage or default to 'system'
*/
const getInitialTheme = (): string => {
if (typeof window === 'undefined') return 'system';
try {
const stored = localStorage.getItem(THEME_KEY);
if (stored && ['light', 'dark', 'system'].includes(stored)) {
return stored;
}
} catch {
// localStorage not available
}
return 'system';
};
/**
* Get initial theme colors from localStorage
*/
const getInitialThemeColors = (): IThemeRGB | undefined => {
if (typeof window === 'undefined') return undefined;
try {
const stored = localStorage.getItem(THEME_COLORS_KEY);
if (stored) {
const parsed = JSON.parse(stored);
if (isValidThemeColors(parsed)) {
return parsed;
}
}
} catch {
// localStorage not available or invalid JSON
}
return undefined;
};
/**
* Get initial theme name from localStorage
*/
const getInitialThemeName = (): string | undefined => {
if (typeof window === 'undefined') return undefined;
try {
return localStorage.getItem(THEME_NAME_KEY) || undefined;
} catch {
// localStorage not available
}
return undefined;
};
/**
* ThemeProvider component that handles both dark/light mode switching
* and dynamic color themes via CSS variables with localStorage persistence
@ -50,102 +116,128 @@ export function ThemeProvider({
themeName: propThemeName,
initialTheme,
}: ThemeProviderProps) {
// Use jotai atoms for persistent state
const [theme, setTheme] = useAtom(themeModeAtom);
const [storedThemeRGB, setStoredThemeRGB] = useAtom(themeColorsAtom);
const [storedThemeName, setStoredThemeName] = useAtom(themeNameAtom);
const [theme, setThemeState] = useState<string>(getInitialTheme);
const [themeRGB, setThemeRGBState] = useState<IThemeRGB | undefined>(getInitialThemeColors);
const [themeName, setThemeNameState] = useState<string | undefined>(getInitialThemeName);
// Track if props have been initialized
const propsInitialized = useRef(false);
const initialized = useRef(false);
const setTheme = useCallback((newTheme: string) => {
setThemeState(newTheme);
if (typeof window === 'undefined') return;
try {
localStorage.setItem(THEME_KEY, newTheme);
} catch {
// localStorage not available
}
}, []);
const setThemeRGB = useCallback((colors?: IThemeRGB) => {
setThemeRGBState(colors);
if (typeof window === 'undefined') return;
try {
if (colors) {
localStorage.setItem(THEME_COLORS_KEY, JSON.stringify(colors));
} else {
localStorage.removeItem(THEME_COLORS_KEY);
}
} catch {
// localStorage not available
}
}, []);
const setThemeName = useCallback((name?: string) => {
setThemeNameState(name);
if (typeof window === 'undefined') return;
try {
if (name) {
localStorage.setItem(THEME_NAME_KEY, name);
} else {
localStorage.removeItem(THEME_NAME_KEY);
}
} catch {
// localStorage not available
}
}, []);
// Initialize from props only once on mount
useEffect(() => {
if (!propsInitialized.current) {
propsInitialized.current = true;
if (initialized.current) return;
initialized.current = true;
// Set initial theme if provided
if (initialTheme) {
setTheme(initialTheme);
}
// Set initial theme colors if provided
if (propThemeRGB) {
setStoredThemeRGB(propThemeRGB);
}
// Set initial theme name if provided
if (propThemeName) {
setStoredThemeName(propThemeName);
}
// Set initial theme if provided
if (initialTheme) {
setTheme(initialTheme);
}
}, [initialTheme, propThemeRGB, propThemeName, setTheme, setStoredThemeRGB, setStoredThemeName]);
// Set initial theme colors if provided
if (propThemeRGB) {
setThemeRGB(propThemeRGB);
}
// Set initial theme name if provided
if (propThemeName) {
setThemeName(propThemeName);
}
}, [initialTheme, propThemeRGB, propThemeName, setTheme, setThemeRGB, setThemeName]);
// Apply class-based dark mode
const applyThemeMode = useCallback((rawTheme: string) => {
const applyThemeMode = useCallback((currentTheme: string) => {
const root = window.document.documentElement;
const darkMode = isDark(rawTheme);
const darkMode = isDark(currentTheme);
root.classList.remove(darkMode ? 'light' : 'dark');
root.classList.add(darkMode ? 'dark' : 'light');
}, []);
// Handle system theme changes
useEffect(() => {
const mediaQuery = window.matchMedia('(prefers-color-scheme: dark)');
const changeThemeOnSystemChange = () => {
if (theme === 'system') {
applyThemeMode('system');
}
};
mediaQuery.addEventListener('change', changeThemeOnSystemChange);
return () => {
mediaQuery.removeEventListener('change', changeThemeOnSystemChange);
};
}, [theme, applyThemeMode]);
// Apply dark/light mode class
// Apply theme mode whenever theme changes
useEffect(() => {
applyThemeMode(theme);
}, [theme, applyThemeMode]);
// Listen for system theme changes when theme is 'system'
useEffect(() => {
if (theme !== 'system') return;
const mediaQuery = window.matchMedia('(prefers-color-scheme: dark)');
const handleChange = () => {
applyThemeMode('system');
};
mediaQuery.addEventListener('change', handleChange);
return () => mediaQuery.removeEventListener('change', handleChange);
}, [theme, applyThemeMode]);
// Apply dynamic color theme
useEffect(() => {
if (storedThemeRGB) {
applyTheme(storedThemeRGB);
if (themeRGB) {
applyTheme(themeRGB);
}
}, [storedThemeRGB]);
}, [themeRGB]);
// Reset theme function
const resetTheme = useCallback(() => {
setTheme('system');
setStoredThemeRGB(undefined);
setStoredThemeName(undefined);
setThemeRGB(undefined);
setThemeName(undefined);
// Remove any custom CSS variables
const root = document.documentElement;
const customProps = Array.from(root.style).filter((prop) => prop.startsWith('--'));
customProps.forEach((prop) => root.style.removeProperty(prop));
}, [setTheme, setStoredThemeRGB, setStoredThemeName]);
}, [setTheme, setThemeRGB, setThemeName]);
const value = useMemo(
() => ({
theme,
setTheme,
themeRGB: storedThemeRGB,
setThemeRGB: setStoredThemeRGB,
themeName: storedThemeName,
setThemeName: setStoredThemeName,
themeRGB,
setThemeRGB,
themeName,
setThemeName,
resetTheme,
}),
[
theme,
setTheme,
storedThemeRGB,
setStoredThemeRGB,
storedThemeName,
setStoredThemeName,
resetTheme,
],
[theme, setTheme, themeRGB, setThemeRGB, themeName, setThemeName, resetTheme],
);
return <ThemeContext.Provider value={value}>{children}</ThemeContext.Provider>;

View file

@ -39,7 +39,7 @@
},
"homepage": "https://librechat.ai",
"dependencies": {
"axios": "^1.12.1",
"axios": "^1.13.5",
"dayjs": "^1.11.13",
"js-yaml": "^4.1.1",
"zod": "^3.22.4"

View file

@ -6,7 +6,7 @@
"": {
"name": "librechat-data-provider/react-query",
"dependencies": {
"axios": "^1.12.1"
"axios": "^1.13.5"
}
},
"node_modules/asynckit": {
@ -16,13 +16,13 @@
"license": "MIT"
},
"node_modules/axios": {
"version": "1.12.1",
"resolved": "https://registry.npmjs.org/axios/-/axios-1.12.1.tgz",
"integrity": "sha512-Kn4kbSXpkFHCGE6rBFNwIv0GQs4AvDT80jlveJDKFxjbTYMUeB4QtsdPCv6H8Cm19Je7IU6VFtRl2zWZI0rudQ==",
"version": "1.13.5",
"resolved": "https://registry.npmjs.org/axios/-/axios-1.13.5.tgz",
"integrity": "sha512-cz4ur7Vb0xS4/KUN0tPWe44eqxrIu31me+fbang3ijiNscE129POzipJJA6zniq2C/Z6sJCjMimjS8Lc/GAs8Q==",
"license": "MIT",
"dependencies": {
"follow-redirects": "^1.15.6",
"form-data": "^4.0.4",
"follow-redirects": "^1.15.11",
"form-data": "^4.0.5",
"proxy-from-env": "^1.1.0"
}
},
@ -140,9 +140,9 @@
}
},
"node_modules/form-data": {
"version": "4.0.4",
"resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.4.tgz",
"integrity": "sha512-KrGhL9Q4zjj0kiUt5OO4Mr/A/jlI2jDYs5eHBpYHPcBEVSiipAvn2Ko2HnPe20rmcuuvMHNdZFp+4IlGTMF0Ow==",
"version": "4.0.5",
"resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.5.tgz",
"integrity": "sha512-8RipRLol37bNs2bhoV67fiTEvdTrbMUYcFTiy3+wuuOnUog2QBHCZWXDRijWQfAkhBj2Uf5UnVaiWwA5vdd82w==",
"license": "MIT",
"dependencies": {
"asynckit": "^0.4.0",

View file

@ -5,6 +5,6 @@
"module": "./index.es.js",
"types": "../dist/types/react-query/index.d.ts",
"dependencies": {
"axios": "^1.12.1"
"axios": "^1.13.5"
}
}

View file

@ -459,6 +459,82 @@ describe('ActionRequest', () => {
await expect(actionRequest.execute()).rejects.toThrow('Unsupported HTTP method: invalid');
});
describe('SSRF-safe agent passthrough', () => {
beforeEach(() => {
mockedAxios.get.mockResolvedValue({ data: { success: true } });
mockedAxios.post.mockResolvedValue({ data: { success: true } });
});
it('should pass httpAgent and httpsAgent to axios.create when provided', async () => {
const mockHttpAgent = { keepAlive: true };
const mockHttpsAgent = { keepAlive: true };
const actionRequest = new ActionRequest(
'https://example.com',
'/test',
'GET',
'testOp',
false,
'application/json',
);
const executor = actionRequest.createExecutor();
executor.setParams({ key: 'value' });
await executor.execute({ httpAgent: mockHttpAgent, httpsAgent: mockHttpsAgent });
expect(mockedAxios.create).toHaveBeenCalledWith(
expect.objectContaining({
httpAgent: mockHttpAgent,
httpsAgent: mockHttpsAgent,
maxRedirects: 0,
}),
);
});
it('should not include agent keys when no options are provided', async () => {
const actionRequest = new ActionRequest(
'https://example.com',
'/test',
'GET',
'testOp',
false,
'application/json',
);
const executor = actionRequest.createExecutor();
executor.setParams({ key: 'value' });
await executor.execute();
const createArg = mockedAxios.create.mock.calls[
mockedAxios.create.mock.calls.length - 1
][0] as Record<string, unknown>;
expect(createArg).not.toHaveProperty('httpAgent');
expect(createArg).not.toHaveProperty('httpsAgent');
});
it('should pass agents through for POST requests', async () => {
const mockAgent = { ssrf: true };
const actionRequest = new ActionRequest(
'https://example.com',
'/test',
'POST',
'testOp',
false,
'application/json',
);
const executor = actionRequest.createExecutor();
executor.setParams({ body: 'data' });
await executor.execute({ httpAgent: mockAgent, httpsAgent: mockAgent });
expect(mockedAxios.create).toHaveBeenCalledWith(
expect.objectContaining({
httpAgent: mockAgent,
httpsAgent: mockAgent,
}),
);
expect(mockedAxios.post).toHaveBeenCalled();
});
});
describe('ActionRequest Concurrent Execution', () => {
beforeEach(() => {
jest.clearAllMocks();

View file

@ -46,6 +46,30 @@ describe('supportsAdaptiveThinking', () => {
expect(supportsAdaptiveThinking('claude-opus-4-0')).toBe(false);
});
test('should return true for claude-sonnet-4-6', () => {
expect(supportsAdaptiveThinking('claude-sonnet-4-6')).toBe(true);
});
test('should return true for claude-sonnet-4.6', () => {
expect(supportsAdaptiveThinking('claude-sonnet-4.6')).toBe(true);
});
test('should return true for claude-sonnet-4-7 (future)', () => {
expect(supportsAdaptiveThinking('claude-sonnet-4-7')).toBe(true);
});
test('should return true for anthropic.claude-sonnet-4-6 (Bedrock)', () => {
expect(supportsAdaptiveThinking('anthropic.claude-sonnet-4-6')).toBe(true);
});
test('should return true for us.anthropic.claude-sonnet-4-6 (cross-region Bedrock)', () => {
expect(supportsAdaptiveThinking('us.anthropic.claude-sonnet-4-6')).toBe(true);
});
test('should return true for claude-4-6-sonnet (alternate naming)', () => {
expect(supportsAdaptiveThinking('claude-4-6-sonnet')).toBe(true);
});
test('should return false for claude-sonnet-4-5', () => {
expect(supportsAdaptiveThinking('claude-sonnet-4-5')).toBe(false);
});
@ -104,6 +128,14 @@ describe('supportsContext1m', () => {
expect(supportsContext1m('claude-sonnet-4-5')).toBe(true);
});
test('should return true for claude-sonnet-4-6', () => {
expect(supportsContext1m('claude-sonnet-4-6')).toBe(true);
});
test('should return true for anthropic.claude-sonnet-4-6 (Bedrock)', () => {
expect(supportsContext1m('anthropic.claude-sonnet-4-6')).toBe(true);
});
test('should return true for claude-sonnet-5 (future)', () => {
expect(supportsContext1m('claude-sonnet-5')).toBe(true);
});
@ -237,14 +269,42 @@ describe('bedrockInputParser', () => {
]);
});
test('should match anthropic.claude-4-7-sonnet model with 1M context header', () => {
test('should match anthropic.claude-sonnet-4-6 with adaptive thinking and 1M context header', () => {
const input = {
model: 'anthropic.claude-sonnet-4-6',
};
const result = bedrockInputParser.parse(input) as Record<string, unknown>;
const additionalFields = result.additionalModelRequestFields as Record<string, unknown>;
expect(additionalFields.thinking).toEqual({ type: 'adaptive' });
expect(additionalFields.thinkingBudget).toBeUndefined();
expect(additionalFields.anthropic_beta).toEqual([
'output-128k-2025-02-19',
'context-1m-2025-08-07',
]);
});
test('should match us.anthropic.claude-sonnet-4-6 with adaptive thinking and 1M context header', () => {
const input = {
model: 'us.anthropic.claude-sonnet-4-6',
};
const result = bedrockInputParser.parse(input) as Record<string, unknown>;
const additionalFields = result.additionalModelRequestFields as Record<string, unknown>;
expect(additionalFields.thinking).toEqual({ type: 'adaptive' });
expect(additionalFields.thinkingBudget).toBeUndefined();
expect(additionalFields.anthropic_beta).toEqual([
'output-128k-2025-02-19',
'context-1m-2025-08-07',
]);
});
test('should match anthropic.claude-4-7-sonnet model with adaptive thinking and 1M context header', () => {
const input = {
model: 'anthropic.claude-4-7-sonnet',
};
const result = bedrockInputParser.parse(input) as Record<string, unknown>;
const additionalFields = result.additionalModelRequestFields as Record<string, unknown>;
expect(additionalFields.thinking).toBe(true);
expect(additionalFields.thinkingBudget).toBe(2000);
expect(additionalFields.thinking).toEqual({ type: 'adaptive' });
expect(additionalFields.thinkingBudget).toBeUndefined();
expect(additionalFields.anthropic_beta).toEqual([
'output-128k-2025-02-19',
'context-1m-2025-08-07',

View file

@ -1,4 +1,4 @@
import { replaceSpecialVars, parseCompactConvo, parseTextParts } from '../src/parsers';
import { replaceSpecialVars, parseConvo, parseCompactConvo, parseTextParts } from '../src/parsers';
import { specialVariables } from '../src/config';
import { EModelEndpoint } from '../src/schemas';
import { ContentTypes } from '../src/types/runs';
@ -262,6 +262,257 @@ describe('parseCompactConvo', () => {
});
});
describe('parseConvo - defaultParamsEndpoint', () => {
test('should strip maxOutputTokens for custom endpoint without defaultParamsEndpoint', () => {
const conversation: Partial<TConversation> = {
model: 'anthropic/claude-opus-4.5',
temperature: 0.7,
maxOutputTokens: 8192,
maxContextTokens: 50000,
};
const result = parseConvo({
endpoint: 'MyCustomEndpoint' as EModelEndpoint,
endpointType: EModelEndpoint.custom,
conversation,
});
expect(result).not.toBeNull();
expect(result?.temperature).toBe(0.7);
expect(result?.maxContextTokens).toBe(50000);
expect(result?.maxOutputTokens).toBeUndefined();
});
test('should preserve maxOutputTokens when defaultParamsEndpoint is anthropic', () => {
const conversation: Partial<TConversation> = {
model: 'anthropic/claude-opus-4.5',
temperature: 0.7,
maxOutputTokens: 8192,
topP: 0.9,
topK: 40,
maxContextTokens: 50000,
};
const result = parseConvo({
endpoint: 'MyCustomEndpoint' as EModelEndpoint,
endpointType: EModelEndpoint.custom,
conversation,
defaultParamsEndpoint: EModelEndpoint.anthropic,
});
expect(result).not.toBeNull();
expect(result?.model).toBe('anthropic/claude-opus-4.5');
expect(result?.temperature).toBe(0.7);
expect(result?.maxOutputTokens).toBe(8192);
expect(result?.topP).toBe(0.9);
expect(result?.topK).toBe(40);
expect(result?.maxContextTokens).toBe(50000);
});
test('should strip OpenAI-specific fields when defaultParamsEndpoint is anthropic', () => {
const conversation: Partial<TConversation> = {
model: 'anthropic/claude-opus-4.5',
temperature: 0.7,
max_tokens: 4096,
top_p: 0.9,
presence_penalty: 0.5,
frequency_penalty: 0.3,
};
const result = parseConvo({
endpoint: 'MyCustomEndpoint' as EModelEndpoint,
endpointType: EModelEndpoint.custom,
conversation,
defaultParamsEndpoint: EModelEndpoint.anthropic,
});
expect(result).not.toBeNull();
expect(result?.temperature).toBe(0.7);
expect(result?.max_tokens).toBeUndefined();
expect(result?.top_p).toBeUndefined();
expect(result?.presence_penalty).toBeUndefined();
expect(result?.frequency_penalty).toBeUndefined();
});
test('should preserve max_tokens when defaultParamsEndpoint is not set (OpenAI default)', () => {
const conversation: Partial<TConversation> = {
model: 'gpt-4o',
temperature: 0.7,
max_tokens: 4096,
top_p: 0.9,
};
const result = parseConvo({
endpoint: 'MyCustomEndpoint' as EModelEndpoint,
endpointType: EModelEndpoint.custom,
conversation,
});
expect(result).not.toBeNull();
expect(result?.max_tokens).toBe(4096);
expect(result?.top_p).toBe(0.9);
});
test('should preserve Google-specific fields when defaultParamsEndpoint is google', () => {
const conversation: Partial<TConversation> = {
model: 'gemini-pro',
temperature: 0.7,
maxOutputTokens: 8192,
topP: 0.9,
topK: 40,
};
const result = parseConvo({
endpoint: 'MyCustomEndpoint' as EModelEndpoint,
endpointType: EModelEndpoint.custom,
conversation,
defaultParamsEndpoint: EModelEndpoint.google,
});
expect(result).not.toBeNull();
expect(result?.maxOutputTokens).toBe(8192);
expect(result?.topP).toBe(0.9);
expect(result?.topK).toBe(40);
});
test('should not strip fields from non-custom endpoints that already have a schema', () => {
const conversation: Partial<TConversation> = {
model: 'gpt-4o',
temperature: 0.7,
max_tokens: 4096,
top_p: 0.9,
};
const result = parseConvo({
endpoint: EModelEndpoint.openAI,
conversation,
defaultParamsEndpoint: EModelEndpoint.anthropic,
});
expect(result).not.toBeNull();
expect(result?.max_tokens).toBe(4096);
expect(result?.top_p).toBe(0.9);
});
test('should not carry bedrock region to custom endpoint without defaultParamsEndpoint', () => {
const conversation: Partial<TConversation> = {
model: 'gpt-4o',
temperature: 0.7,
region: 'us-east-1',
};
const result = parseConvo({
endpoint: 'MyCustomEndpoint' as EModelEndpoint,
endpointType: EModelEndpoint.custom,
conversation,
});
expect(result).not.toBeNull();
expect(result?.temperature).toBe(0.7);
expect(result?.region).toBeUndefined();
});
test('should fall back to endpointType schema when defaultParamsEndpoint is invalid', () => {
const conversation: Partial<TConversation> = {
model: 'gpt-4o',
temperature: 0.7,
max_tokens: 4096,
maxOutputTokens: 8192,
};
const result = parseConvo({
endpoint: 'MyCustomEndpoint' as EModelEndpoint,
endpointType: EModelEndpoint.custom,
conversation,
defaultParamsEndpoint: 'nonexistent_endpoint',
});
expect(result).not.toBeNull();
expect(result?.max_tokens).toBe(4096);
expect(result?.maxOutputTokens).toBeUndefined();
});
});
describe('parseCompactConvo - defaultParamsEndpoint', () => {
test('should strip maxOutputTokens for custom endpoint without defaultParamsEndpoint', () => {
const conversation: Partial<TConversation> = {
model: 'anthropic/claude-opus-4.5',
temperature: 0.7,
maxOutputTokens: 8192,
};
const result = parseCompactConvo({
endpoint: 'MyCustomEndpoint' as EModelEndpoint,
endpointType: EModelEndpoint.custom,
conversation,
});
expect(result).not.toBeNull();
expect(result?.temperature).toBe(0.7);
expect(result?.maxOutputTokens).toBeUndefined();
});
test('should preserve maxOutputTokens when defaultParamsEndpoint is anthropic', () => {
const conversation: Partial<TConversation> = {
model: 'anthropic/claude-opus-4.5',
temperature: 0.7,
maxOutputTokens: 8192,
topP: 0.9,
maxContextTokens: 50000,
};
const result = parseCompactConvo({
endpoint: 'MyCustomEndpoint' as EModelEndpoint,
endpointType: EModelEndpoint.custom,
conversation,
defaultParamsEndpoint: EModelEndpoint.anthropic,
});
expect(result).not.toBeNull();
expect(result?.maxOutputTokens).toBe(8192);
expect(result?.topP).toBe(0.9);
expect(result?.maxContextTokens).toBe(50000);
});
test('should strip iconURL even when defaultParamsEndpoint is set', () => {
const conversation: Partial<TConversation> = {
model: 'anthropic/claude-opus-4.5',
iconURL: 'https://malicious.com/track.png',
maxOutputTokens: 8192,
};
const result = parseCompactConvo({
endpoint: 'MyCustomEndpoint' as EModelEndpoint,
endpointType: EModelEndpoint.custom,
conversation,
defaultParamsEndpoint: EModelEndpoint.anthropic,
});
expect(result).not.toBeNull();
expect(result?.['iconURL']).toBeUndefined();
expect(result?.maxOutputTokens).toBe(8192);
});
test('should fall back to endpointType when defaultParamsEndpoint is null', () => {
const conversation: Partial<TConversation> = {
model: 'gpt-4o',
max_tokens: 4096,
maxOutputTokens: 8192,
};
const result = parseCompactConvo({
endpoint: 'MyCustomEndpoint' as EModelEndpoint,
endpointType: EModelEndpoint.custom,
conversation,
defaultParamsEndpoint: null,
});
expect(result).not.toBeNull();
expect(result?.max_tokens).toBe(4096);
expect(result?.maxOutputTokens).toBeUndefined();
});
});
describe('parseTextParts', () => {
test('should concatenate text parts', () => {
const parts: TMessageContentParts[] = [

View file

@ -283,7 +283,7 @@ class RequestExecutor {
return this;
}
async execute() {
async execute(options?: { httpAgent?: unknown; httpsAgent?: unknown }) {
const url = createURL(this.config.domain, this.path);
const headers: Record<string, string> = {
...this.authHeaders,
@ -300,10 +300,15 @@ class RequestExecutor {
*
* By setting maxRedirects: 0, we prevent this attack vector.
* The action will receive the redirect response (3xx) instead of following it.
*
* SECURITY: When httpAgent/httpsAgent are provided (SSRF-safe agents), they validate
* the DNS-resolved IP at TCP connect time, preventing TOCTOU DNS rebinding attacks.
*/
const axios = _axios.create({
maxRedirects: 0,
validateStatus: (status) => status >= 200 && status < 400, // Accept 3xx but don't follow
validateStatus: (status) => status >= 200 && status < 400,
...(options?.httpAgent != null ? { httpAgent: options.httpAgent } : {}),
...(options?.httpsAgent != null ? { httpsAgent: options.httpsAgent } : {}),
});
// Initialize separate containers for query and body parameters.

View file

@ -181,6 +181,11 @@ export const cancelMCPOAuth = (serverName: string) => {
return `${BASE_URL}/api/mcp/oauth/cancel/${serverName}`;
};
export const mcpOAuthBind = (serverName: string) => `${BASE_URL}/api/mcp/${serverName}/oauth/bind`;
export const actionOAuthBind = (actionId: string) =>
`${BASE_URL}/api/actions/${actionId}/oauth/bind`;
export const config = () => `${BASE_URL}/api/config`;
export const prompts = () => `${BASE_URL}/api/prompts`;

View file

@ -35,27 +35,34 @@ function parseOpusVersion(model: string): { major: number; minor: number } | nul
return null;
}
/** Extracts sonnet major version from both naming formats */
function parseSonnetVersion(model: string): number | null {
const nameFirst = model.match(/claude-sonnet[-.]?(\d+)/);
/** Extracts sonnet major/minor version from both naming formats.
* Uses single-digit minor capture to avoid matching date suffixes (e.g., -20250514). */
function parseSonnetVersion(model: string): { major: number; minor: number } | null {
const nameFirst = model.match(/claude-sonnet[-.]?(\d+)(?:[-.](\d)(?!\d))?/);
if (nameFirst) {
return parseInt(nameFirst[1], 10);
return {
major: parseInt(nameFirst[1], 10),
minor: nameFirst[2] != null ? parseInt(nameFirst[2], 10) : 0,
};
}
const numFirst = model.match(/claude-(\d+)(?:[-.]?\d+)?-sonnet/);
const numFirst = model.match(/claude-(\d+)(?:[-.](\d)(?!\d))?-sonnet/);
if (numFirst) {
return parseInt(numFirst[1], 10);
return {
major: parseInt(numFirst[1], 10),
minor: numFirst[2] != null ? parseInt(numFirst[2], 10) : 0,
};
}
return null;
}
/** Checks if a model supports adaptive thinking (Opus 4.6+, Sonnet 5+) */
/** Checks if a model supports adaptive thinking (Opus 4.6+, Sonnet 4.6+) */
export function supportsAdaptiveThinking(model: string): boolean {
const opus = parseOpusVersion(model);
if (opus && (opus.major > 4 || (opus.major === 4 && opus.minor >= 6))) {
return true;
}
const sonnet = parseSonnetVersion(model);
if (sonnet != null && sonnet >= 5) {
if (sonnet != null && (sonnet.major > 4 || (sonnet.major === 4 && sonnet.minor >= 6))) {
return true;
}
return false;
@ -64,7 +71,7 @@ export function supportsAdaptiveThinking(model: string): boolean {
/** Checks if a model qualifies for the context-1m beta header (Sonnet 4+, Opus 4.6+, Opus 5+) */
export function supportsContext1m(model: string): boolean {
const sonnet = parseSonnetVersion(model);
if (sonnet != null && sonnet >= 4) {
if (sonnet != null && sonnet.major >= 4) {
return true;
}
const opus = parseOpusVersion(model);

View file

@ -1133,6 +1133,7 @@ const sharedOpenAIModels = [
];
const sharedAnthropicModels = [
'claude-sonnet-4-6',
'claude-opus-4-6',
'claude-sonnet-4-5',
'claude-sonnet-4-5-20250929',
@ -1154,6 +1155,7 @@ const sharedAnthropicModels = [
];
export const bedrockModels = [
'anthropic.claude-sonnet-4-6',
'anthropic.claude-opus-4-6-v1',
'anthropic.claude-sonnet-4-5-20250929-v1:0',
'anthropic.claude-haiku-4-5-20251001-v1:0',
@ -1364,6 +1366,10 @@ export enum CacheKeys {
* Key for the config store namespace.
*/
CONFIG_STORE = 'CONFIG_STORE',
/**
* Key for the tool cache namespace (plugins, MCP tools, tool definitions).
*/
TOOL_CACHE = 'TOOL_CACHE',
/**
* Key for the roles cache.
*/
@ -1754,6 +1760,8 @@ export enum Constants {
mcp_all = 'sys__all__sys',
/** Unique value to indicate clearing MCP servers from UI state. For frontend use only. */
mcp_clear = 'sys__clear__sys',
/** Key suffix for non-spec user default tool storage */
spec_defaults_key = '__defaults__',
/**
* Unique value to indicate the MCP tool was added to an agent.
* This helps inform the UI if the mcp server was previously added.
@ -1904,3 +1912,14 @@ export function getEndpointField<
}
return config[property];
}
/** Resolves the `defaultParamsEndpoint` for a given endpoint from its custom params config */
export function getDefaultParamsEndpoint(
endpointsConfig: TEndpointsConfig | undefined | null,
endpoint: string | null | undefined,
): string | undefined {
if (!endpointsConfig || !endpoint) {
return undefined;
}
return endpointsConfig[endpoint]?.customParams?.defaultParamsEndpoint;
}

View file

@ -178,6 +178,14 @@ export const reinitializeMCPServer = (serverName: string) => {
return request.post(endpoints.mcpReinitialize(serverName));
};
export const bindMCPOAuth = (serverName: string): Promise<{ success: boolean }> => {
return request.post(endpoints.mcpOAuthBind(serverName));
};
export const bindActionOAuth = (actionId: string): Promise<{ success: boolean }> => {
return request.post(endpoints.actionOAuthBind(actionId));
};
export const getMCPConnectionStatus = (): Promise<q.MCPConnectionStatusResponse> => {
return request.get(endpoints.mcpConnectionStatus());
};

View file

@ -35,6 +35,7 @@ export type TModelSpec = {
webSearch?: boolean;
fileSearch?: boolean;
executeCode?: boolean;
artifacts?: string | boolean;
mcpServers?: string[];
};
@ -54,6 +55,7 @@ export const tModelSpecSchema = z.object({
webSearch: z.boolean().optional(),
fileSearch: z.boolean().optional(),
executeCode: z.boolean().optional(),
artifacts: z.union([z.string(), z.boolean()]).optional(),
mcpServers: z.array(z.string()).optional(),
});

View file

@ -952,6 +952,9 @@ export const paramSettings: Record<string, SettingsConfiguration | undefined> =
[`${EModelEndpoint.bedrock}-${BedrockProviders.Amazon}`]: bedrockGeneral,
[`${EModelEndpoint.bedrock}-${BedrockProviders.DeepSeek}`]: bedrockGeneral,
[`${EModelEndpoint.bedrock}-${BedrockProviders.Moonshot}`]: bedrockMoonshot,
[`${EModelEndpoint.bedrock}-${BedrockProviders.MoonshotAI}`]: bedrockMoonshot,
[`${EModelEndpoint.bedrock}-${BedrockProviders.OpenAI}`]: bedrockGeneral,
[`${EModelEndpoint.bedrock}-${BedrockProviders.ZAI}`]: bedrockGeneral,
[EModelEndpoint.google]: googleConfig,
};
@ -1000,6 +1003,12 @@ export const presetSettings: Record<
col1: bedrockMoonshotCol1,
col2: bedrockMoonshotCol2,
},
[`${EModelEndpoint.bedrock}-${BedrockProviders.MoonshotAI}`]: {
col1: bedrockMoonshotCol1,
col2: bedrockMoonshotCol2,
},
[`${EModelEndpoint.bedrock}-${BedrockProviders.OpenAI}`]: bedrockGeneralColumns,
[`${EModelEndpoint.bedrock}-${BedrockProviders.ZAI}`]: bedrockGeneralColumns,
[EModelEndpoint.google]: {
col1: googleCol1,
col2: googleCol2,

View file

@ -144,26 +144,25 @@ export const parseConvo = ({
endpointType,
conversation,
possibleValues,
defaultParamsEndpoint,
}: {
endpoint: EndpointSchemaKey;
endpointType?: EndpointSchemaKey | null;
conversation: Partial<s.TConversation | s.TPreset> | null;
possibleValues?: TPossibleValues;
// TODO: POC for default schema
// defaultSchema?: Partial<EndpointSchema>,
defaultParamsEndpoint?: string | null;
}) => {
let schema = endpointSchemas[endpoint] as EndpointSchema | undefined;
if (!schema && !endpointType) {
throw new Error(`Unknown endpoint: ${endpoint}`);
} else if (!schema && endpointType) {
schema = endpointSchemas[endpointType];
} else if (!schema) {
const overrideSchema = defaultParamsEndpoint
? endpointSchemas[defaultParamsEndpoint as EndpointSchemaKey]
: undefined;
schema = overrideSchema ?? (endpointType ? endpointSchemas[endpointType] : undefined);
}
// if (defaultSchema && schemaCreators[endpoint]) {
// schema = schemaCreators[endpoint](defaultSchema);
// }
const convo = schema?.parse(conversation) as s.TConversation | undefined;
const { models } = possibleValues ?? {};
@ -310,13 +309,13 @@ export const parseCompactConvo = ({
endpointType,
conversation,
possibleValues,
defaultParamsEndpoint,
}: {
endpoint?: EndpointSchemaKey;
endpointType?: EndpointSchemaKey | null;
conversation: Partial<s.TConversation | s.TPreset>;
possibleValues?: TPossibleValues;
// TODO: POC for default schema
// defaultSchema?: Partial<EndpointSchema>,
defaultParamsEndpoint?: string | null;
}): Omit<s.TConversation, 'iconURL'> | null => {
if (!endpoint) {
throw new Error(`undefined endpoint: ${endpoint}`);
@ -326,8 +325,11 @@ export const parseCompactConvo = ({
if (!schema && !endpointType) {
throw new Error(`Unknown endpoint: ${endpoint}`);
} else if (!schema && endpointType) {
schema = compactEndpointSchemas[endpointType];
} else if (!schema) {
const overrideSchema = defaultParamsEndpoint
? compactEndpointSchemas[defaultParamsEndpoint as EndpointSchemaKey]
: undefined;
schema = overrideSchema ?? (endpointType ? compactEndpointSchemas[endpointType] : undefined);
}
if (!schema) {

View file

@ -101,7 +101,10 @@ export enum BedrockProviders {
Meta = 'meta',
MistralAI = 'mistral',
Moonshot = 'moonshot',
MoonshotAI = 'moonshotai',
OpenAI = 'openai',
StabilityAI = 'stability',
ZAI = 'zai',
}
export const getModelKey = (endpoint: EModelEndpoint | string, model: string) => {

View file

@ -99,6 +99,7 @@ export type TEphemeralAgent = {
web_search?: boolean;
file_search?: boolean;
execute_code?: boolean;
artifacts?: string;
};
export type TPayload = Partial<TMessage> &

View file

@ -1015,4 +1015,239 @@ describe('Meilisearch Mongoose plugin', () => {
]);
});
});
describe('Missing _meiliIndex property handling in sync process', () => {
test('syncWithMeili includes documents with missing _meiliIndex', async () => {
const conversationModel = createConversationModel(mongoose) as SchemaWithMeiliMethods;
await conversationModel.deleteMany({});
mockAddDocumentsInBatches.mockClear();
// Insert documents with different _meiliIndex states
await conversationModel.collection.insertMany([
{
conversationId: new mongoose.Types.ObjectId(),
user: new mongoose.Types.ObjectId(),
title: 'Missing _meiliIndex',
endpoint: EModelEndpoint.openAI,
expiredAt: null,
// _meiliIndex is not set (missing/undefined)
},
{
conversationId: new mongoose.Types.ObjectId(),
user: new mongoose.Types.ObjectId(),
title: 'Explicit false',
endpoint: EModelEndpoint.openAI,
expiredAt: null,
_meiliIndex: false,
},
{
conversationId: new mongoose.Types.ObjectId(),
user: new mongoose.Types.ObjectId(),
title: 'Already indexed',
endpoint: EModelEndpoint.openAI,
expiredAt: null,
_meiliIndex: true,
},
]);
// Run sync
await conversationModel.syncWithMeili();
// Should have processed 2 documents (missing and false, but not true)
expect(mockAddDocumentsInBatches).toHaveBeenCalled();
// Check that both documents without _meiliIndex=true are now indexed
const indexedCount = await conversationModel.countDocuments({
expiredAt: null,
_meiliIndex: true,
});
expect(indexedCount).toBe(3); // All 3 should now be indexed
// Verify documents with missing _meiliIndex were updated
const docsWithMissingIndex = await conversationModel.countDocuments({
expiredAt: null,
title: 'Missing _meiliIndex',
_meiliIndex: true,
});
expect(docsWithMissingIndex).toBe(1);
});
test('getSyncProgress counts documents with missing _meiliIndex as not indexed', async () => {
const messageModel = createMessageModel(mongoose) as SchemaWithMeiliMethods;
await messageModel.deleteMany({});
// Insert documents with different _meiliIndex states
await messageModel.collection.insertMany([
{
messageId: new mongoose.Types.ObjectId(),
conversationId: new mongoose.Types.ObjectId(),
user: new mongoose.Types.ObjectId(),
isCreatedByUser: true,
expiredAt: null,
_meiliIndex: true,
},
{
messageId: new mongoose.Types.ObjectId(),
conversationId: new mongoose.Types.ObjectId(),
user: new mongoose.Types.ObjectId(),
isCreatedByUser: true,
expiredAt: null,
_meiliIndex: false,
},
{
messageId: new mongoose.Types.ObjectId(),
conversationId: new mongoose.Types.ObjectId(),
user: new mongoose.Types.ObjectId(),
isCreatedByUser: true,
expiredAt: null,
// _meiliIndex is missing
},
]);
const progress = await messageModel.getSyncProgress();
// Total should be 3
expect(progress.totalDocuments).toBe(3);
// Only 1 is indexed (with _meiliIndex: true)
expect(progress.totalProcessed).toBe(1);
// Not complete since 2 documents are not indexed
expect(progress.isComplete).toBe(false);
});
test('query with _meiliIndex: { $ne: true } includes missing values', async () => {
const conversationModel = createConversationModel(mongoose) as SchemaWithMeiliMethods;
await conversationModel.deleteMany({});
// Insert documents with different _meiliIndex states
await conversationModel.collection.insertMany([
{
conversationId: new mongoose.Types.ObjectId(),
user: new mongoose.Types.ObjectId(),
title: 'Missing',
endpoint: EModelEndpoint.openAI,
expiredAt: null,
// _meiliIndex is missing
},
{
conversationId: new mongoose.Types.ObjectId(),
user: new mongoose.Types.ObjectId(),
title: 'False',
endpoint: EModelEndpoint.openAI,
expiredAt: null,
_meiliIndex: false,
},
{
conversationId: new mongoose.Types.ObjectId(),
user: new mongoose.Types.ObjectId(),
title: 'True',
endpoint: EModelEndpoint.openAI,
expiredAt: null,
_meiliIndex: true,
},
]);
// Query for documents where _meiliIndex is not true (used in syncWithMeili)
const unindexedDocs = await conversationModel.find({
expiredAt: null,
_meiliIndex: { $ne: true },
});
// Should find 2 documents (missing and false, but not true)
expect(unindexedDocs.length).toBe(2);
const titles = unindexedDocs.map((doc) => doc.title).sort();
expect(titles).toEqual(['False', 'Missing']);
});
test('syncWithMeili processes all documents where _meiliIndex is not true', async () => {
const messageModel = createMessageModel(mongoose) as SchemaWithMeiliMethods;
await messageModel.deleteMany({});
mockAddDocumentsInBatches.mockClear();
// Create a mix of documents with missing and false _meiliIndex
await messageModel.collection.insertMany([
{
messageId: new mongoose.Types.ObjectId(),
conversationId: new mongoose.Types.ObjectId(),
user: new mongoose.Types.ObjectId(),
isCreatedByUser: true,
expiredAt: null,
// _meiliIndex missing
},
{
messageId: new mongoose.Types.ObjectId(),
conversationId: new mongoose.Types.ObjectId(),
user: new mongoose.Types.ObjectId(),
isCreatedByUser: true,
expiredAt: null,
_meiliIndex: false,
},
{
messageId: new mongoose.Types.ObjectId(),
conversationId: new mongoose.Types.ObjectId(),
user: new mongoose.Types.ObjectId(),
isCreatedByUser: true,
expiredAt: null,
// _meiliIndex missing
},
]);
// Count documents that should be synced (where _meiliIndex: { $ne: true })
const toSyncCount = await messageModel.countDocuments({
expiredAt: null,
_meiliIndex: { $ne: true },
});
expect(toSyncCount).toBe(3); // All 3 should be synced
await messageModel.syncWithMeili();
// All should now be indexed
const indexedCount = await messageModel.countDocuments({
expiredAt: null,
_meiliIndex: true,
});
expect(indexedCount).toBe(3);
});
test('syncWithMeili treats missing _meiliIndex same as false', async () => {
const conversationModel = createConversationModel(mongoose) as SchemaWithMeiliMethods;
await conversationModel.deleteMany({});
mockAddDocumentsInBatches.mockClear();
// Insert one document with missing _meiliIndex and one with false
await conversationModel.collection.insertMany([
{
conversationId: new mongoose.Types.ObjectId(),
user: new mongoose.Types.ObjectId(),
title: 'Missing',
endpoint: EModelEndpoint.openAI,
expiredAt: null,
// _meiliIndex is missing
},
{
conversationId: new mongoose.Types.ObjectId(),
user: new mongoose.Types.ObjectId(),
title: 'False',
endpoint: EModelEndpoint.openAI,
expiredAt: null,
_meiliIndex: false,
},
]);
// Both should be picked up by the sync query
const toSync = await conversationModel.find({
expiredAt: null,
_meiliIndex: { $ne: true },
});
expect(toSync.length).toBe(2);
await conversationModel.syncWithMeili();
// Both should be indexed after sync
const afterSync = await conversationModel.find({
expiredAt: null,
_meiliIndex: true,
});
expect(afterSync.length).toBe(2);
});
});
});

View file

@ -162,8 +162,8 @@ const createMeiliMongooseModel = ({
/**
* Synchronizes data between the MongoDB collection and the MeiliSearch index by
* incrementally indexing only documents where `expiredAt` is `null` and `_meiliIndex` is `false`
* (i.e., non-expired documents that have not yet been indexed).
* incrementally indexing only documents where `expiredAt` is `null` and `_meiliIndex` is not `true`
* (i.e., non-expired documents that have not yet been indexed, including those with missing or null `_meiliIndex`).
* */
static async syncWithMeili(this: SchemaWithMeiliMethods): Promise<void> {
const startTime = Date.now();
@ -196,7 +196,7 @@ const createMeiliMongooseModel = ({
while (hasMore) {
const query: FilterQuery<unknown> = {
expiredAt: null,
_meiliIndex: false,
_meiliIndex: { $ne: true },
};
try {