mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-04-07 00:15:23 +02:00
Merge branch 'main' into fix-image_gen_oai-with-nova-models
This commit is contained in:
commit
cbb6b1a7d5
211 changed files with 10746 additions and 1656 deletions
|
|
@ -10,6 +10,17 @@ export default {
|
|||
],
|
||||
coverageReporters: ['text', 'cobertura'],
|
||||
testResultsProcessor: 'jest-junit',
|
||||
transform: {
|
||||
'\\.[jt]sx?$': [
|
||||
'babel-jest',
|
||||
{
|
||||
presets: [
|
||||
['@babel/preset-env', { targets: { node: 'current' } }],
|
||||
'@babel/preset-typescript',
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
moduleNameMapper: {
|
||||
'^@src/(.*)$': '<rootDir>/src/$1',
|
||||
'~/(.*)': '<rootDir>/src/$1',
|
||||
|
|
|
|||
|
|
@ -87,11 +87,11 @@
|
|||
"@google/genai": "^1.19.0",
|
||||
"@keyv/redis": "^4.3.3",
|
||||
"@langchain/core": "^0.3.80",
|
||||
"@librechat/agents": "^3.1.38",
|
||||
"@librechat/agents": "^3.1.50",
|
||||
"@librechat/data-schemas": "*",
|
||||
"@modelcontextprotocol/sdk": "^1.26.0",
|
||||
"@smithy/node-http-handler": "^4.4.5",
|
||||
"axios": "^1.12.1",
|
||||
"axios": "^1.13.5",
|
||||
"connect-redis": "^8.1.0",
|
||||
"eventsource": "^3.0.2",
|
||||
"express": "^5.1.0",
|
||||
|
|
|
|||
284
packages/api/src/agents/__tests__/initialize.test.ts
Normal file
284
packages/api/src/agents/__tests__/initialize.test.ts
Normal file
|
|
@ -0,0 +1,284 @@
|
|||
import { Providers } from '@librechat/agents';
|
||||
import { EModelEndpoint } from 'librechat-data-provider';
|
||||
import type { Agent } from 'librechat-data-provider';
|
||||
import type { ServerRequest, InitializeResultBase } from '~/types';
|
||||
import type { InitializeAgentDbMethods } from '../initialize';
|
||||
|
||||
// Mock logger
|
||||
jest.mock('winston', () => ({
|
||||
createLogger: jest.fn(() => ({
|
||||
debug: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
})),
|
||||
format: {
|
||||
combine: jest.fn(),
|
||||
colorize: jest.fn(),
|
||||
simple: jest.fn(),
|
||||
},
|
||||
transports: {
|
||||
Console: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
const mockExtractLibreChatParams = jest.fn();
|
||||
const mockGetModelMaxTokens = jest.fn();
|
||||
const mockOptionalChainWithEmptyCheck = jest.fn();
|
||||
const mockGetThreadData = jest.fn();
|
||||
|
||||
jest.mock('~/utils', () => ({
|
||||
extractLibreChatParams: (...args: unknown[]) => mockExtractLibreChatParams(...args),
|
||||
getModelMaxTokens: (...args: unknown[]) => mockGetModelMaxTokens(...args),
|
||||
optionalChainWithEmptyCheck: (...args: unknown[]) => mockOptionalChainWithEmptyCheck(...args),
|
||||
getThreadData: (...args: unknown[]) => mockGetThreadData(...args),
|
||||
}));
|
||||
|
||||
const mockGetProviderConfig = jest.fn();
|
||||
jest.mock('~/endpoints', () => ({
|
||||
getProviderConfig: (...args: unknown[]) => mockGetProviderConfig(...args),
|
||||
}));
|
||||
|
||||
jest.mock('~/files', () => ({
|
||||
filterFilesByEndpointConfig: jest.fn(() => []),
|
||||
}));
|
||||
|
||||
jest.mock('~/prompts', () => ({
|
||||
generateArtifactsPrompt: jest.fn(() => null),
|
||||
}));
|
||||
|
||||
jest.mock('../resources', () => ({
|
||||
primeResources: jest.fn().mockResolvedValue({
|
||||
attachments: [],
|
||||
tool_resources: undefined,
|
||||
}),
|
||||
}));
|
||||
|
||||
import { initializeAgent } from '../initialize';
|
||||
|
||||
/**
|
||||
* Creates minimal mock objects for initializeAgent tests.
|
||||
*/
|
||||
function createMocks(overrides?: {
|
||||
maxContextTokens?: number;
|
||||
modelDefault?: number;
|
||||
maxOutputTokens?: number;
|
||||
}) {
|
||||
const { maxContextTokens, modelDefault = 200000, maxOutputTokens = 4096 } = overrides ?? {};
|
||||
|
||||
const agent = {
|
||||
id: 'agent-1',
|
||||
model: 'test-model',
|
||||
provider: Providers.OPENAI,
|
||||
tools: [],
|
||||
model_parameters: { model: 'test-model' },
|
||||
} as unknown as Agent;
|
||||
|
||||
const req = {
|
||||
user: { id: 'user-1' },
|
||||
config: {},
|
||||
} as unknown as ServerRequest;
|
||||
|
||||
const res = {} as unknown as import('express').Response;
|
||||
|
||||
const mockGetOptions = jest.fn().mockResolvedValue({
|
||||
llmConfig: {
|
||||
model: 'test-model',
|
||||
maxTokens: maxOutputTokens,
|
||||
},
|
||||
endpointTokenConfig: undefined,
|
||||
} satisfies InitializeResultBase);
|
||||
|
||||
mockGetProviderConfig.mockReturnValue({
|
||||
getOptions: mockGetOptions,
|
||||
overrideProvider: Providers.OPENAI,
|
||||
});
|
||||
|
||||
// extractLibreChatParams returns maxContextTokens when provided in model_parameters
|
||||
mockExtractLibreChatParams.mockReturnValue({
|
||||
resendFiles: false,
|
||||
maxContextTokens,
|
||||
modelOptions: { model: 'test-model' },
|
||||
});
|
||||
|
||||
// getModelMaxTokens returns the model's default context window
|
||||
mockGetModelMaxTokens.mockReturnValue(modelDefault);
|
||||
|
||||
// Implement real optionalChainWithEmptyCheck behavior
|
||||
mockOptionalChainWithEmptyCheck.mockImplementation(
|
||||
(...values: (string | number | undefined)[]) => {
|
||||
for (const v of values) {
|
||||
if (v !== undefined && v !== null && v !== '') {
|
||||
return v;
|
||||
}
|
||||
}
|
||||
return values[values.length - 1];
|
||||
},
|
||||
);
|
||||
|
||||
const loadTools = jest.fn().mockResolvedValue({
|
||||
tools: [],
|
||||
toolContextMap: {},
|
||||
userMCPAuthMap: undefined,
|
||||
toolRegistry: undefined,
|
||||
toolDefinitions: [],
|
||||
hasDeferredTools: false,
|
||||
});
|
||||
|
||||
const db: InitializeAgentDbMethods = {
|
||||
getFiles: jest.fn().mockResolvedValue([]),
|
||||
getConvoFiles: jest.fn().mockResolvedValue([]),
|
||||
updateFilesUsage: jest.fn().mockResolvedValue([]),
|
||||
getUserKey: jest.fn().mockResolvedValue('user-1'),
|
||||
getUserKeyValues: jest.fn().mockResolvedValue([]),
|
||||
getToolFilesByIds: jest.fn().mockResolvedValue([]),
|
||||
};
|
||||
|
||||
return { agent, req, res, loadTools, db };
|
||||
}
|
||||
|
||||
describe('initializeAgent — maxContextTokens', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('uses user-configured maxContextTokens when provided via model_parameters', async () => {
|
||||
const userValue = 50000;
|
||||
const { agent, req, res, loadTools, db } = createMocks({
|
||||
maxContextTokens: userValue,
|
||||
modelDefault: 200000,
|
||||
maxOutputTokens: 4096,
|
||||
});
|
||||
|
||||
const result = await initializeAgent(
|
||||
{
|
||||
req,
|
||||
res,
|
||||
agent,
|
||||
loadTools,
|
||||
endpointOption: {
|
||||
endpoint: EModelEndpoint.agents,
|
||||
model_parameters: { maxContextTokens: userValue },
|
||||
},
|
||||
allowedProviders: new Set([Providers.OPENAI]),
|
||||
isInitialAgent: true,
|
||||
},
|
||||
db,
|
||||
);
|
||||
|
||||
expect(result.maxContextTokens).toBe(userValue);
|
||||
});
|
||||
|
||||
it('falls back to formula when maxContextTokens is NOT provided', async () => {
|
||||
const modelDefault = 200000;
|
||||
const maxOutputTokens = 4096;
|
||||
const { agent, req, res, loadTools, db } = createMocks({
|
||||
maxContextTokens: undefined,
|
||||
modelDefault,
|
||||
maxOutputTokens,
|
||||
});
|
||||
|
||||
const result = await initializeAgent(
|
||||
{
|
||||
req,
|
||||
res,
|
||||
agent,
|
||||
loadTools,
|
||||
endpointOption: { endpoint: EModelEndpoint.agents },
|
||||
allowedProviders: new Set([Providers.OPENAI]),
|
||||
isInitialAgent: true,
|
||||
},
|
||||
db,
|
||||
);
|
||||
|
||||
const expected = Math.round((modelDefault - maxOutputTokens) * 0.9);
|
||||
expect(result.maxContextTokens).toBe(expected);
|
||||
});
|
||||
|
||||
it('falls back to formula when maxContextTokens is 0', async () => {
|
||||
const maxOutputTokens = 4096;
|
||||
const { agent, req, res, loadTools, db } = createMocks({
|
||||
maxContextTokens: 0,
|
||||
modelDefault: 200000,
|
||||
maxOutputTokens,
|
||||
});
|
||||
|
||||
const result = await initializeAgent(
|
||||
{
|
||||
req,
|
||||
res,
|
||||
agent,
|
||||
loadTools,
|
||||
endpointOption: {
|
||||
endpoint: EModelEndpoint.agents,
|
||||
model_parameters: { maxContextTokens: 0 },
|
||||
},
|
||||
allowedProviders: new Set([Providers.OPENAI]),
|
||||
isInitialAgent: true,
|
||||
},
|
||||
db,
|
||||
);
|
||||
|
||||
// 0 is not used as-is; the formula kicks in.
|
||||
// optionalChainWithEmptyCheck(0, 200000, 18000) returns 0 (not null/undefined),
|
||||
// then Number(0) || 18000 = 18000 (the fallback default).
|
||||
expect(result.maxContextTokens).not.toBe(0);
|
||||
const expected = Math.round((18000 - maxOutputTokens) * 0.9);
|
||||
expect(result.maxContextTokens).toBe(expected);
|
||||
});
|
||||
|
||||
it('falls back to formula when maxContextTokens is negative', async () => {
|
||||
const maxOutputTokens = 4096;
|
||||
const { agent, req, res, loadTools, db } = createMocks({
|
||||
maxContextTokens: -1,
|
||||
modelDefault: 200000,
|
||||
maxOutputTokens,
|
||||
});
|
||||
|
||||
const result = await initializeAgent(
|
||||
{
|
||||
req,
|
||||
res,
|
||||
agent,
|
||||
loadTools,
|
||||
endpointOption: {
|
||||
endpoint: EModelEndpoint.agents,
|
||||
model_parameters: { maxContextTokens: -1 },
|
||||
},
|
||||
allowedProviders: new Set([Providers.OPENAI]),
|
||||
isInitialAgent: true,
|
||||
},
|
||||
db,
|
||||
);
|
||||
|
||||
// -1 is not used as-is; the formula kicks in
|
||||
expect(result.maxContextTokens).not.toBe(-1);
|
||||
});
|
||||
|
||||
it('preserves small user-configured value (e.g. 1000 from modelSpec)', async () => {
|
||||
const userValue = 1000;
|
||||
const { agent, req, res, loadTools, db } = createMocks({
|
||||
maxContextTokens: userValue,
|
||||
modelDefault: 128000,
|
||||
maxOutputTokens: 4096,
|
||||
});
|
||||
|
||||
const result = await initializeAgent(
|
||||
{
|
||||
req,
|
||||
res,
|
||||
agent,
|
||||
loadTools,
|
||||
endpointOption: {
|
||||
endpoint: EModelEndpoint.agents,
|
||||
model_parameters: { maxContextTokens: userValue },
|
||||
},
|
||||
allowedProviders: new Set([Providers.OPENAI]),
|
||||
isInitialAgent: true,
|
||||
},
|
||||
db,
|
||||
);
|
||||
|
||||
// Should NOT be overridden to Math.round((128000 - 4096) * 0.9) = 111,514
|
||||
expect(result.maxContextTokens).toBe(userValue);
|
||||
});
|
||||
});
|
||||
162
packages/api/src/agents/client.ts
Normal file
162
packages/api/src/agents/client.ts
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
import { logger } from '@librechat/data-schemas';
|
||||
import { isAgentsEndpoint } from 'librechat-data-provider';
|
||||
import { labelContentByAgent, getTokenCountForMessage } from '@librechat/agents';
|
||||
import type { MessageContentComplex } from '@librechat/agents';
|
||||
import type { Agent, TMessage } from 'librechat-data-provider';
|
||||
import type { BaseMessage } from '@langchain/core/messages';
|
||||
import type { ServerRequest } from '~/types';
|
||||
import Tokenizer from '~/utils/tokenizer';
|
||||
import { logAxiosError } from '~/utils';
|
||||
|
||||
export const omitTitleOptions = new Set([
|
||||
'stream',
|
||||
'thinking',
|
||||
'streaming',
|
||||
'clientOptions',
|
||||
'thinkingConfig',
|
||||
'thinkingBudget',
|
||||
'includeThoughts',
|
||||
'maxOutputTokens',
|
||||
'additionalModelRequestFields',
|
||||
]);
|
||||
|
||||
export function payloadParser({ req, endpoint }: { req: ServerRequest; endpoint: string }) {
|
||||
if (isAgentsEndpoint(endpoint)) {
|
||||
return;
|
||||
}
|
||||
return req.body?.endpointOption?.model_parameters;
|
||||
}
|
||||
|
||||
export function createTokenCounter(encoding: Parameters<typeof Tokenizer.getTokenCount>[1]) {
|
||||
return function (message: BaseMessage) {
|
||||
const countTokens = (text: string) => Tokenizer.getTokenCount(text, encoding);
|
||||
return getTokenCountForMessage(message, countTokens);
|
||||
};
|
||||
}
|
||||
|
||||
export function logToolError(_graph: unknown, error: unknown, toolId: string) {
|
||||
logAxiosError({
|
||||
error,
|
||||
message: `[api/server/controllers/agents/client.js #chatCompletion] Tool Error "${toolId}"`,
|
||||
});
|
||||
}
|
||||
|
||||
const AGENT_SUFFIX_PATTERN = /____(\d+)$/;
|
||||
|
||||
/** Finds the primary agent ID within a set of agent IDs (no suffix or lowest suffix number) */
|
||||
export function findPrimaryAgentId(agentIds: Set<string>): string | null {
|
||||
let primaryAgentId: string | null = null;
|
||||
let lowestSuffixIndex = Infinity;
|
||||
|
||||
for (const agentId of agentIds) {
|
||||
const suffixMatch = agentId.match(AGENT_SUFFIX_PATTERN);
|
||||
if (!suffixMatch) {
|
||||
return agentId;
|
||||
}
|
||||
const suffixIndex = parseInt(suffixMatch[1], 10);
|
||||
if (suffixIndex < lowestSuffixIndex) {
|
||||
lowestSuffixIndex = suffixIndex;
|
||||
primaryAgentId = agentId;
|
||||
}
|
||||
}
|
||||
|
||||
return primaryAgentId;
|
||||
}
|
||||
|
||||
type ContentPart = TMessage['content'] extends (infer U)[] | undefined ? U : never;
|
||||
|
||||
/**
|
||||
* Creates a mapMethod for getMessagesForConversation that processes agent content.
|
||||
* - Strips agentId/groupId metadata from all content
|
||||
* - For parallel agents (addedConvo with groupId): filters each group to its primary agent
|
||||
* - For handoffs (agentId without groupId): keeps all content from all agents
|
||||
* - For multi-agent: applies agent labels to content
|
||||
*
|
||||
* The key distinction:
|
||||
* - Parallel execution (addedConvo): Parts have both agentId AND groupId
|
||||
* - Handoffs: Parts only have agentId, no groupId
|
||||
*/
|
||||
export function createMultiAgentMapper(primaryAgent: Agent, agentConfigs?: Map<string, Agent>) {
|
||||
const hasMultipleAgents = (primaryAgent.edges?.length ?? 0) > 0 || (agentConfigs?.size ?? 0) > 0;
|
||||
|
||||
let agentNames: Record<string, string> | null = null;
|
||||
if (hasMultipleAgents) {
|
||||
agentNames = { [primaryAgent.id]: primaryAgent.name || 'Assistant' };
|
||||
if (agentConfigs) {
|
||||
for (const [agentId, agentConfig] of agentConfigs.entries()) {
|
||||
agentNames[agentId] = agentConfig.name || agentConfig.id;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return (message: TMessage): TMessage => {
|
||||
if (message.isCreatedByUser || !Array.isArray(message.content)) {
|
||||
return message;
|
||||
}
|
||||
|
||||
const hasAgentMetadata = message.content.some(
|
||||
(part) =>
|
||||
(part as ContentPart & { agentId?: string; groupId?: number })?.agentId ||
|
||||
(part as ContentPart & { groupId?: number })?.groupId != null,
|
||||
);
|
||||
if (!hasAgentMetadata) {
|
||||
return message;
|
||||
}
|
||||
|
||||
try {
|
||||
const groupAgentMap = new Map<number, Set<string>>();
|
||||
|
||||
for (const part of message.content) {
|
||||
const p = part as ContentPart & { agentId?: string; groupId?: number };
|
||||
const groupId = p?.groupId;
|
||||
const agentId = p?.agentId;
|
||||
if (groupId != null && agentId) {
|
||||
if (!groupAgentMap.has(groupId)) {
|
||||
groupAgentMap.set(groupId, new Set());
|
||||
}
|
||||
groupAgentMap.get(groupId)!.add(agentId);
|
||||
}
|
||||
}
|
||||
|
||||
const groupPrimaryMap = new Map<number, string>();
|
||||
for (const [groupId, agentIds] of groupAgentMap) {
|
||||
const primary = findPrimaryAgentId(agentIds);
|
||||
if (primary) {
|
||||
groupPrimaryMap.set(groupId, primary);
|
||||
}
|
||||
}
|
||||
|
||||
const filteredContent: ContentPart[] = [];
|
||||
const agentIdMap: Record<number, string> = {};
|
||||
|
||||
for (const part of message.content) {
|
||||
const p = part as ContentPart & { agentId?: string; groupId?: number };
|
||||
const agentId = p?.agentId;
|
||||
const groupId = p?.groupId;
|
||||
|
||||
const isParallelPart = groupId != null;
|
||||
const groupPrimary = isParallelPart ? groupPrimaryMap.get(groupId) : null;
|
||||
const shouldInclude = !isParallelPart || !agentId || agentId === groupPrimary;
|
||||
|
||||
if (shouldInclude) {
|
||||
const newIndex = filteredContent.length;
|
||||
const { agentId: _a, groupId: _g, ...cleanPart } = p;
|
||||
filteredContent.push(cleanPart as ContentPart);
|
||||
if (agentId && hasMultipleAgents) {
|
||||
agentIdMap[newIndex] = agentId;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const finalContent =
|
||||
Object.keys(agentIdMap).length > 0 && agentNames
|
||||
? labelContentByAgent(filteredContent as MessageContentComplex[], agentIdMap, agentNames)
|
||||
: filteredContent;
|
||||
|
||||
return { ...message, content: finalContent as TMessage['content'] };
|
||||
} catch (error) {
|
||||
logger.error('[AgentClient] Error processing multi-agent message:', error);
|
||||
return message;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
export * from './avatars';
|
||||
export * from './chain';
|
||||
export * from './client';
|
||||
export * from './context';
|
||||
export * from './edges';
|
||||
export * from './handlers';
|
||||
|
|
|
|||
|
|
@ -413,7 +413,10 @@ export async function initializeAgent(
|
|||
toolContextMap: toolContextMap ?? {},
|
||||
useLegacyContent: !!options.useLegacyContent,
|
||||
tools: (tools ?? []) as GenericTool[] & string[],
|
||||
maxContextTokens: Math.round((agentMaxContextNum - maxOutputTokensNum) * 0.9),
|
||||
maxContextTokens:
|
||||
maxContextTokens != null && maxContextTokens > 0
|
||||
? maxContextTokens
|
||||
: Math.round((agentMaxContextNum - maxOutputTokensNum) * 0.9),
|
||||
};
|
||||
|
||||
return initializedAgent;
|
||||
|
|
|
|||
193
packages/api/src/agents/responses/__tests__/responses-api.live.test.sh
Executable file
193
packages/api/src/agents/responses/__tests__/responses-api.live.test.sh
Executable file
|
|
@ -0,0 +1,193 @@
|
|||
#!/usr/bin/env bash
|
||||
#
|
||||
# Live integration tests for the Responses API endpoint.
|
||||
# Sends curl requests to a running LibreChat server to verify
|
||||
# multi-turn conversations with output_text / refusal blocks work.
|
||||
#
|
||||
# Usage:
|
||||
# ./responses-api.live.test.sh <BASE_URL> <API_KEY> <AGENT_ID>
|
||||
#
|
||||
# Example:
|
||||
# ./responses-api.live.test.sh http://localhost:3080 sk-abc123 agent_xyz
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
BASE_URL="${1:?Usage: $0 <BASE_URL> <API_KEY> <AGENT_ID>}"
|
||||
API_KEY="${2:?Usage: $0 <BASE_URL> <API_KEY> <AGENT_ID>}"
|
||||
AGENT_ID="${3:?Usage: $0 <BASE_URL> <API_KEY> <AGENT_ID>}"
|
||||
|
||||
ENDPOINT="${BASE_URL}/v1/responses"
|
||||
PASS=0
|
||||
FAIL=0
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
post_json() {
|
||||
local label="$1"
|
||||
local body="$2"
|
||||
local stream="${3:-false}"
|
||||
|
||||
echo "──────────────────────────────────────────────"
|
||||
echo "TEST: ${label}"
|
||||
echo "──────────────────────────────────────────────"
|
||||
|
||||
local http_code
|
||||
local response
|
||||
|
||||
if [ "$stream" = "true" ]; then
|
||||
# For streaming, just check we get a 200 and some SSE data
|
||||
response=$(curl -s -w "\n%{http_code}" \
|
||||
-X POST "${ENDPOINT}" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer ${API_KEY}" \
|
||||
-d "${body}" \
|
||||
--max-time 60)
|
||||
else
|
||||
response=$(curl -s -w "\n%{http_code}" \
|
||||
-X POST "${ENDPOINT}" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer ${API_KEY}" \
|
||||
-d "${body}" \
|
||||
--max-time 60)
|
||||
fi
|
||||
|
||||
http_code=$(echo "$response" | tail -1)
|
||||
local body_out
|
||||
body_out=$(echo "$response" | sed '$d')
|
||||
|
||||
if [ "$http_code" = "200" ]; then
|
||||
echo " ✓ HTTP 200"
|
||||
PASS=$((PASS + 1))
|
||||
else
|
||||
echo " ✗ HTTP ${http_code}"
|
||||
echo " Response: ${body_out}"
|
||||
FAIL=$((FAIL + 1))
|
||||
fi
|
||||
|
||||
# Print truncated response for inspection
|
||||
echo " Response (first 300 chars): ${body_out:0:300}"
|
||||
echo ""
|
||||
|
||||
# Return the body for chaining
|
||||
echo "$body_out"
|
||||
}
|
||||
|
||||
extract_response_id() {
|
||||
# Extract "id" field from JSON response
|
||||
echo "$1" | grep -o '"id":"[^"]*"' | head -1 | cut -d'"' -f4
|
||||
}
|
||||
|
||||
# ── Test 1: Basic single-turn request ─────────────────────────────────
|
||||
|
||||
RESP1=$(post_json "Basic single-turn request" "$(cat <<EOF
|
||||
{
|
||||
"model": "${AGENT_ID}",
|
||||
"input": "Say hello in exactly 5 words.",
|
||||
"stream": false
|
||||
}
|
||||
EOF
|
||||
)")
|
||||
|
||||
# ── Test 2: Multi-turn with output_text assistant blocks ──────────────
|
||||
|
||||
post_json "Multi-turn with output_text blocks (the original bug)" "$(cat <<EOF
|
||||
{
|
||||
"model": "${AGENT_ID}",
|
||||
"input": [
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": "What is 2+2?"}]
|
||||
},
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": "2+2 equals 4.", "annotations": [], "logprobs": []}]
|
||||
},
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": "And what is 3+3?"}]
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
EOF
|
||||
)" > /dev/null
|
||||
|
||||
# ── Test 3: Multi-turn with refusal blocks ────────────────────────────
|
||||
|
||||
post_json "Multi-turn with refusal blocks" "$(cat <<EOF
|
||||
{
|
||||
"model": "${AGENT_ID}",
|
||||
"input": [
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": "Do something bad"}]
|
||||
},
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "refusal", "refusal": "I cannot help with that."}]
|
||||
},
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": "OK, just say hello then."}]
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
EOF
|
||||
)" > /dev/null
|
||||
|
||||
# ── Test 4: Streaming request ─────────────────────────────────────────
|
||||
|
||||
post_json "Streaming single-turn request" "$(cat <<EOF
|
||||
{
|
||||
"model": "${AGENT_ID}",
|
||||
"input": "Say hi in one word.",
|
||||
"stream": true
|
||||
}
|
||||
EOF
|
||||
)" "true" > /dev/null
|
||||
|
||||
# ── Test 5: Back-and-forth using previous_response_id ─────────────────
|
||||
|
||||
RESP5=$(post_json "First turn for previous_response_id chain" "$(cat <<EOF
|
||||
{
|
||||
"model": "${AGENT_ID}",
|
||||
"input": "Remember this number: 42. Just confirm you got it.",
|
||||
"stream": false
|
||||
}
|
||||
EOF
|
||||
)")
|
||||
|
||||
RESP5_ID=$(extract_response_id "$RESP5")
|
||||
|
||||
if [ -n "$RESP5_ID" ]; then
|
||||
echo " Extracted response ID: ${RESP5_ID}"
|
||||
post_json "Follow-up using previous_response_id" "$(cat <<EOF
|
||||
{
|
||||
"model": "${AGENT_ID}",
|
||||
"input": "What number did I ask you to remember?",
|
||||
"previous_response_id": "${RESP5_ID}",
|
||||
"stream": false
|
||||
}
|
||||
EOF
|
||||
)" > /dev/null
|
||||
else
|
||||
echo " ⚠ Could not extract response ID, skipping follow-up test"
|
||||
FAIL=$((FAIL + 1))
|
||||
fi
|
||||
|
||||
# ── Summary ───────────────────────────────────────────────────────────
|
||||
|
||||
echo "══════════════════════════════════════════════"
|
||||
echo "RESULTS: ${PASS} passed, ${FAIL} failed"
|
||||
echo "══════════════════════════════════════════════"
|
||||
|
||||
if [ "$FAIL" -gt 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
333
packages/api/src/agents/responses/__tests__/service.test.ts
Normal file
333
packages/api/src/agents/responses/__tests__/service.test.ts
Normal file
|
|
@ -0,0 +1,333 @@
|
|||
import { convertInputToMessages } from '../service';
|
||||
import type { InputItem } from '../types';
|
||||
|
||||
describe('convertInputToMessages', () => {
|
||||
// ── String input shorthand ─────────────────────────────────────────
|
||||
it('converts a string input to a single user message', () => {
|
||||
const result = convertInputToMessages('Hello');
|
||||
expect(result).toEqual([{ role: 'user', content: 'Hello' }]);
|
||||
});
|
||||
|
||||
// ── Empty input array ──────────────────────────────────────────────
|
||||
it('returns an empty array for empty input', () => {
|
||||
const result = convertInputToMessages([]);
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
|
||||
// ── Role mapping ───────────────────────────────────────────────────
|
||||
it('maps developer role to system', () => {
|
||||
const input: InputItem[] = [
|
||||
{ type: 'message', role: 'developer', content: 'You are helpful.' },
|
||||
];
|
||||
expect(convertInputToMessages(input)).toEqual([
|
||||
{ role: 'system', content: 'You are helpful.' },
|
||||
]);
|
||||
});
|
||||
|
||||
it('maps system role to system', () => {
|
||||
const input: InputItem[] = [{ type: 'message', role: 'system', content: 'System prompt.' }];
|
||||
expect(convertInputToMessages(input)).toEqual([{ role: 'system', content: 'System prompt.' }]);
|
||||
});
|
||||
|
||||
it('maps user role to user', () => {
|
||||
const input: InputItem[] = [{ type: 'message', role: 'user', content: 'Hi' }];
|
||||
expect(convertInputToMessages(input)).toEqual([{ role: 'user', content: 'Hi' }]);
|
||||
});
|
||||
|
||||
it('maps assistant role to assistant', () => {
|
||||
const input: InputItem[] = [{ type: 'message', role: 'assistant', content: 'Hello!' }];
|
||||
expect(convertInputToMessages(input)).toEqual([{ role: 'assistant', content: 'Hello!' }]);
|
||||
});
|
||||
|
||||
it('defaults unknown roles to user', () => {
|
||||
const input = [
|
||||
{ type: 'message', role: 'unknown_role', content: 'test' },
|
||||
] as unknown as InputItem[];
|
||||
expect(convertInputToMessages(input)[0].role).toBe('user');
|
||||
});
|
||||
|
||||
// ── input_text content blocks ──────────────────────────────────────
|
||||
it('converts input_text blocks to text blocks', () => {
|
||||
const input: InputItem[] = [
|
||||
{
|
||||
type: 'message',
|
||||
role: 'user',
|
||||
content: [{ type: 'input_text', text: 'Hello world' }],
|
||||
},
|
||||
];
|
||||
const result = convertInputToMessages(input);
|
||||
expect(result).toEqual([{ role: 'user', content: [{ type: 'text', text: 'Hello world' }] }]);
|
||||
});
|
||||
|
||||
// ── output_text content blocks (the original bug) ──────────────────
|
||||
it('converts output_text blocks to text blocks', () => {
|
||||
const input: InputItem[] = [
|
||||
{
|
||||
type: 'message',
|
||||
role: 'assistant',
|
||||
content: [{ type: 'output_text', text: 'I can help!', annotations: [], logprobs: [] }],
|
||||
},
|
||||
];
|
||||
const result = convertInputToMessages(input);
|
||||
expect(result).toEqual([
|
||||
{ role: 'assistant', content: [{ type: 'text', text: 'I can help!' }] },
|
||||
]);
|
||||
});
|
||||
|
||||
// ── refusal content blocks ─────────────────────────────────────────
|
||||
it('converts refusal blocks to text blocks', () => {
|
||||
const input: InputItem[] = [
|
||||
{
|
||||
type: 'message',
|
||||
role: 'assistant',
|
||||
content: [{ type: 'refusal', refusal: 'I cannot do that.' }],
|
||||
},
|
||||
];
|
||||
const result = convertInputToMessages(input);
|
||||
expect(result).toEqual([
|
||||
{ role: 'assistant', content: [{ type: 'text', text: 'I cannot do that.' }] },
|
||||
]);
|
||||
});
|
||||
|
||||
// ── input_image content blocks ─────────────────────────────────────
|
||||
it('converts input_image blocks to image_url blocks', () => {
|
||||
const input: InputItem[] = [
|
||||
{
|
||||
type: 'message',
|
||||
role: 'user',
|
||||
content: [
|
||||
{ type: 'input_image', image_url: 'https://example.com/img.png', detail: 'high' },
|
||||
],
|
||||
},
|
||||
];
|
||||
const result = convertInputToMessages(input);
|
||||
expect(result).toEqual([
|
||||
{
|
||||
role: 'user',
|
||||
content: [
|
||||
{
|
||||
type: 'image_url',
|
||||
image_url: { url: 'https://example.com/img.png', detail: 'high' },
|
||||
},
|
||||
],
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
// ── input_file content blocks ──────────────────────────────────────
|
||||
it('converts input_file blocks to text placeholders', () => {
|
||||
const input: InputItem[] = [
|
||||
{
|
||||
type: 'message',
|
||||
role: 'user',
|
||||
content: [{ type: 'input_file', filename: 'report.pdf', file_id: 'f_123' }],
|
||||
},
|
||||
];
|
||||
const result = convertInputToMessages(input);
|
||||
expect(result).toEqual([
|
||||
{ role: 'user', content: [{ type: 'text', text: '[File: report.pdf]' }] },
|
||||
]);
|
||||
});
|
||||
|
||||
it('uses "unknown" for input_file without filename', () => {
|
||||
const input: InputItem[] = [
|
||||
{
|
||||
type: 'message',
|
||||
role: 'user',
|
||||
content: [{ type: 'input_file', file_id: 'f_123' }],
|
||||
},
|
||||
];
|
||||
const result = convertInputToMessages(input);
|
||||
expect(result).toEqual([
|
||||
{ role: 'user', content: [{ type: 'text', text: '[File: unknown]' }] },
|
||||
]);
|
||||
});
|
||||
|
||||
// ── Null / undefined filtering ─────────────────────────────────────
|
||||
it('filters out null elements in content arrays', () => {
|
||||
const input = [
|
||||
{
|
||||
type: 'message',
|
||||
role: 'user',
|
||||
content: [null, { type: 'input_text', text: 'valid' }, undefined],
|
||||
},
|
||||
] as unknown as InputItem[];
|
||||
const result = convertInputToMessages(input);
|
||||
expect(result).toEqual([{ role: 'user', content: [{ type: 'text', text: 'valid' }] }]);
|
||||
});
|
||||
|
||||
// ── Missing text field defaults to empty string ────────────────────
|
||||
it('defaults to empty string when text field is missing on input_text', () => {
|
||||
const input = [
|
||||
{
|
||||
type: 'message',
|
||||
role: 'user',
|
||||
content: [{ type: 'input_text' }],
|
||||
},
|
||||
] as unknown as InputItem[];
|
||||
const result = convertInputToMessages(input);
|
||||
expect(result).toEqual([{ role: 'user', content: [{ type: 'text', text: '' }] }]);
|
||||
});
|
||||
|
||||
it('defaults to empty string when text field is missing on output_text', () => {
|
||||
const input = [
|
||||
{
|
||||
type: 'message',
|
||||
role: 'assistant',
|
||||
content: [{ type: 'output_text' }],
|
||||
},
|
||||
] as unknown as InputItem[];
|
||||
const result = convertInputToMessages(input);
|
||||
expect(result).toEqual([{ role: 'assistant', content: [{ type: 'text', text: '' }] }]);
|
||||
});
|
||||
|
||||
it('defaults to empty string when refusal field is missing on refusal block', () => {
|
||||
const input = [
|
||||
{
|
||||
type: 'message',
|
||||
role: 'assistant',
|
||||
content: [{ type: 'refusal' }],
|
||||
},
|
||||
] as unknown as InputItem[];
|
||||
const result = convertInputToMessages(input);
|
||||
expect(result).toEqual([{ role: 'assistant', content: [{ type: 'text', text: '' }] }]);
|
||||
});
|
||||
|
||||
// ── Unknown block types are filtered out ───────────────────────────
|
||||
it('filters out unknown content block types', () => {
|
||||
const input = [
|
||||
{
|
||||
type: 'message',
|
||||
role: 'user',
|
||||
content: [
|
||||
{ type: 'input_text', text: 'keep me' },
|
||||
{ type: 'some_future_type', data: 'ignore' },
|
||||
],
|
||||
},
|
||||
] as unknown as InputItem[];
|
||||
const result = convertInputToMessages(input);
|
||||
expect(result).toEqual([{ role: 'user', content: [{ type: 'text', text: 'keep me' }] }]);
|
||||
});
|
||||
|
||||
// ── Mixed valid/invalid content in same array ──────────────────────
|
||||
it('handles mixed valid and invalid content blocks', () => {
|
||||
const input = [
|
||||
{
|
||||
type: 'message',
|
||||
role: 'assistant',
|
||||
content: [
|
||||
{ type: 'output_text', text: 'Hello', annotations: [], logprobs: [] },
|
||||
null,
|
||||
{ type: 'unknown_type' },
|
||||
{ type: 'refusal', refusal: 'No can do' },
|
||||
],
|
||||
},
|
||||
] as unknown as InputItem[];
|
||||
const result = convertInputToMessages(input);
|
||||
expect(result).toEqual([
|
||||
{
|
||||
role: 'assistant',
|
||||
content: [
|
||||
{ type: 'text', text: 'Hello' },
|
||||
{ type: 'text', text: 'No can do' },
|
||||
],
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
// ── Non-array, non-string content defaults to empty string ─────────
|
||||
it('defaults to empty string for non-array non-string content', () => {
|
||||
const input = [{ type: 'message', role: 'user', content: 42 }] as unknown as InputItem[];
|
||||
const result = convertInputToMessages(input);
|
||||
expect(result).toEqual([{ role: 'user', content: '' }]);
|
||||
});
|
||||
|
||||
// ── Function call items ────────────────────────────────────────────
|
||||
it('converts function_call items to assistant messages with tool_calls', () => {
|
||||
const input: InputItem[] = [
|
||||
{
|
||||
type: 'function_call',
|
||||
id: 'fc_1',
|
||||
call_id: 'call_abc',
|
||||
name: 'get_weather',
|
||||
arguments: '{"city":"NYC"}',
|
||||
},
|
||||
];
|
||||
const result = convertInputToMessages(input);
|
||||
expect(result).toEqual([
|
||||
{
|
||||
role: 'assistant',
|
||||
content: '',
|
||||
tool_calls: [
|
||||
{
|
||||
id: 'call_abc',
|
||||
type: 'function',
|
||||
function: { name: 'get_weather', arguments: '{"city":"NYC"}' },
|
||||
},
|
||||
],
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
// ── Function call output items ─────────────────────────────────────
|
||||
it('converts function_call_output items to tool messages', () => {
|
||||
const input: InputItem[] = [
|
||||
{
|
||||
type: 'function_call_output',
|
||||
call_id: 'call_abc',
|
||||
output: '{"temp":72}',
|
||||
},
|
||||
];
|
||||
const result = convertInputToMessages(input);
|
||||
expect(result).toEqual([
|
||||
{
|
||||
role: 'tool',
|
||||
content: '{"temp":72}',
|
||||
tool_call_id: 'call_abc',
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
// ── Item references are skipped ────────────────────────────────────
|
||||
it('skips item_reference items', () => {
|
||||
const input: InputItem[] = [
|
||||
{ type: 'item_reference', id: 'ref_123' },
|
||||
{ type: 'message', role: 'user', content: 'Hello' },
|
||||
];
|
||||
const result = convertInputToMessages(input);
|
||||
expect(result).toEqual([{ role: 'user', content: 'Hello' }]);
|
||||
});
|
||||
|
||||
// ── Multi-turn conversation (the real-world scenario) ──────────────
|
||||
it('handles a full multi-turn conversation with output_text blocks', () => {
|
||||
const input: InputItem[] = [
|
||||
{
|
||||
type: 'message',
|
||||
role: 'developer',
|
||||
content: [{ type: 'input_text', text: 'You are a helpful assistant.' }],
|
||||
},
|
||||
{
|
||||
type: 'message',
|
||||
role: 'user',
|
||||
content: [{ type: 'input_text', text: 'What is 2+2?' }],
|
||||
},
|
||||
{
|
||||
type: 'message',
|
||||
role: 'assistant',
|
||||
content: [{ type: 'output_text', text: '2+2 is 4.', annotations: [], logprobs: [] }],
|
||||
},
|
||||
{
|
||||
type: 'message',
|
||||
role: 'user',
|
||||
content: [{ type: 'input_text', text: 'And 3+3?' }],
|
||||
},
|
||||
];
|
||||
const result = convertInputToMessages(input);
|
||||
expect(result).toEqual([
|
||||
{ role: 'system', content: [{ type: 'text', text: 'You are a helpful assistant.' }] },
|
||||
{ role: 'user', content: [{ type: 'text', text: 'What is 2+2?' }] },
|
||||
{ role: 'assistant', content: [{ type: 'text', text: '2+2 is 4.' }] },
|
||||
{ role: 'user', content: [{ type: 'text', text: 'And 3+3?' }] },
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
|
@ -6,11 +6,12 @@
|
|||
*/
|
||||
import type { Response as ServerResponse } from 'express';
|
||||
import type {
|
||||
ResponseRequest,
|
||||
RequestValidationResult,
|
||||
InputItem,
|
||||
InputContent,
|
||||
ResponseRequest,
|
||||
ResponseContext,
|
||||
InputContent,
|
||||
ModelContent,
|
||||
InputItem,
|
||||
Response,
|
||||
} from './types';
|
||||
import {
|
||||
|
|
@ -134,7 +135,7 @@ export function convertInputToMessages(input: string | InputItem[]): InternalMes
|
|||
const messageItem = item as {
|
||||
type: 'message';
|
||||
role: string;
|
||||
content: string | InputContent[];
|
||||
content: string | (InputContent | ModelContent)[];
|
||||
};
|
||||
|
||||
let content: InternalMessage['content'];
|
||||
|
|
@ -142,21 +143,31 @@ export function convertInputToMessages(input: string | InputItem[]): InternalMes
|
|||
if (typeof messageItem.content === 'string') {
|
||||
content = messageItem.content;
|
||||
} else if (Array.isArray(messageItem.content)) {
|
||||
content = messageItem.content.map((part) => {
|
||||
if (part.type === 'input_text') {
|
||||
return { type: 'text', text: part.text };
|
||||
}
|
||||
if (part.type === 'input_image') {
|
||||
return {
|
||||
type: 'image_url',
|
||||
image_url: {
|
||||
url: (part as { image_url?: string }).image_url,
|
||||
detail: (part as { detail?: string }).detail,
|
||||
},
|
||||
};
|
||||
}
|
||||
return { type: part.type };
|
||||
});
|
||||
content = messageItem.content
|
||||
.filter((part): part is InputContent | ModelContent => part != null)
|
||||
.map((part) => {
|
||||
if (part.type === 'input_text' || part.type === 'output_text') {
|
||||
return { type: 'text', text: (part as { text?: string }).text ?? '' };
|
||||
}
|
||||
if (part.type === 'refusal') {
|
||||
return { type: 'text', text: (part as { refusal?: string }).refusal ?? '' };
|
||||
}
|
||||
if (part.type === 'input_image') {
|
||||
return {
|
||||
type: 'image_url',
|
||||
image_url: {
|
||||
url: (part as { image_url?: string }).image_url,
|
||||
detail: (part as { detail?: string }).detail,
|
||||
},
|
||||
};
|
||||
}
|
||||
if (part.type === 'input_file') {
|
||||
const filePart = part as { filename?: string };
|
||||
return { type: 'text', text: `[File: ${filePart.filename ?? 'unknown'}]` };
|
||||
}
|
||||
return null;
|
||||
})
|
||||
.filter((part): part is NonNullable<typeof part> => part != null);
|
||||
} else {
|
||||
content = '';
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import { getTransactionsConfig, getBalanceConfig } from './config';
|
||||
import { getTransactionsConfig, getBalanceConfig, getCustomEndpointConfig } from './config';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import { FileSources } from 'librechat-data-provider';
|
||||
import type { TCustomConfig } from 'librechat-data-provider';
|
||||
import { FileSources, EModelEndpoint } from 'librechat-data-provider';
|
||||
import type { TCustomConfig, TEndpoint } from 'librechat-data-provider';
|
||||
import type { AppConfig } from '@librechat/data-schemas';
|
||||
|
||||
// Helper function to create a minimal AppConfig for testing
|
||||
|
|
@ -282,3 +282,75 @@ describe('getBalanceConfig', () => {
|
|||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('getCustomEndpointConfig', () => {
|
||||
describe('when appConfig is not provided', () => {
|
||||
it('should throw an error', () => {
|
||||
expect(() => getCustomEndpointConfig({ endpoint: 'test' })).toThrow(
|
||||
'Config not found for the test custom endpoint.',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('when appConfig is provided', () => {
|
||||
it('should return undefined when no custom endpoints are configured', () => {
|
||||
const appConfig = createTestAppConfig();
|
||||
const result = getCustomEndpointConfig({ endpoint: 'test', appConfig });
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should return the matching endpoint config when found', () => {
|
||||
const appConfig = createTestAppConfig({
|
||||
endpoints: {
|
||||
[EModelEndpoint.custom]: [
|
||||
{
|
||||
name: 'TestEndpoint',
|
||||
apiKey: 'test-key',
|
||||
} as TEndpoint,
|
||||
],
|
||||
},
|
||||
});
|
||||
|
||||
const result = getCustomEndpointConfig({ endpoint: 'TestEndpoint', appConfig });
|
||||
expect(result).toEqual({
|
||||
name: 'TestEndpoint',
|
||||
apiKey: 'test-key',
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle case-insensitive matching for Ollama endpoint', () => {
|
||||
const appConfig = createTestAppConfig({
|
||||
endpoints: {
|
||||
[EModelEndpoint.custom]: [
|
||||
{
|
||||
name: 'Ollama',
|
||||
apiKey: 'ollama-key',
|
||||
} as TEndpoint,
|
||||
],
|
||||
},
|
||||
});
|
||||
|
||||
const result = getCustomEndpointConfig({ endpoint: 'Ollama', appConfig });
|
||||
expect(result).toEqual({
|
||||
name: 'Ollama',
|
||||
apiKey: 'ollama-key',
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle mixed case endpoint names', () => {
|
||||
const appConfig = createTestAppConfig({
|
||||
endpoints: {
|
||||
[EModelEndpoint.custom]: [
|
||||
{
|
||||
name: 'CustomAI',
|
||||
apiKey: 'custom-key',
|
||||
} as TEndpoint,
|
||||
],
|
||||
},
|
||||
});
|
||||
|
||||
const result = getCustomEndpointConfig({ endpoint: 'customai', appConfig });
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ export const getCustomEndpointConfig = ({
|
|||
|
||||
const customEndpoints = appConfig.endpoints?.[EModelEndpoint.custom] ?? [];
|
||||
return customEndpoints.find(
|
||||
(endpointConfig) => normalizeEndpointName(endpointConfig.name) === endpoint,
|
||||
(endpointConfig) => normalizeEndpointName(endpointConfig.name) === normalizeEndpointName(endpoint),
|
||||
);
|
||||
};
|
||||
|
||||
|
|
|
|||
113
packages/api/src/auth/agent.spec.ts
Normal file
113
packages/api/src/auth/agent.spec.ts
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
jest.mock('node:dns', () => {
|
||||
const actual = jest.requireActual('node:dns');
|
||||
return {
|
||||
...actual,
|
||||
lookup: jest.fn(),
|
||||
};
|
||||
});
|
||||
|
||||
import dns from 'node:dns';
|
||||
import { createSSRFSafeAgents, createSSRFSafeUndiciConnect } from './agent';
|
||||
|
||||
type LookupCallback = (err: NodeJS.ErrnoException | null, address: string, family: number) => void;
|
||||
|
||||
const mockedDnsLookup = dns.lookup as jest.MockedFunction<typeof dns.lookup>;
|
||||
|
||||
function mockDnsResult(address: string, family: number): void {
|
||||
mockedDnsLookup.mockImplementation(((
|
||||
_hostname: string,
|
||||
_options: unknown,
|
||||
callback: LookupCallback,
|
||||
) => {
|
||||
callback(null, address, family);
|
||||
}) as never);
|
||||
}
|
||||
|
||||
function mockDnsError(err: NodeJS.ErrnoException): void {
|
||||
mockedDnsLookup.mockImplementation(((
|
||||
_hostname: string,
|
||||
_options: unknown,
|
||||
callback: LookupCallback,
|
||||
) => {
|
||||
callback(err, '', 0);
|
||||
}) as never);
|
||||
}
|
||||
|
||||
describe('createSSRFSafeAgents', () => {
|
||||
afterEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should return httpAgent and httpsAgent', () => {
|
||||
const agents = createSSRFSafeAgents();
|
||||
expect(agents.httpAgent).toBeDefined();
|
||||
expect(agents.httpsAgent).toBeDefined();
|
||||
});
|
||||
|
||||
it('should patch httpAgent createConnection to inject SSRF lookup', () => {
|
||||
const agents = createSSRFSafeAgents();
|
||||
const internal = agents.httpAgent as unknown as {
|
||||
createConnection: (opts: Record<string, unknown>) => unknown;
|
||||
};
|
||||
expect(internal.createConnection).toBeInstanceOf(Function);
|
||||
});
|
||||
});
|
||||
|
||||
describe('createSSRFSafeUndiciConnect', () => {
|
||||
afterEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should return an object with a lookup function', () => {
|
||||
const connect = createSSRFSafeUndiciConnect();
|
||||
expect(connect).toHaveProperty('lookup');
|
||||
expect(connect.lookup).toBeInstanceOf(Function);
|
||||
});
|
||||
|
||||
it('lookup should block private IPs', async () => {
|
||||
mockDnsResult('10.0.0.1', 4);
|
||||
const connect = createSSRFSafeUndiciConnect();
|
||||
|
||||
const result = await new Promise<{ err: NodeJS.ErrnoException | null }>((resolve) => {
|
||||
connect.lookup('evil.example.com', {}, (err) => {
|
||||
resolve({ err });
|
||||
});
|
||||
});
|
||||
|
||||
expect(result.err).toBeTruthy();
|
||||
expect(result.err!.code).toBe('ESSRF');
|
||||
});
|
||||
|
||||
it('lookup should allow public IPs', async () => {
|
||||
mockDnsResult('93.184.216.34', 4);
|
||||
const connect = createSSRFSafeUndiciConnect();
|
||||
|
||||
const result = await new Promise<{ err: NodeJS.ErrnoException | null; address: string }>(
|
||||
(resolve) => {
|
||||
connect.lookup('example.com', {}, (err, address) => {
|
||||
resolve({ err, address: address as string });
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
expect(result.err).toBeNull();
|
||||
expect(result.address).toBe('93.184.216.34');
|
||||
});
|
||||
|
||||
it('lookup should forward DNS errors', async () => {
|
||||
const dnsError = Object.assign(new Error('ENOTFOUND'), {
|
||||
code: 'ENOTFOUND',
|
||||
}) as NodeJS.ErrnoException;
|
||||
mockDnsError(dnsError);
|
||||
const connect = createSSRFSafeUndiciConnect();
|
||||
|
||||
const result = await new Promise<{ err: NodeJS.ErrnoException | null }>((resolve) => {
|
||||
connect.lookup('nonexistent.example.com', {}, (err) => {
|
||||
resolve({ err });
|
||||
});
|
||||
});
|
||||
|
||||
expect(result.err).toBeTruthy();
|
||||
expect(result.err!.code).toBe('ENOTFOUND');
|
||||
});
|
||||
});
|
||||
61
packages/api/src/auth/agent.ts
Normal file
61
packages/api/src/auth/agent.ts
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
import dns from 'node:dns';
|
||||
import http from 'node:http';
|
||||
import https from 'node:https';
|
||||
import type { LookupFunction } from 'node:net';
|
||||
import { isPrivateIP } from './domain';
|
||||
|
||||
/** DNS lookup wrapper that blocks resolution to private/reserved IP addresses */
|
||||
const ssrfSafeLookup: LookupFunction = (hostname, options, callback) => {
|
||||
dns.lookup(hostname, options, (err, address, family) => {
|
||||
if (err) {
|
||||
callback(err, '', 0);
|
||||
return;
|
||||
}
|
||||
if (typeof address === 'string' && isPrivateIP(address)) {
|
||||
const ssrfError = Object.assign(
|
||||
new Error(`SSRF protection: ${hostname} resolved to blocked address ${address}`),
|
||||
{ code: 'ESSRF' },
|
||||
) as NodeJS.ErrnoException;
|
||||
callback(ssrfError, address, family as number);
|
||||
return;
|
||||
}
|
||||
callback(null, address as string, family as number);
|
||||
});
|
||||
};
|
||||
|
||||
/** Internal agent shape exposing createConnection (exists at runtime but not in TS types) */
|
||||
type AgentInternal = {
|
||||
createConnection: (options: Record<string, unknown>, oncreate?: unknown) => unknown;
|
||||
};
|
||||
|
||||
/** Patches an agent instance to inject SSRF-safe DNS lookup at connect time */
|
||||
function withSSRFProtection<T extends http.Agent>(agent: T): T {
|
||||
const internal = agent as unknown as AgentInternal;
|
||||
const origCreate = internal.createConnection.bind(agent);
|
||||
internal.createConnection = (options: Record<string, unknown>, oncreate?: unknown) => {
|
||||
options.lookup = ssrfSafeLookup;
|
||||
return origCreate(options, oncreate);
|
||||
};
|
||||
return agent;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates HTTP and HTTPS agents that block TCP connections to private/reserved IP addresses.
|
||||
* Provides TOCTOU-safe SSRF protection by validating the resolved IP at connect time,
|
||||
* preventing DNS rebinding attacks where a hostname resolves to a public IP during
|
||||
* pre-validation but to a private IP when the actual connection is made.
|
||||
*/
|
||||
export function createSSRFSafeAgents(): { httpAgent: http.Agent; httpsAgent: https.Agent } {
|
||||
return {
|
||||
httpAgent: withSSRFProtection(new http.Agent()),
|
||||
httpsAgent: withSSRFProtection(new https.Agent()),
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns undici-compatible `connect` options with SSRF-safe DNS lookup.
|
||||
* Pass the result as the `connect` property when constructing an undici `Agent`.
|
||||
*/
|
||||
export function createSSRFSafeUndiciConnect(): { lookup: LookupFunction } {
|
||||
return { lookup: ssrfSafeLookup };
|
||||
}
|
||||
|
|
@ -1,12 +1,21 @@
|
|||
/* eslint-disable @typescript-eslint/ban-ts-comment */
|
||||
jest.mock('node:dns/promises', () => ({
|
||||
lookup: jest.fn(),
|
||||
}));
|
||||
|
||||
import { lookup } from 'node:dns/promises';
|
||||
import {
|
||||
extractMCPServerDomain,
|
||||
isActionDomainAllowed,
|
||||
isEmailDomainAllowed,
|
||||
isMCPDomainAllowed,
|
||||
isPrivateIP,
|
||||
isSSRFTarget,
|
||||
resolveHostnameSSRF,
|
||||
} from './domain';
|
||||
|
||||
const mockedLookup = lookup as jest.MockedFunction<typeof lookup>;
|
||||
|
||||
describe('isEmailDomainAllowed', () => {
|
||||
afterEach(() => {
|
||||
jest.clearAllMocks();
|
||||
|
|
@ -192,7 +201,154 @@ describe('isSSRFTarget', () => {
|
|||
});
|
||||
});
|
||||
|
||||
describe('isPrivateIP', () => {
|
||||
describe('IPv4 private ranges', () => {
|
||||
it('should detect loopback addresses', () => {
|
||||
expect(isPrivateIP('127.0.0.1')).toBe(true);
|
||||
expect(isPrivateIP('127.255.255.255')).toBe(true);
|
||||
});
|
||||
|
||||
it('should detect 10.x.x.x private range', () => {
|
||||
expect(isPrivateIP('10.0.0.1')).toBe(true);
|
||||
expect(isPrivateIP('10.255.255.255')).toBe(true);
|
||||
});
|
||||
|
||||
it('should detect 172.16-31.x.x private range', () => {
|
||||
expect(isPrivateIP('172.16.0.1')).toBe(true);
|
||||
expect(isPrivateIP('172.31.255.255')).toBe(true);
|
||||
expect(isPrivateIP('172.15.0.1')).toBe(false);
|
||||
expect(isPrivateIP('172.32.0.1')).toBe(false);
|
||||
});
|
||||
|
||||
it('should detect 192.168.x.x private range', () => {
|
||||
expect(isPrivateIP('192.168.0.1')).toBe(true);
|
||||
expect(isPrivateIP('192.168.255.255')).toBe(true);
|
||||
});
|
||||
|
||||
it('should detect 169.254.x.x link-local range', () => {
|
||||
expect(isPrivateIP('169.254.169.254')).toBe(true);
|
||||
expect(isPrivateIP('169.254.0.1')).toBe(true);
|
||||
});
|
||||
|
||||
it('should detect 0.0.0.0', () => {
|
||||
expect(isPrivateIP('0.0.0.0')).toBe(true);
|
||||
});
|
||||
|
||||
it('should allow public IPs', () => {
|
||||
expect(isPrivateIP('8.8.8.8')).toBe(false);
|
||||
expect(isPrivateIP('1.1.1.1')).toBe(false);
|
||||
expect(isPrivateIP('93.184.216.34')).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('IPv6 private ranges', () => {
|
||||
it('should detect loopback', () => {
|
||||
expect(isPrivateIP('::1')).toBe(true);
|
||||
expect(isPrivateIP('::')).toBe(true);
|
||||
expect(isPrivateIP('[::1]')).toBe(true);
|
||||
});
|
||||
|
||||
it('should detect unique local (fc/fd) and link-local (fe80)', () => {
|
||||
expect(isPrivateIP('fc00::1')).toBe(true);
|
||||
expect(isPrivateIP('fd00::1')).toBe(true);
|
||||
expect(isPrivateIP('fe80::1')).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('IPv4-mapped IPv6 addresses', () => {
|
||||
it('should detect private IPs in IPv4-mapped IPv6 form', () => {
|
||||
expect(isPrivateIP('::ffff:169.254.169.254')).toBe(true);
|
||||
expect(isPrivateIP('::ffff:127.0.0.1')).toBe(true);
|
||||
expect(isPrivateIP('::ffff:10.0.0.1')).toBe(true);
|
||||
expect(isPrivateIP('::ffff:192.168.1.1')).toBe(true);
|
||||
});
|
||||
|
||||
it('should allow public IPs in IPv4-mapped IPv6 form', () => {
|
||||
expect(isPrivateIP('::ffff:8.8.8.8')).toBe(false);
|
||||
expect(isPrivateIP('::ffff:93.184.216.34')).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('resolveHostnameSSRF', () => {
|
||||
afterEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should detect domains that resolve to private IPs (nip.io bypass)', async () => {
|
||||
mockedLookup.mockResolvedValueOnce([{ address: '169.254.169.254', family: 4 }] as never);
|
||||
expect(await resolveHostnameSSRF('169.254.169.254.nip.io')).toBe(true);
|
||||
});
|
||||
|
||||
it('should detect domains that resolve to loopback', async () => {
|
||||
mockedLookup.mockResolvedValueOnce([{ address: '127.0.0.1', family: 4 }] as never);
|
||||
expect(await resolveHostnameSSRF('loopback.example.com')).toBe(true);
|
||||
});
|
||||
|
||||
it('should detect when any resolved address is private', async () => {
|
||||
mockedLookup.mockResolvedValueOnce([
|
||||
{ address: '93.184.216.34', family: 4 },
|
||||
{ address: '10.0.0.1', family: 4 },
|
||||
] as never);
|
||||
expect(await resolveHostnameSSRF('dual.example.com')).toBe(true);
|
||||
});
|
||||
|
||||
it('should allow domains that resolve to public IPs', async () => {
|
||||
mockedLookup.mockResolvedValueOnce([{ address: '93.184.216.34', family: 4 }] as never);
|
||||
expect(await resolveHostnameSSRF('example.com')).toBe(false);
|
||||
});
|
||||
|
||||
it('should skip literal IPv4 addresses (handled by isSSRFTarget)', async () => {
|
||||
expect(await resolveHostnameSSRF('169.254.169.254')).toBe(false);
|
||||
expect(mockedLookup).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should skip literal IPv6 addresses', async () => {
|
||||
expect(await resolveHostnameSSRF('::1')).toBe(false);
|
||||
expect(mockedLookup).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should fail open on DNS resolution failure', async () => {
|
||||
mockedLookup.mockRejectedValueOnce(new Error('ENOTFOUND'));
|
||||
expect(await resolveHostnameSSRF('nonexistent.example.com')).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isActionDomainAllowed - DNS resolution SSRF protection', () => {
|
||||
afterEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should block domains resolving to cloud metadata IP (169.254.169.254)', async () => {
|
||||
mockedLookup.mockResolvedValueOnce([{ address: '169.254.169.254', family: 4 }] as never);
|
||||
expect(await isActionDomainAllowed('169.254.169.254.nip.io', null)).toBe(false);
|
||||
});
|
||||
|
||||
it('should block domains resolving to private 10.x range', async () => {
|
||||
mockedLookup.mockResolvedValueOnce([{ address: '10.0.0.5', family: 4 }] as never);
|
||||
expect(await isActionDomainAllowed('internal.attacker.com', null)).toBe(false);
|
||||
});
|
||||
|
||||
it('should block domains resolving to 172.16.x range', async () => {
|
||||
mockedLookup.mockResolvedValueOnce([{ address: '172.16.0.1', family: 4 }] as never);
|
||||
expect(await isActionDomainAllowed('docker.attacker.com', null)).toBe(false);
|
||||
});
|
||||
|
||||
it('should allow domains resolving to public IPs when no allowlist', async () => {
|
||||
mockedLookup.mockResolvedValueOnce([{ address: '93.184.216.34', family: 4 }] as never);
|
||||
expect(await isActionDomainAllowed('example.com', null)).toBe(true);
|
||||
});
|
||||
|
||||
it('should not perform DNS check when allowedDomains is configured', async () => {
|
||||
expect(await isActionDomainAllowed('example.com', ['example.com'])).toBe(true);
|
||||
expect(mockedLookup).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('isActionDomainAllowed', () => {
|
||||
beforeEach(() => {
|
||||
mockedLookup.mockResolvedValue([{ address: '93.184.216.34', family: 4 }] as never);
|
||||
});
|
||||
afterEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
|
@ -541,6 +697,9 @@ describe('extractMCPServerDomain', () => {
|
|||
});
|
||||
|
||||
describe('isMCPDomainAllowed', () => {
|
||||
beforeEach(() => {
|
||||
mockedLookup.mockResolvedValue([{ address: '93.184.216.34', family: 4 }] as never);
|
||||
});
|
||||
afterEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import { lookup } from 'node:dns/promises';
|
||||
|
||||
/**
|
||||
* @param email
|
||||
* @param allowedDomains
|
||||
|
|
@ -22,6 +24,88 @@ export function isEmailDomainAllowed(email: string, allowedDomains?: string[] |
|
|||
return allowedDomains.some((allowedDomain) => allowedDomain?.toLowerCase() === domain);
|
||||
}
|
||||
|
||||
/** Checks if IPv4 octets fall within private, reserved, or link-local ranges */
|
||||
function isPrivateIPv4(a: number, b: number, c: number): boolean {
|
||||
if (a === 127) {
|
||||
return true;
|
||||
}
|
||||
if (a === 10) {
|
||||
return true;
|
||||
}
|
||||
if (a === 172 && b >= 16 && b <= 31) {
|
||||
return true;
|
||||
}
|
||||
if (a === 192 && b === 168) {
|
||||
return true;
|
||||
}
|
||||
if (a === 169 && b === 254) {
|
||||
return true;
|
||||
}
|
||||
if (a === 0 && b === 0 && c === 0) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if an IP address belongs to a private, reserved, or link-local range.
|
||||
* Handles IPv4, IPv6, and IPv4-mapped IPv6 addresses (::ffff:A.B.C.D).
|
||||
*/
|
||||
export function isPrivateIP(ip: string): boolean {
|
||||
const normalized = ip.toLowerCase().trim();
|
||||
|
||||
const mappedMatch = normalized.match(/^::ffff:(\d{1,3})\.(\d{1,3})\.(\d{1,3})\.(\d{1,3})$/);
|
||||
if (mappedMatch) {
|
||||
const [, a, b, c] = mappedMatch.map(Number);
|
||||
return isPrivateIPv4(a, b, c);
|
||||
}
|
||||
|
||||
const ipv4Match = normalized.match(/^(\d{1,3})\.(\d{1,3})\.(\d{1,3})\.(\d{1,3})$/);
|
||||
if (ipv4Match) {
|
||||
const [, a, b, c] = ipv4Match.map(Number);
|
||||
return isPrivateIPv4(a, b, c);
|
||||
}
|
||||
|
||||
const ipv6 = normalized.replace(/^\[|\]$/g, '');
|
||||
if (
|
||||
ipv6 === '::1' ||
|
||||
ipv6 === '::' ||
|
||||
ipv6.startsWith('fc') ||
|
||||
ipv6.startsWith('fd') ||
|
||||
ipv6.startsWith('fe80')
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolves a hostname via DNS and checks if any resolved address is a private/reserved IP.
|
||||
* Detects DNS-based SSRF bypasses (e.g., nip.io wildcard DNS, attacker-controlled nameservers).
|
||||
* Fails open: returns false if DNS resolution fails, since hostname-only checks still apply
|
||||
* and the actual HTTP request would also fail.
|
||||
*/
|
||||
export async function resolveHostnameSSRF(hostname: string): Promise<boolean> {
|
||||
const normalizedHost = hostname.toLowerCase().trim();
|
||||
|
||||
if (/^(\d{1,3})\.(\d{1,3})\.(\d{1,3})\.(\d{1,3})$/.test(normalizedHost)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const ipv6Check = normalizedHost.replace(/^\[|\]$/g, '');
|
||||
if (ipv6Check.includes(':')) {
|
||||
return false;
|
||||
}
|
||||
|
||||
try {
|
||||
const addresses = await lookup(hostname, { all: true });
|
||||
return addresses.some((entry) => isPrivateIP(entry.address));
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* SSRF Protection: Checks if a hostname/IP is a potentially dangerous internal target.
|
||||
* Blocks private IPs, localhost, cloud metadata IPs, and common internal hostnames.
|
||||
|
|
@ -31,7 +115,6 @@ export function isEmailDomainAllowed(email: string, allowedDomains?: string[] |
|
|||
export function isSSRFTarget(hostname: string): boolean {
|
||||
const normalizedHost = hostname.toLowerCase().trim();
|
||||
|
||||
// Block localhost variations
|
||||
if (
|
||||
normalizedHost === 'localhost' ||
|
||||
normalizedHost === 'localhost.localdomain' ||
|
||||
|
|
@ -40,51 +123,7 @@ export function isSSRFTarget(hostname: string): boolean {
|
|||
return true;
|
||||
}
|
||||
|
||||
// Check if it's an IP address and block private/internal ranges
|
||||
const ipv4Match = normalizedHost.match(/^(\d{1,3})\.(\d{1,3})\.(\d{1,3})\.(\d{1,3})$/);
|
||||
if (ipv4Match) {
|
||||
const [, a, b, c] = ipv4Match.map(Number);
|
||||
|
||||
// 127.0.0.0/8 - Loopback
|
||||
if (a === 127) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// 10.0.0.0/8 - Private
|
||||
if (a === 10) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// 172.16.0.0/12 - Private (172.16.x.x - 172.31.x.x)
|
||||
if (a === 172 && b >= 16 && b <= 31) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// 192.168.0.0/16 - Private
|
||||
if (a === 192 && b === 168) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// 169.254.0.0/16 - Link-local (includes cloud metadata 169.254.169.254)
|
||||
if (a === 169 && b === 254) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// 0.0.0.0 - Special
|
||||
if (a === 0 && b === 0 && c === 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// IPv6 loopback and private ranges
|
||||
const ipv6Normalized = normalizedHost.replace(/^\[|\]$/g, ''); // Remove brackets if present
|
||||
if (
|
||||
ipv6Normalized === '::1' ||
|
||||
ipv6Normalized === '::' ||
|
||||
ipv6Normalized.startsWith('fc') || // fc00::/7 - Unique local
|
||||
ipv6Normalized.startsWith('fd') || // fd00::/8 - Unique local
|
||||
ipv6Normalized.startsWith('fe80') // fe80::/10 - Link-local
|
||||
) {
|
||||
if (isPrivateIP(normalizedHost)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
@ -257,6 +296,10 @@ async function isDomainAllowedCore(
|
|||
if (isSSRFTarget(inputSpec.hostname)) {
|
||||
return false;
|
||||
}
|
||||
/** SECURITY: Resolve hostname and block if it points to a private/reserved IP */
|
||||
if (await resolveHostnameSSRF(inputSpec.hostname)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
export * from './domain';
|
||||
export * from './openid';
|
||||
export * from './exchange';
|
||||
export * from './agent';
|
||||
|
|
|
|||
|
|
@ -215,16 +215,30 @@ describe('cacheConfig', () => {
|
|||
}).rejects.toThrow('Invalid cache keys in FORCED_IN_MEMORY_CACHE_NAMESPACES: INVALID_KEY');
|
||||
});
|
||||
|
||||
test('should handle empty string gracefully', async () => {
|
||||
test('should produce empty array when set to empty string (opt out of defaults)', async () => {
|
||||
process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES = '';
|
||||
|
||||
const { cacheConfig } = await import('../cacheConfig');
|
||||
expect(cacheConfig.FORCED_IN_MEMORY_CACHE_NAMESPACES).toEqual([]);
|
||||
});
|
||||
|
||||
test('should handle undefined env var gracefully', async () => {
|
||||
test('should default to CONFIG_STORE and APP_CONFIG when env var is not set', async () => {
|
||||
const { cacheConfig } = await import('../cacheConfig');
|
||||
expect(cacheConfig.FORCED_IN_MEMORY_CACHE_NAMESPACES).toEqual([]);
|
||||
expect(cacheConfig.FORCED_IN_MEMORY_CACHE_NAMESPACES).toEqual(['CONFIG_STORE', 'APP_CONFIG']);
|
||||
});
|
||||
|
||||
test('should accept TOOL_CACHE as a valid namespace', async () => {
|
||||
process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES = 'TOOL_CACHE';
|
||||
|
||||
const { cacheConfig } = await import('../cacheConfig');
|
||||
expect(cacheConfig.FORCED_IN_MEMORY_CACHE_NAMESPACES).toEqual(['TOOL_CACHE']);
|
||||
});
|
||||
|
||||
test('should accept CONFIG_STORE and APP_CONFIG together for blue/green deployments', async () => {
|
||||
process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES = 'CONFIG_STORE,APP_CONFIG';
|
||||
|
||||
const { cacheConfig } = await import('../cacheConfig');
|
||||
expect(cacheConfig.FORCED_IN_MEMORY_CACHE_NAMESPACES).toEqual(['CONFIG_STORE', 'APP_CONFIG']);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
135
packages/api/src/cache/__tests__/cacheFactory/standardCache.namespace_isolation.spec.ts
vendored
Normal file
135
packages/api/src/cache/__tests__/cacheFactory/standardCache.namespace_isolation.spec.ts
vendored
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
import { CacheKeys } from 'librechat-data-provider';
|
||||
|
||||
const mockKeyvRedisInstance = {
|
||||
namespace: '',
|
||||
keyPrefixSeparator: '',
|
||||
on: jest.fn(),
|
||||
};
|
||||
|
||||
const MockKeyvRedis = jest.fn().mockReturnValue(mockKeyvRedisInstance);
|
||||
|
||||
jest.mock('@keyv/redis', () => ({
|
||||
default: MockKeyvRedis,
|
||||
}));
|
||||
|
||||
const mockKeyvRedisClient = { scanIterator: jest.fn() };
|
||||
|
||||
jest.mock('../../redisClients', () => ({
|
||||
keyvRedisClient: mockKeyvRedisClient,
|
||||
ioredisClient: null,
|
||||
}));
|
||||
|
||||
jest.mock('../../redisUtils', () => ({
|
||||
batchDeleteKeys: jest.fn(),
|
||||
scanKeys: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
error: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
describe('standardCache - CONFIG_STORE vs TOOL_CACHE namespace isolation', () => {
|
||||
afterEach(() => {
|
||||
jest.resetModules();
|
||||
MockKeyvRedis.mockClear();
|
||||
});
|
||||
|
||||
/**
|
||||
* Core behavioral test for blue/green deployments:
|
||||
* When CONFIG_STORE and APP_CONFIG are forced in-memory,
|
||||
* TOOL_CACHE should still use Redis for cross-container sharing.
|
||||
*/
|
||||
it('should force CONFIG_STORE to in-memory while TOOL_CACHE uses Redis', async () => {
|
||||
jest.doMock('../../cacheConfig', () => ({
|
||||
cacheConfig: {
|
||||
FORCED_IN_MEMORY_CACHE_NAMESPACES: [CacheKeys.CONFIG_STORE, CacheKeys.APP_CONFIG],
|
||||
REDIS_KEY_PREFIX: '',
|
||||
GLOBAL_PREFIX_SEPARATOR: '>>',
|
||||
},
|
||||
}));
|
||||
|
||||
const { standardCache } = await import('../../cacheFactory');
|
||||
|
||||
MockKeyvRedis.mockClear();
|
||||
|
||||
const configCache = standardCache(CacheKeys.CONFIG_STORE);
|
||||
expect(MockKeyvRedis).not.toHaveBeenCalled();
|
||||
expect(configCache).toBeDefined();
|
||||
|
||||
const appConfigCache = standardCache(CacheKeys.APP_CONFIG);
|
||||
expect(MockKeyvRedis).not.toHaveBeenCalled();
|
||||
expect(appConfigCache).toBeDefined();
|
||||
|
||||
const toolCache = standardCache(CacheKeys.TOOL_CACHE);
|
||||
expect(MockKeyvRedis).toHaveBeenCalledTimes(1);
|
||||
expect(MockKeyvRedis).toHaveBeenCalledWith(mockKeyvRedisClient);
|
||||
expect(toolCache).toBeDefined();
|
||||
});
|
||||
|
||||
it('CONFIG_STORE and TOOL_CACHE should be independent stores', async () => {
|
||||
jest.doMock('../../cacheConfig', () => ({
|
||||
cacheConfig: {
|
||||
FORCED_IN_MEMORY_CACHE_NAMESPACES: [CacheKeys.CONFIG_STORE],
|
||||
REDIS_KEY_PREFIX: '',
|
||||
GLOBAL_PREFIX_SEPARATOR: '>>',
|
||||
},
|
||||
}));
|
||||
|
||||
const { standardCache } = await import('../../cacheFactory');
|
||||
|
||||
const configCache = standardCache(CacheKeys.CONFIG_STORE);
|
||||
const toolCache = standardCache(CacheKeys.TOOL_CACHE);
|
||||
|
||||
await configCache.set('STARTUP_CONFIG', { version: 'v2-green' });
|
||||
await toolCache.set('tools:global', { myTool: { type: 'function' } });
|
||||
|
||||
expect(await configCache.get('STARTUP_CONFIG')).toEqual({ version: 'v2-green' });
|
||||
expect(await configCache.get('tools:global')).toBeUndefined();
|
||||
|
||||
expect(await toolCache.get('STARTUP_CONFIG')).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should use Redis for all namespaces when nothing is forced in-memory', async () => {
|
||||
jest.doMock('../../cacheConfig', () => ({
|
||||
cacheConfig: {
|
||||
FORCED_IN_MEMORY_CACHE_NAMESPACES: [],
|
||||
REDIS_KEY_PREFIX: '',
|
||||
GLOBAL_PREFIX_SEPARATOR: '>>',
|
||||
},
|
||||
}));
|
||||
|
||||
const { standardCache } = await import('../../cacheFactory');
|
||||
|
||||
MockKeyvRedis.mockClear();
|
||||
|
||||
standardCache(CacheKeys.CONFIG_STORE);
|
||||
standardCache(CacheKeys.TOOL_CACHE);
|
||||
standardCache(CacheKeys.APP_CONFIG);
|
||||
|
||||
expect(MockKeyvRedis).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
|
||||
it('forcing TOOL_CACHE to in-memory should not affect CONFIG_STORE', async () => {
|
||||
jest.doMock('../../cacheConfig', () => ({
|
||||
cacheConfig: {
|
||||
FORCED_IN_MEMORY_CACHE_NAMESPACES: [CacheKeys.TOOL_CACHE],
|
||||
REDIS_KEY_PREFIX: '',
|
||||
GLOBAL_PREFIX_SEPARATOR: '>>',
|
||||
},
|
||||
}));
|
||||
|
||||
const { standardCache } = await import('../../cacheFactory');
|
||||
|
||||
MockKeyvRedis.mockClear();
|
||||
|
||||
standardCache(CacheKeys.TOOL_CACHE);
|
||||
expect(MockKeyvRedis).not.toHaveBeenCalled();
|
||||
|
||||
standardCache(CacheKeys.CONFIG_STORE);
|
||||
expect(MockKeyvRedis).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
|
@ -20,6 +20,24 @@ interface ViolationData {
|
|||
};
|
||||
}
|
||||
|
||||
/** Waits for both Redis clients (ioredis + keyv/node-redis) to be ready */
|
||||
async function waitForRedisClients() {
|
||||
const redisClients = await import('../../redisClients');
|
||||
const { ioredisClient, keyvRedisClientReady } = redisClients;
|
||||
|
||||
if (ioredisClient && ioredisClient.status !== 'ready') {
|
||||
await new Promise<void>((resolve) => {
|
||||
ioredisClient.once('ready', resolve);
|
||||
});
|
||||
}
|
||||
|
||||
if (keyvRedisClientReady) {
|
||||
await keyvRedisClientReady;
|
||||
}
|
||||
|
||||
return redisClients;
|
||||
}
|
||||
|
||||
describe('violationCache', () => {
|
||||
let originalEnv: NodeJS.ProcessEnv;
|
||||
|
||||
|
|
@ -45,17 +63,9 @@ describe('violationCache', () => {
|
|||
|
||||
test('should create violation cache with Redis when USE_REDIS is true', async () => {
|
||||
const cacheFactory = await import('../../cacheFactory');
|
||||
const redisClients = await import('../../redisClients');
|
||||
const { ioredisClient } = redisClients;
|
||||
await waitForRedisClients();
|
||||
const cache = cacheFactory.violationCache('test-violations', 60000); // 60 second TTL
|
||||
|
||||
// Wait for Redis connection to be ready
|
||||
if (ioredisClient && ioredisClient.status !== 'ready') {
|
||||
await new Promise<void>((resolve) => {
|
||||
ioredisClient.once('ready', resolve);
|
||||
});
|
||||
}
|
||||
|
||||
// Verify it returns a Keyv instance
|
||||
expect(cache).toBeDefined();
|
||||
expect(cache.constructor.name).toBe('Keyv');
|
||||
|
|
@ -112,18 +122,10 @@ describe('violationCache', () => {
|
|||
|
||||
test('should respect namespace prefixing', async () => {
|
||||
const cacheFactory = await import('../../cacheFactory');
|
||||
const redisClients = await import('../../redisClients');
|
||||
const { ioredisClient } = redisClients;
|
||||
await waitForRedisClients();
|
||||
const cache1 = cacheFactory.violationCache('namespace1');
|
||||
const cache2 = cacheFactory.violationCache('namespace2');
|
||||
|
||||
// Wait for Redis connection to be ready
|
||||
if (ioredisClient && ioredisClient.status !== 'ready') {
|
||||
await new Promise<void>((resolve) => {
|
||||
ioredisClient.once('ready', resolve);
|
||||
});
|
||||
}
|
||||
|
||||
const testKey = 'shared-key';
|
||||
const value1: ViolationData = { namespace: 1 };
|
||||
const value2: ViolationData = { namespace: 2 };
|
||||
|
|
@ -146,18 +148,10 @@ describe('violationCache', () => {
|
|||
|
||||
test('should respect TTL settings', async () => {
|
||||
const cacheFactory = await import('../../cacheFactory');
|
||||
const redisClients = await import('../../redisClients');
|
||||
const { ioredisClient } = redisClients;
|
||||
await waitForRedisClients();
|
||||
const ttl = 1000; // 1 second TTL
|
||||
const cache = cacheFactory.violationCache('ttl-test', ttl);
|
||||
|
||||
// Wait for Redis connection to be ready
|
||||
if (ioredisClient && ioredisClient.status !== 'ready') {
|
||||
await new Promise<void>((resolve) => {
|
||||
ioredisClient.once('ready', resolve);
|
||||
});
|
||||
}
|
||||
|
||||
const testKey = 'ttl-key';
|
||||
const testValue: ViolationData = { data: 'expires soon' };
|
||||
|
||||
|
|
@ -178,17 +172,9 @@ describe('violationCache', () => {
|
|||
|
||||
test('should handle complex violation data structures', async () => {
|
||||
const cacheFactory = await import('../../cacheFactory');
|
||||
const redisClients = await import('../../redisClients');
|
||||
const { ioredisClient } = redisClients;
|
||||
await waitForRedisClients();
|
||||
const cache = cacheFactory.violationCache('complex-violations');
|
||||
|
||||
// Wait for Redis connection to be ready
|
||||
if (ioredisClient && ioredisClient.status !== 'ready') {
|
||||
await new Promise<void>((resolve) => {
|
||||
ioredisClient.once('ready', resolve);
|
||||
});
|
||||
}
|
||||
|
||||
const complexData: ViolationData = {
|
||||
userId: 'user123',
|
||||
violations: [
|
||||
|
|
|
|||
11
packages/api/src/cache/cacheConfig.ts
vendored
11
packages/api/src/cache/cacheConfig.ts
vendored
|
|
@ -27,9 +27,14 @@ const USE_REDIS_STREAMS =
|
|||
|
||||
// Comma-separated list of cache namespaces that should be forced to use in-memory storage
|
||||
// even when Redis is enabled. This allows selective performance optimization for specific caches.
|
||||
const FORCED_IN_MEMORY_CACHE_NAMESPACES = process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES
|
||||
? process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES.split(',').map((key) => key.trim())
|
||||
: [];
|
||||
// Defaults to CONFIG_STORE,APP_CONFIG so YAML-derived config stays per-container.
|
||||
// Set to empty string to force all namespaces through Redis.
|
||||
const FORCED_IN_MEMORY_CACHE_NAMESPACES =
|
||||
process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES !== undefined
|
||||
? process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES.split(',')
|
||||
.map((key) => key.trim())
|
||||
.filter(Boolean)
|
||||
: [CacheKeys.CONFIG_STORE, CacheKeys.APP_CONFIG];
|
||||
|
||||
// Validate against CacheKeys enum
|
||||
if (FORCED_IN_MEMORY_CACHE_NAMESPACES.length > 0) {
|
||||
|
|
|
|||
4
packages/api/src/cache/cacheFactory.ts
vendored
4
packages/api/src/cache/cacheFactory.ts
vendored
|
|
@ -120,7 +120,9 @@ export const limiterCache = (prefix: string): RedisStore | undefined => {
|
|||
if (!cacheConfig.USE_REDIS) {
|
||||
return undefined;
|
||||
}
|
||||
// TODO: The prefix is not actually applied. Also needs to account for global prefix.
|
||||
// Note: The `prefix` is applied by RedisStore internally to its key operations.
|
||||
// The global REDIS_KEY_PREFIX is applied by ioredisClient's keyPrefix setting.
|
||||
// Combined key format: `{REDIS_KEY_PREFIX}::{prefix}{identifier}`
|
||||
prefix = prefix.endsWith(':') ? prefix : `${prefix}:`;
|
||||
|
||||
try {
|
||||
|
|
|
|||
12
packages/api/src/cache/redisClients.ts
vendored
12
packages/api/src/cache/redisClients.ts
vendored
|
|
@ -29,7 +29,9 @@ if (cacheConfig.USE_REDIS) {
|
|||
);
|
||||
return null;
|
||||
}
|
||||
const delay = Math.min(times * 50, cacheConfig.REDIS_RETRY_MAX_DELAY);
|
||||
const base = Math.min(Math.pow(2, times) * 50, cacheConfig.REDIS_RETRY_MAX_DELAY);
|
||||
const jitter = Math.floor(Math.random() * Math.min(base, 1000));
|
||||
const delay = Math.min(base + jitter, cacheConfig.REDIS_RETRY_MAX_DELAY);
|
||||
logger.info(`ioredis reconnecting... attempt ${times}, delay ${delay}ms`);
|
||||
return delay;
|
||||
},
|
||||
|
|
@ -71,7 +73,9 @@ if (cacheConfig.USE_REDIS) {
|
|||
);
|
||||
return null;
|
||||
}
|
||||
const delay = Math.min(times * 100, cacheConfig.REDIS_RETRY_MAX_DELAY);
|
||||
const base = Math.min(Math.pow(2, times) * 100, cacheConfig.REDIS_RETRY_MAX_DELAY);
|
||||
const jitter = Math.floor(Math.random() * Math.min(base, 1000));
|
||||
const delay = Math.min(base + jitter, cacheConfig.REDIS_RETRY_MAX_DELAY);
|
||||
logger.info(`ioredis cluster reconnecting... attempt ${times}, delay ${delay}ms`);
|
||||
return delay;
|
||||
},
|
||||
|
|
@ -149,7 +153,9 @@ if (cacheConfig.USE_REDIS) {
|
|||
);
|
||||
return new Error('Max reconnection attempts reached');
|
||||
}
|
||||
const delay = Math.min(retries * 100, cacheConfig.REDIS_RETRY_MAX_DELAY);
|
||||
const base = Math.min(Math.pow(2, retries) * 100, cacheConfig.REDIS_RETRY_MAX_DELAY);
|
||||
const jitter = Math.floor(Math.random() * Math.min(base, 1000));
|
||||
const delay = Math.min(base + jitter, cacheConfig.REDIS_RETRY_MAX_DELAY);
|
||||
logger.info(`@keyv/redis reconnecting... attempt ${retries}, delay ${delay}ms`);
|
||||
return delay;
|
||||
},
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ function getClaudeHeaders(
|
|||
|
||||
/**
|
||||
* Configures reasoning-related options for Claude models.
|
||||
* Models supporting adaptive thinking (Opus 4.6+, Sonnet 5+) use effort control instead of manual budget_tokens.
|
||||
* Models supporting adaptive thinking (Opus 4.6+, Sonnet 4.6+) use effort control instead of manual budget_tokens.
|
||||
*/
|
||||
function configureReasoning(
|
||||
anthropicInput: AnthropicClientOptions & { max_tokens?: number },
|
||||
|
|
|
|||
|
|
@ -121,6 +121,39 @@ describe('getLLMConfig', () => {
|
|||
});
|
||||
});
|
||||
|
||||
it('should add "context-1m" beta header for claude-sonnet-4-6 model', () => {
|
||||
const modelOptions = {
|
||||
model: 'claude-sonnet-4-6',
|
||||
promptCache: true,
|
||||
};
|
||||
const result = getLLMConfig('test-key', { modelOptions });
|
||||
const clientOptions = result.llmConfig.clientOptions;
|
||||
expect(clientOptions?.defaultHeaders).toBeDefined();
|
||||
expect(clientOptions?.defaultHeaders).toHaveProperty('anthropic-beta');
|
||||
const defaultHeaders = clientOptions?.defaultHeaders as Record<string, string>;
|
||||
expect(defaultHeaders['anthropic-beta']).toBe('context-1m-2025-08-07');
|
||||
expect(result.llmConfig.promptCache).toBe(true);
|
||||
});
|
||||
|
||||
it('should add "context-1m" beta header for claude-sonnet-4-6 model formats', () => {
|
||||
const modelVariations = [
|
||||
'claude-sonnet-4-6',
|
||||
'claude-sonnet-4-6-20260101',
|
||||
'anthropic/claude-sonnet-4-6',
|
||||
];
|
||||
|
||||
modelVariations.forEach((model) => {
|
||||
const modelOptions = { model, promptCache: true };
|
||||
const result = getLLMConfig('test-key', { modelOptions });
|
||||
const clientOptions = result.llmConfig.clientOptions;
|
||||
expect(clientOptions?.defaultHeaders).toBeDefined();
|
||||
expect(clientOptions?.defaultHeaders).toHaveProperty('anthropic-beta');
|
||||
const defaultHeaders = clientOptions?.defaultHeaders as Record<string, string>;
|
||||
expect(defaultHeaders['anthropic-beta']).toBe('context-1m-2025-08-07');
|
||||
expect(result.llmConfig.promptCache).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
it('should pass promptCache boolean for claude-opus-4-5 model (no beta header needed)', () => {
|
||||
const modelOptions = {
|
||||
model: 'claude-opus-4-5',
|
||||
|
|
@ -963,6 +996,51 @@ describe('getLLMConfig', () => {
|
|||
});
|
||||
});
|
||||
|
||||
it('should use adaptive thinking for Sonnet 4.6 instead of enabled + budget_tokens', () => {
|
||||
const result = getLLMConfig('test-key', {
|
||||
modelOptions: {
|
||||
model: 'claude-sonnet-4-6',
|
||||
thinking: true,
|
||||
thinkingBudget: 10000,
|
||||
},
|
||||
});
|
||||
|
||||
expect((result.llmConfig.thinking as unknown as { type: string }).type).toBe('adaptive');
|
||||
expect(result.llmConfig.thinking).not.toHaveProperty('budget_tokens');
|
||||
expect(result.llmConfig.maxTokens).toBe(64000);
|
||||
});
|
||||
|
||||
it('should set effort via output_config for Sonnet 4.6', () => {
|
||||
const result = getLLMConfig('test-key', {
|
||||
modelOptions: {
|
||||
model: 'claude-sonnet-4-6',
|
||||
thinking: true,
|
||||
effort: AnthropicEffort.high,
|
||||
},
|
||||
});
|
||||
|
||||
expect((result.llmConfig.thinking as unknown as { type: string }).type).toBe('adaptive');
|
||||
expect(result.llmConfig.invocationKwargs).toHaveProperty('output_config');
|
||||
expect(result.llmConfig.invocationKwargs?.output_config).toEqual({
|
||||
effort: AnthropicEffort.high,
|
||||
});
|
||||
});
|
||||
|
||||
it('should exclude topP/topK for Sonnet 4.6 with adaptive thinking', () => {
|
||||
const result = getLLMConfig('test-key', {
|
||||
modelOptions: {
|
||||
model: 'claude-sonnet-4-6',
|
||||
thinking: true,
|
||||
topP: 0.9,
|
||||
topK: 40,
|
||||
},
|
||||
});
|
||||
|
||||
expect((result.llmConfig.thinking as unknown as { type: string }).type).toBe('adaptive');
|
||||
expect(result.llmConfig).not.toHaveProperty('topP');
|
||||
expect(result.llmConfig).not.toHaveProperty('topK');
|
||||
});
|
||||
|
||||
it('should NOT set adaptive thinking or effort for non-adaptive models', () => {
|
||||
const nonAdaptiveModels = [
|
||||
'claude-opus-4-5',
|
||||
|
|
|
|||
|
|
@ -5,5 +5,6 @@ export * from './filter';
|
|||
export * from './mistral/crud';
|
||||
export * from './ocr';
|
||||
export * from './parse';
|
||||
export * from './rag';
|
||||
export * from './validation';
|
||||
export * from './text';
|
||||
|
|
|
|||
150
packages/api/src/files/rag.spec.ts
Normal file
150
packages/api/src/files/rag.spec.ts
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
debug: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('~/crypto/jwt', () => ({
|
||||
generateShortLivedToken: jest.fn().mockReturnValue('mock-jwt-token'),
|
||||
}));
|
||||
|
||||
jest.mock('axios', () => ({
|
||||
delete: jest.fn(),
|
||||
interceptors: {
|
||||
request: { use: jest.fn(), eject: jest.fn() },
|
||||
response: { use: jest.fn(), eject: jest.fn() },
|
||||
},
|
||||
}));
|
||||
|
||||
import axios from 'axios';
|
||||
import { deleteRagFile } from './rag';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import { generateShortLivedToken } from '~/crypto/jwt';
|
||||
|
||||
const mockedAxios = axios as jest.Mocked<typeof axios>;
|
||||
const mockedLogger = logger as jest.Mocked<typeof logger>;
|
||||
const mockedGenerateShortLivedToken = generateShortLivedToken as jest.MockedFunction<
|
||||
typeof generateShortLivedToken
|
||||
>;
|
||||
|
||||
describe('deleteRagFile', () => {
|
||||
const originalEnv = process.env;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
process.env = { ...originalEnv };
|
||||
process.env.RAG_API_URL = 'http://localhost:8000';
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
process.env = originalEnv;
|
||||
});
|
||||
|
||||
describe('when file is embedded and RAG_API_URL is configured', () => {
|
||||
it('should delete the document from RAG API successfully', async () => {
|
||||
const file = { file_id: 'file-123', embedded: true };
|
||||
mockedAxios.delete.mockResolvedValueOnce({ status: 200 });
|
||||
|
||||
const result = await deleteRagFile({ userId: 'user123', file });
|
||||
|
||||
expect(result).toBe(true);
|
||||
expect(mockedGenerateShortLivedToken).toHaveBeenCalledWith('user123');
|
||||
expect(mockedAxios.delete).toHaveBeenCalledWith('http://localhost:8000/documents', {
|
||||
headers: {
|
||||
Authorization: 'Bearer mock-jwt-token',
|
||||
'Content-Type': 'application/json',
|
||||
accept: 'application/json',
|
||||
},
|
||||
data: ['file-123'],
|
||||
});
|
||||
expect(mockedLogger.debug).toHaveBeenCalledWith(
|
||||
'[deleteRagFile] Successfully deleted document file-123 from RAG API',
|
||||
);
|
||||
});
|
||||
|
||||
it('should return true and log warning when document is not found (404)', async () => {
|
||||
const file = { file_id: 'file-not-found', embedded: true };
|
||||
const error = new Error('Not Found') as Error & { response?: { status?: number } };
|
||||
error.response = { status: 404 };
|
||||
mockedAxios.delete.mockRejectedValueOnce(error);
|
||||
|
||||
const result = await deleteRagFile({ userId: 'user123', file });
|
||||
|
||||
expect(result).toBe(true);
|
||||
expect(mockedLogger.warn).toHaveBeenCalledWith(
|
||||
'[deleteRagFile] Document file-not-found not found in RAG API, may have been deleted already',
|
||||
);
|
||||
});
|
||||
|
||||
it('should return false and log error on other errors', async () => {
|
||||
const file = { file_id: 'file-error', embedded: true };
|
||||
const error = new Error('Server Error') as Error & { response?: { status?: number } };
|
||||
error.response = { status: 500 };
|
||||
mockedAxios.delete.mockRejectedValueOnce(error);
|
||||
|
||||
const result = await deleteRagFile({ userId: 'user123', file });
|
||||
|
||||
expect(result).toBe(false);
|
||||
expect(mockedLogger.error).toHaveBeenCalledWith(
|
||||
'[deleteRagFile] Error deleting document from RAG API:',
|
||||
'Server Error',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('when file is not embedded', () => {
|
||||
it('should skip RAG deletion and return true', async () => {
|
||||
const file = { file_id: 'file-123', embedded: false };
|
||||
|
||||
const result = await deleteRagFile({ userId: 'user123', file });
|
||||
|
||||
expect(result).toBe(true);
|
||||
expect(mockedAxios.delete).not.toHaveBeenCalled();
|
||||
expect(mockedGenerateShortLivedToken).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should skip RAG deletion when embedded is undefined', async () => {
|
||||
const file = { file_id: 'file-123' };
|
||||
|
||||
const result = await deleteRagFile({ userId: 'user123', file });
|
||||
|
||||
expect(result).toBe(true);
|
||||
expect(mockedAxios.delete).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('when RAG_API_URL is not configured', () => {
|
||||
it('should skip RAG deletion and return true', async () => {
|
||||
delete process.env.RAG_API_URL;
|
||||
const file = { file_id: 'file-123', embedded: true };
|
||||
|
||||
const result = await deleteRagFile({ userId: 'user123', file });
|
||||
|
||||
expect(result).toBe(true);
|
||||
expect(mockedAxios.delete).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('userId handling', () => {
|
||||
it('should return false when no userId is provided', async () => {
|
||||
const file = { file_id: 'file-123', embedded: true };
|
||||
|
||||
const result = await deleteRagFile({ userId: '', file });
|
||||
|
||||
expect(result).toBe(false);
|
||||
expect(mockedLogger.error).toHaveBeenCalledWith('[deleteRagFile] No user ID provided');
|
||||
expect(mockedAxios.delete).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return false when userId is undefined', async () => {
|
||||
const file = { file_id: 'file-123', embedded: true };
|
||||
|
||||
const result = await deleteRagFile({ userId: undefined as unknown as string, file });
|
||||
|
||||
expect(result).toBe(false);
|
||||
expect(mockedLogger.error).toHaveBeenCalledWith('[deleteRagFile] No user ID provided');
|
||||
});
|
||||
});
|
||||
});
|
||||
60
packages/api/src/files/rag.ts
Normal file
60
packages/api/src/files/rag.ts
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
import axios from 'axios';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import { generateShortLivedToken } from '~/crypto/jwt';
|
||||
|
||||
interface DeleteRagFileParams {
|
||||
/** The user ID. Required for authentication. If not provided, the function returns false and logs an error. */
|
||||
userId: string;
|
||||
/** The file object. Must have `embedded` and `file_id` properties. */
|
||||
file: {
|
||||
file_id: string;
|
||||
embedded?: boolean;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes embedded document(s) from the RAG API.
|
||||
* This is a shared utility function used by all file storage strategies
|
||||
* (S3, Azure, Firebase, Local) to delete RAG embeddings when a file is deleted.
|
||||
*
|
||||
* @param params - The parameters object.
|
||||
* @param params.userId - The user ID for authentication.
|
||||
* @param params.file - The file object. Must have `embedded` and `file_id` properties.
|
||||
* @returns Returns true if deletion was successful or skipped, false if there was an error.
|
||||
*/
|
||||
export async function deleteRagFile({ userId, file }: DeleteRagFileParams): Promise<boolean> {
|
||||
if (!file.embedded || !process.env.RAG_API_URL) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!userId) {
|
||||
logger.error('[deleteRagFile] No user ID provided');
|
||||
return false;
|
||||
}
|
||||
|
||||
const jwtToken = generateShortLivedToken(userId);
|
||||
|
||||
try {
|
||||
await axios.delete(`${process.env.RAG_API_URL}/documents`, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${jwtToken}`,
|
||||
'Content-Type': 'application/json',
|
||||
accept: 'application/json',
|
||||
},
|
||||
data: [file.file_id],
|
||||
});
|
||||
logger.debug(`[deleteRagFile] Successfully deleted document ${file.file_id} from RAG API`);
|
||||
return true;
|
||||
} catch (error) {
|
||||
const axiosError = error as { response?: { status?: number }; message?: string };
|
||||
if (axiosError.response?.status === 404) {
|
||||
logger.warn(
|
||||
`[deleteRagFile] Document ${file.file_id} not found in RAG API, may have been deleted already`,
|
||||
);
|
||||
return true;
|
||||
} else {
|
||||
logger.error('[deleteRagFile] Error deleting document from RAG API:', axiosError.message);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -4,6 +4,8 @@ import { MCPConnection } from './connection';
|
|||
import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry';
|
||||
import type * as t from './types';
|
||||
|
||||
const CONNECT_CONCURRENCY = 3;
|
||||
|
||||
/**
|
||||
* Manages MCP connections with lazy loading and reconnection.
|
||||
* Maintains a pool of connections and handles connection lifecycle management.
|
||||
|
|
@ -73,6 +75,7 @@ export class ConnectionsRepository {
|
|||
{
|
||||
serverName,
|
||||
serverConfig,
|
||||
useSSRFProtection: MCPServersRegistry.getInstance().shouldEnableSSRFProtection(),
|
||||
},
|
||||
this.oauthOpts,
|
||||
);
|
||||
|
|
@ -83,9 +86,17 @@ export class ConnectionsRepository {
|
|||
|
||||
/** Gets or creates connections for multiple servers concurrently */
|
||||
async getMany(serverNames: string[]): Promise<Map<string, MCPConnection>> {
|
||||
const connectionPromises = serverNames.map(async (name) => [name, await this.get(name)]);
|
||||
const connections = await Promise.all(connectionPromises);
|
||||
return new Map((connections as [string, MCPConnection][]).filter((v) => !!v[1]));
|
||||
const results: [string, MCPConnection | null][] = [];
|
||||
for (let i = 0; i < serverNames.length; i += CONNECT_CONCURRENCY) {
|
||||
const batch = serverNames.slice(i, i + CONNECT_CONCURRENCY);
|
||||
const batchResults = await Promise.all(
|
||||
batch.map(
|
||||
async (name): Promise<[string, MCPConnection | null]> => [name, await this.get(name)],
|
||||
),
|
||||
);
|
||||
results.push(...batchResults);
|
||||
}
|
||||
return new Map(results.filter((v): v is [string, MCPConnection] => v[1] != null));
|
||||
}
|
||||
|
||||
/** Returns all currently loaded connections without creating new ones */
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ export class MCPConnectionFactory {
|
|||
protected readonly serverConfig: t.MCPOptions;
|
||||
protected readonly logPrefix: string;
|
||||
protected readonly useOAuth: boolean;
|
||||
protected readonly useSSRFProtection: boolean;
|
||||
|
||||
// OAuth-related properties (only set when useOAuth is true)
|
||||
protected readonly userId?: string;
|
||||
|
|
@ -72,6 +73,7 @@ export class MCPConnectionFactory {
|
|||
serverConfig: this.serverConfig,
|
||||
userId: this.userId,
|
||||
oauthTokens,
|
||||
useSSRFProtection: this.useSSRFProtection,
|
||||
});
|
||||
|
||||
const oauthHandler = async () => {
|
||||
|
|
@ -146,6 +148,7 @@ export class MCPConnectionFactory {
|
|||
serverConfig: this.serverConfig,
|
||||
userId: this.userId,
|
||||
oauthTokens: null,
|
||||
useSSRFProtection: this.useSSRFProtection,
|
||||
});
|
||||
|
||||
unauthConnection.on('oauthRequired', () => {
|
||||
|
|
@ -189,6 +192,7 @@ export class MCPConnectionFactory {
|
|||
});
|
||||
this.serverName = basic.serverName;
|
||||
this.useOAuth = !!oauth?.useOAuth;
|
||||
this.useSSRFProtection = basic.useSSRFProtection === true;
|
||||
this.connectionTimeout = oauth?.connectionTimeout;
|
||||
this.logPrefix = oauth?.user
|
||||
? `[MCP][${basic.serverName}][${oauth.user.id}]`
|
||||
|
|
@ -213,6 +217,7 @@ export class MCPConnectionFactory {
|
|||
serverConfig: this.serverConfig,
|
||||
userId: this.userId,
|
||||
oauthTokens,
|
||||
useSSRFProtection: this.useSSRFProtection,
|
||||
});
|
||||
|
||||
let cleanupOAuthHandlers: (() => void) | null = null;
|
||||
|
|
@ -293,38 +298,45 @@ export class MCPConnectionFactory {
|
|||
const oauthHandler = async (data: { serverUrl?: string }) => {
|
||||
logger.info(`${this.logPrefix} oauthRequired event received`);
|
||||
|
||||
// If we just want to initiate OAuth and return, handle it differently
|
||||
if (this.returnOnOAuth) {
|
||||
try {
|
||||
const config = this.serverConfig;
|
||||
const { authorizationUrl, flowId, flowMetadata } =
|
||||
await MCPOAuthHandler.initiateOAuthFlow(
|
||||
this.serverName,
|
||||
data.serverUrl || '',
|
||||
this.userId!,
|
||||
config?.oauth_headers ?? {},
|
||||
config?.oauth,
|
||||
const flowId = MCPOAuthHandler.generateFlowId(this.userId!, this.serverName);
|
||||
const existingFlow = await this.flowManager!.getFlowState(flowId, 'mcp_oauth');
|
||||
|
||||
if (existingFlow?.status === 'PENDING') {
|
||||
logger.debug(
|
||||
`${this.logPrefix} PENDING OAuth flow already exists, skipping new initiation`,
|
||||
);
|
||||
connection.emit('oauthFailed', new Error('OAuth flow initiated - return early'));
|
||||
return;
|
||||
}
|
||||
|
||||
// Delete any existing flow state to ensure we start fresh
|
||||
// This prevents stale codeVerifier issues when re-authenticating
|
||||
await this.flowManager!.deleteFlow(flowId, 'mcp_oauth');
|
||||
const {
|
||||
authorizationUrl,
|
||||
flowId: newFlowId,
|
||||
flowMetadata,
|
||||
} = await MCPOAuthHandler.initiateOAuthFlow(
|
||||
this.serverName,
|
||||
data.serverUrl || '',
|
||||
this.userId!,
|
||||
config?.oauth_headers ?? {},
|
||||
config?.oauth,
|
||||
);
|
||||
|
||||
// Create the flow state so the OAuth callback can find it
|
||||
// We spawn this in the background without waiting for it
|
||||
// Pass signal so the flow can be aborted if the request is cancelled
|
||||
this.flowManager!.createFlow(flowId, 'mcp_oauth', flowMetadata, this.signal).catch(() => {
|
||||
// The OAuth callback will resolve this flow, so we expect it to timeout here
|
||||
// or it will be aborted if the request is cancelled - both are fine
|
||||
});
|
||||
if (existingFlow) {
|
||||
await this.flowManager!.deleteFlow(newFlowId, 'mcp_oauth');
|
||||
}
|
||||
|
||||
this.flowManager!.createFlow(newFlowId, 'mcp_oauth', flowMetadata, this.signal).catch(
|
||||
() => {},
|
||||
);
|
||||
|
||||
if (this.oauthStart) {
|
||||
logger.info(`${this.logPrefix} OAuth flow started, issuing authorization URL`);
|
||||
await this.oauthStart(authorizationUrl);
|
||||
}
|
||||
|
||||
// Emit oauthFailed to signal that connection should not proceed
|
||||
// but OAuth was successfully initiated
|
||||
connection.emit('oauthFailed', new Error('OAuth flow initiated - return early'));
|
||||
return;
|
||||
} catch (error) {
|
||||
|
|
@ -386,11 +398,9 @@ export class MCPConnectionFactory {
|
|||
logger.error(`${this.logPrefix} Failed to establish connection.`);
|
||||
}
|
||||
|
||||
// Handles connection attempts with retry logic and OAuth error handling
|
||||
private async connectTo(connection: MCPConnection): Promise<void> {
|
||||
const maxAttempts = 3;
|
||||
let attempts = 0;
|
||||
let oauthHandled = false;
|
||||
|
||||
while (attempts < maxAttempts) {
|
||||
try {
|
||||
|
|
@ -403,22 +413,6 @@ export class MCPConnectionFactory {
|
|||
attempts++;
|
||||
|
||||
if (this.useOAuth && this.isOAuthError(error)) {
|
||||
// For returnOnOAuth mode, let the event handler (handleOAuthEvents) deal with OAuth
|
||||
// We just need to stop retrying and let the error propagate
|
||||
if (this.returnOnOAuth) {
|
||||
logger.info(
|
||||
`${this.logPrefix} OAuth required (return on OAuth mode), stopping retries`,
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
|
||||
// Normal flow - wait for OAuth to complete
|
||||
if (this.oauthStart && !oauthHandled) {
|
||||
oauthHandled = true;
|
||||
logger.info(`${this.logPrefix} Handling OAuth`);
|
||||
await this.handleOAuthRequired();
|
||||
}
|
||||
// Don't retry on OAuth errors - just throw
|
||||
logger.info(`${this.logPrefix} OAuth required, stopping connection attempts`);
|
||||
throw error;
|
||||
}
|
||||
|
|
@ -494,26 +488,15 @@ export class MCPConnectionFactory {
|
|||
/** Check if there's already an ongoing OAuth flow for this flowId */
|
||||
const existingFlow = await this.flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
|
||||
// If any flow exists (PENDING, COMPLETED, FAILED), cancel it and start fresh
|
||||
// This ensures the user always gets a new OAuth URL instead of waiting for stale flows
|
||||
if (existingFlow) {
|
||||
logger.debug(
|
||||
`${this.logPrefix} Found existing OAuth flow (status: ${existingFlow.status}), cancelling to start fresh`,
|
||||
`${this.logPrefix} Found existing OAuth flow (status: ${existingFlow.status}), cleaning up to start fresh`,
|
||||
);
|
||||
try {
|
||||
if (existingFlow.status === 'PENDING') {
|
||||
await this.flowManager.failFlow(
|
||||
flowId,
|
||||
'mcp_oauth',
|
||||
new Error('Cancelled for new OAuth request'),
|
||||
);
|
||||
} else {
|
||||
await this.flowManager.deleteFlow(flowId, 'mcp_oauth');
|
||||
}
|
||||
await this.flowManager.deleteFlow(flowId, 'mcp_oauth');
|
||||
} catch (error) {
|
||||
logger.warn(`${this.logPrefix} Failed to cancel existing OAuth flow`, error);
|
||||
logger.warn(`${this.logPrefix} Failed to clean up existing OAuth flow`, error);
|
||||
}
|
||||
// Continue to start a new flow below
|
||||
}
|
||||
|
||||
logger.debug(`${this.logPrefix} Initiating new OAuth flow for ${this.serverName}...`);
|
||||
|
|
|
|||
|
|
@ -102,7 +102,8 @@ export class MCPManager extends UserConnectionManager {
|
|||
serverConfig.requiresOAuth || (serverConfig as t.ParsedServerConfig).oauthMetadata,
|
||||
);
|
||||
|
||||
const basic: t.BasicConnectionOptions = { serverName, serverConfig };
|
||||
const useSSRFProtection = MCPServersRegistry.getInstance().shouldEnableSSRFProtection();
|
||||
const basic: t.BasicConnectionOptions = { serverName, serverConfig, useSSRFProtection };
|
||||
|
||||
if (!useOAuth) {
|
||||
const result = await MCPConnectionFactory.discoverTools(basic);
|
||||
|
|
|
|||
|
|
@ -117,6 +117,7 @@ export abstract class UserConnectionManager {
|
|||
{
|
||||
serverName: serverName,
|
||||
serverConfig: config,
|
||||
useSSRFProtection: MCPServersRegistry.getInstance().shouldEnableSSRFProtection(),
|
||||
},
|
||||
{
|
||||
useOAuth: true,
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ jest.mock('../connection');
|
|||
const mockRegistryInstance = {
|
||||
getServerConfig: jest.fn(),
|
||||
getAllServerConfigs: jest.fn(),
|
||||
shouldEnableSSRFProtection: jest.fn().mockReturnValue(false),
|
||||
};
|
||||
|
||||
jest.mock('../registry/MCPServersRegistry', () => ({
|
||||
|
|
@ -108,6 +109,7 @@ describe('ConnectionsRepository', () => {
|
|||
{
|
||||
serverName: 'server1',
|
||||
serverConfig: mockServerConfigs.server1,
|
||||
useSSRFProtection: false,
|
||||
},
|
||||
undefined,
|
||||
);
|
||||
|
|
@ -129,6 +131,7 @@ describe('ConnectionsRepository', () => {
|
|||
{
|
||||
serverName: 'server1',
|
||||
serverConfig: mockServerConfigs.server1,
|
||||
useSSRFProtection: false,
|
||||
},
|
||||
undefined,
|
||||
);
|
||||
|
|
@ -167,6 +170,7 @@ describe('ConnectionsRepository', () => {
|
|||
{
|
||||
serverName: 'server1',
|
||||
serverConfig: configWithCachedAt,
|
||||
useSSRFProtection: false,
|
||||
},
|
||||
undefined,
|
||||
);
|
||||
|
|
|
|||
|
|
@ -84,6 +84,7 @@ describe('MCPConnectionFactory', () => {
|
|||
serverConfig: mockServerConfig,
|
||||
userId: undefined,
|
||||
oauthTokens: null,
|
||||
useSSRFProtection: false,
|
||||
});
|
||||
expect(mockConnectionInstance.connect).toHaveBeenCalled();
|
||||
});
|
||||
|
|
@ -125,6 +126,7 @@ describe('MCPConnectionFactory', () => {
|
|||
serverConfig: mockServerConfig,
|
||||
userId: 'user123',
|
||||
oauthTokens: mockTokens,
|
||||
useSSRFProtection: false,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -184,6 +186,7 @@ describe('MCPConnectionFactory', () => {
|
|||
serverConfig: mockServerConfig,
|
||||
userId: 'user123',
|
||||
oauthTokens: null,
|
||||
useSSRFProtection: false,
|
||||
});
|
||||
expect(mockLogger.debug).toHaveBeenCalledWith(
|
||||
expect.stringContaining('No existing tokens found or error loading tokens'),
|
||||
|
|
@ -267,7 +270,54 @@ describe('MCPConnectionFactory', () => {
|
|||
);
|
||||
});
|
||||
|
||||
it('should delete existing flow before creating new OAuth flow to prevent stale codeVerifier', async () => {
|
||||
it('should skip new OAuth flow initiation when a PENDING flow already exists (returnOnOAuth)', async () => {
|
||||
const basicOptions = {
|
||||
serverName: 'test-server',
|
||||
serverConfig: mockServerConfig,
|
||||
user: mockUser,
|
||||
};
|
||||
|
||||
const oauthOptions: t.OAuthConnectionOptions = {
|
||||
user: mockUser,
|
||||
useOAuth: true,
|
||||
returnOnOAuth: true,
|
||||
oauthStart: jest.fn(),
|
||||
flowManager: mockFlowManager,
|
||||
};
|
||||
|
||||
mockFlowManager.getFlowState.mockResolvedValue({
|
||||
status: 'PENDING',
|
||||
type: 'mcp_oauth',
|
||||
metadata: { codeVerifier: 'existing-verifier' },
|
||||
createdAt: Date.now(),
|
||||
});
|
||||
mockConnectionInstance.isConnected.mockResolvedValue(false);
|
||||
|
||||
let oauthRequiredHandler: (data: Record<string, unknown>) => Promise<void>;
|
||||
mockConnectionInstance.on.mockImplementation((event, handler) => {
|
||||
if (event === 'oauthRequired') {
|
||||
oauthRequiredHandler = handler as (data: Record<string, unknown>) => Promise<void>;
|
||||
}
|
||||
return mockConnectionInstance;
|
||||
});
|
||||
|
||||
try {
|
||||
await MCPConnectionFactory.create(basicOptions, oauthOptions);
|
||||
} catch {
|
||||
// Expected to fail
|
||||
}
|
||||
|
||||
await oauthRequiredHandler!({ serverUrl: 'https://api.example.com' });
|
||||
|
||||
expect(mockMCPOAuthHandler.initiateOAuthFlow).not.toHaveBeenCalled();
|
||||
expect(mockFlowManager.deleteFlow).not.toHaveBeenCalled();
|
||||
expect(mockConnectionInstance.emit).toHaveBeenCalledWith(
|
||||
'oauthFailed',
|
||||
expect.objectContaining({ message: 'OAuth flow initiated - return early' }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should delete stale flow and create new OAuth flow when existing flow is COMPLETED', async () => {
|
||||
const basicOptions = {
|
||||
serverName: 'test-server',
|
||||
serverConfig: mockServerConfig,
|
||||
|
|
@ -300,6 +350,12 @@ describe('MCPConnectionFactory', () => {
|
|||
},
|
||||
};
|
||||
|
||||
mockFlowManager.getFlowState.mockResolvedValue({
|
||||
status: 'COMPLETED',
|
||||
type: 'mcp_oauth',
|
||||
metadata: { codeVerifier: 'old-verifier' },
|
||||
createdAt: Date.now() - 60000,
|
||||
});
|
||||
mockMCPOAuthHandler.initiateOAuthFlow.mockResolvedValue(mockFlowData);
|
||||
mockFlowManager.deleteFlow.mockResolvedValue(true);
|
||||
mockFlowManager.createFlow.mockRejectedValue(new Error('Timeout expected'));
|
||||
|
|
@ -316,21 +372,17 @@ describe('MCPConnectionFactory', () => {
|
|||
try {
|
||||
await MCPConnectionFactory.create(basicOptions, oauthOptions);
|
||||
} catch {
|
||||
// Expected to fail due to connection not established
|
||||
// Expected to fail
|
||||
}
|
||||
|
||||
await oauthRequiredHandler!({ serverUrl: 'https://api.example.com' });
|
||||
|
||||
// Verify deleteFlow was called with correct parameters
|
||||
expect(mockFlowManager.deleteFlow).toHaveBeenCalledWith('user123:test-server', 'mcp_oauth');
|
||||
|
||||
// Verify deleteFlow was called before createFlow
|
||||
const deleteCallOrder = mockFlowManager.deleteFlow.mock.invocationCallOrder[0];
|
||||
const createCallOrder = mockFlowManager.createFlow.mock.invocationCallOrder[0];
|
||||
expect(deleteCallOrder).toBeLessThan(createCallOrder);
|
||||
|
||||
// Verify createFlow was called with fresh metadata
|
||||
// 4th arg is the abort signal (undefined in this test since no signal was provided)
|
||||
expect(mockFlowManager.createFlow).toHaveBeenCalledWith(
|
||||
'user123:test-server',
|
||||
'mcp_oauth',
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ const mockRegistryInstance = {
|
|||
getServerConfig: jest.fn(),
|
||||
getAllServerConfigs: jest.fn(),
|
||||
getOAuthServers: jest.fn(),
|
||||
shouldEnableSSRFProtection: jest.fn().mockReturnValue(false),
|
||||
};
|
||||
|
||||
jest.mock('~/mcp/registry/MCPServersRegistry', () => ({
|
||||
|
|
|
|||
|
|
@ -2,7 +2,12 @@
|
|||
// zod.spec.ts
|
||||
import { z } from 'zod';
|
||||
import type { JsonSchemaType } from '@librechat/data-schemas';
|
||||
import { resolveJsonSchemaRefs, convertJsonSchemaToZod, convertWithResolvedRefs } from '../zod';
|
||||
import {
|
||||
convertWithResolvedRefs,
|
||||
convertJsonSchemaToZod,
|
||||
resolveJsonSchemaRefs,
|
||||
normalizeJsonSchema,
|
||||
} from '../zod';
|
||||
|
||||
describe('convertJsonSchemaToZod', () => {
|
||||
describe('integer type handling', () => {
|
||||
|
|
@ -206,7 +211,7 @@ describe('convertJsonSchemaToZod', () => {
|
|||
type: 'number' as const,
|
||||
enum: [1, 2, 3, 5, 8, 13],
|
||||
};
|
||||
const zodSchema = convertWithResolvedRefs(schema as JsonSchemaType);
|
||||
const zodSchema = convertWithResolvedRefs(schema as unknown as JsonSchemaType);
|
||||
|
||||
expect(zodSchema?.parse(1)).toBe(1);
|
||||
expect(zodSchema?.parse(13)).toBe(13);
|
||||
|
|
@ -1599,6 +1604,34 @@ describe('convertJsonSchemaToZod', () => {
|
|||
expect(() => zodSchema?.parse(testData)).not.toThrow();
|
||||
});
|
||||
|
||||
it('should strip $defs from the resolved output', () => {
|
||||
const schemaWithDefs = {
|
||||
type: 'object' as const,
|
||||
properties: {
|
||||
item: { $ref: '#/$defs/Item' },
|
||||
},
|
||||
$defs: {
|
||||
Item: {
|
||||
type: 'object' as const,
|
||||
properties: {
|
||||
name: { type: 'string' as const },
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const resolved = resolveJsonSchemaRefs(schemaWithDefs);
|
||||
// $defs should NOT be in the output — it was only used for resolution
|
||||
expect(resolved).not.toHaveProperty('$defs');
|
||||
// The $ref should be resolved inline
|
||||
expect(resolved.properties?.item).toEqual({
|
||||
type: 'object',
|
||||
properties: {
|
||||
name: { type: 'string' },
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle various edge cases safely', () => {
|
||||
// Test with null/undefined
|
||||
expect(resolveJsonSchemaRefs(null as any)).toBeNull();
|
||||
|
|
@ -2002,3 +2035,329 @@ describe('convertJsonSchemaToZod', () => {
|
|||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('normalizeJsonSchema', () => {
|
||||
it('should convert const to enum', () => {
|
||||
const schema = { type: 'string', const: 'hello' } as any;
|
||||
const result = normalizeJsonSchema(schema);
|
||||
expect(result).toEqual({ type: 'string', enum: ['hello'] });
|
||||
expect(result).not.toHaveProperty('const');
|
||||
});
|
||||
|
||||
it('should preserve existing enum when const is also present', () => {
|
||||
const schema = { type: 'string', const: 'hello', enum: ['hello', 'world'] } as any;
|
||||
const result = normalizeJsonSchema(schema);
|
||||
expect(result).toEqual({ type: 'string', enum: ['hello', 'world'] });
|
||||
expect(result).not.toHaveProperty('const');
|
||||
});
|
||||
|
||||
it('should handle non-string const values (number, boolean, null)', () => {
|
||||
expect(normalizeJsonSchema({ type: 'number', const: 42 } as any)).toEqual({
|
||||
type: 'number',
|
||||
enum: [42],
|
||||
});
|
||||
expect(normalizeJsonSchema({ type: 'boolean', const: true } as any)).toEqual({
|
||||
type: 'boolean',
|
||||
enum: [true],
|
||||
});
|
||||
expect(normalizeJsonSchema({ type: 'string', const: null } as any)).toEqual({
|
||||
type: 'string',
|
||||
enum: [null],
|
||||
});
|
||||
});
|
||||
|
||||
it('should recursively normalize nested object properties', () => {
|
||||
const schema = {
|
||||
type: 'object',
|
||||
properties: {
|
||||
mode: { type: 'string', const: 'advanced' },
|
||||
count: { type: 'number', const: 5 },
|
||||
name: { type: 'string', description: 'A name' },
|
||||
},
|
||||
} as any;
|
||||
|
||||
const result = normalizeJsonSchema(schema);
|
||||
expect(result.properties.mode).toEqual({ type: 'string', enum: ['advanced'] });
|
||||
expect(result.properties.count).toEqual({ type: 'number', enum: [5] });
|
||||
expect(result.properties.name).toEqual({ type: 'string', description: 'A name' });
|
||||
});
|
||||
|
||||
it('should normalize inside oneOf/anyOf/allOf arrays', () => {
|
||||
const schema = {
|
||||
type: 'object',
|
||||
oneOf: [
|
||||
{ type: 'object', properties: { kind: { type: 'string', const: 'A' } } },
|
||||
{ type: 'object', properties: { kind: { type: 'string', const: 'B' } } },
|
||||
],
|
||||
anyOf: [{ type: 'string', const: 'x' }],
|
||||
allOf: [{ type: 'number', const: 1 }],
|
||||
} as any;
|
||||
|
||||
const result = normalizeJsonSchema(schema);
|
||||
expect(result.oneOf[0].properties.kind).toEqual({ type: 'string', enum: ['A'] });
|
||||
expect(result.oneOf[1].properties.kind).toEqual({ type: 'string', enum: ['B'] });
|
||||
expect(result.anyOf[0]).toEqual({ type: 'string', enum: ['x'] });
|
||||
expect(result.allOf[0]).toEqual({ type: 'number', enum: [1] });
|
||||
});
|
||||
|
||||
it('should normalize array items with const', () => {
|
||||
const schema = {
|
||||
type: 'array',
|
||||
items: { type: 'string', const: 'fixed' },
|
||||
} as any;
|
||||
|
||||
const result = normalizeJsonSchema(schema);
|
||||
expect(result.items).toEqual({ type: 'string', enum: ['fixed'] });
|
||||
});
|
||||
|
||||
it('should normalize additionalProperties with const', () => {
|
||||
const schema = {
|
||||
type: 'object',
|
||||
additionalProperties: { type: 'string', const: 'val' },
|
||||
} as any;
|
||||
|
||||
const result = normalizeJsonSchema(schema);
|
||||
expect(result.additionalProperties).toEqual({ type: 'string', enum: ['val'] });
|
||||
});
|
||||
|
||||
it('should handle null, undefined, and primitive inputs safely', () => {
|
||||
expect(normalizeJsonSchema(null as any)).toBeNull();
|
||||
expect(normalizeJsonSchema(undefined as any)).toBeUndefined();
|
||||
expect(normalizeJsonSchema('string' as any)).toBe('string');
|
||||
expect(normalizeJsonSchema(42 as any)).toBe(42);
|
||||
expect(normalizeJsonSchema(true as any)).toBe(true);
|
||||
});
|
||||
|
||||
it('should be a no-op when no const is present', () => {
|
||||
const schema = {
|
||||
type: 'object',
|
||||
properties: {
|
||||
name: { type: 'string', description: 'Name' },
|
||||
age: { type: 'number' },
|
||||
tags: { type: 'array', items: { type: 'string' } },
|
||||
},
|
||||
required: ['name'],
|
||||
} as any;
|
||||
|
||||
const result = normalizeJsonSchema(schema);
|
||||
expect(result).toEqual(schema);
|
||||
});
|
||||
|
||||
it('should handle a Tavily-like schema pattern with const', () => {
|
||||
const schema = {
|
||||
type: 'object',
|
||||
properties: {
|
||||
query: {
|
||||
type: 'string',
|
||||
description: 'The search query',
|
||||
},
|
||||
search_depth: {
|
||||
type: 'string',
|
||||
const: 'advanced',
|
||||
description: 'The depth of the search',
|
||||
},
|
||||
topic: {
|
||||
type: 'string',
|
||||
enum: ['general', 'news'],
|
||||
description: 'The search topic',
|
||||
},
|
||||
include_answer: {
|
||||
type: 'boolean',
|
||||
const: true,
|
||||
},
|
||||
max_results: {
|
||||
type: 'number',
|
||||
const: 5,
|
||||
},
|
||||
},
|
||||
required: ['query'],
|
||||
} as any;
|
||||
|
||||
const result = normalizeJsonSchema(schema);
|
||||
|
||||
// const fields should be converted to enum
|
||||
expect(result.properties.search_depth).toEqual({
|
||||
type: 'string',
|
||||
enum: ['advanced'],
|
||||
description: 'The depth of the search',
|
||||
});
|
||||
expect(result.properties.include_answer).toEqual({
|
||||
type: 'boolean',
|
||||
enum: [true],
|
||||
});
|
||||
expect(result.properties.max_results).toEqual({
|
||||
type: 'number',
|
||||
enum: [5],
|
||||
});
|
||||
|
||||
// Existing enum should be preserved
|
||||
expect(result.properties.topic).toEqual({
|
||||
type: 'string',
|
||||
enum: ['general', 'news'],
|
||||
description: 'The search topic',
|
||||
});
|
||||
|
||||
// Non-const fields should be unchanged
|
||||
expect(result.properties.query).toEqual({
|
||||
type: 'string',
|
||||
description: 'The search query',
|
||||
});
|
||||
|
||||
// Top-level fields preserved
|
||||
expect(result.required).toEqual(['query']);
|
||||
expect(result.type).toBe('object');
|
||||
});
|
||||
|
||||
it('should handle arrays at the top level', () => {
|
||||
const schemas = [
|
||||
{ type: 'string', const: 'a' },
|
||||
{ type: 'number', const: 1 },
|
||||
] as any;
|
||||
|
||||
const result = normalizeJsonSchema(schemas);
|
||||
expect(result).toEqual([
|
||||
{ type: 'string', enum: ['a'] },
|
||||
{ type: 'number', enum: [1] },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should strip vendor extension fields (x-* prefixed keys)', () => {
|
||||
const schema = {
|
||||
type: 'object',
|
||||
properties: {
|
||||
travelMode: {
|
||||
type: 'string',
|
||||
enum: ['DRIVE', 'BICYCLE', 'TRANSIT', 'WALK'],
|
||||
'x-google-enum-descriptions': ['By car', 'By bicycle', 'By public transit', 'By walking'],
|
||||
description: 'Mode of travel',
|
||||
},
|
||||
},
|
||||
} as any;
|
||||
|
||||
const result = normalizeJsonSchema(schema);
|
||||
expect(result.properties.travelMode).toEqual({
|
||||
type: 'string',
|
||||
enum: ['DRIVE', 'BICYCLE', 'TRANSIT', 'WALK'],
|
||||
description: 'Mode of travel',
|
||||
});
|
||||
expect(result.properties.travelMode).not.toHaveProperty('x-google-enum-descriptions');
|
||||
});
|
||||
|
||||
it('should strip x-* fields at all nesting levels', () => {
|
||||
const schema = {
|
||||
type: 'object',
|
||||
'x-custom-root': true,
|
||||
properties: {
|
||||
outer: {
|
||||
type: 'object',
|
||||
'x-custom-outer': 'value',
|
||||
properties: {
|
||||
inner: {
|
||||
type: 'string',
|
||||
'x-custom-inner': 42,
|
||||
},
|
||||
},
|
||||
},
|
||||
arr: {
|
||||
type: 'array',
|
||||
items: {
|
||||
type: 'string',
|
||||
'x-item-meta': 'something',
|
||||
},
|
||||
},
|
||||
},
|
||||
} as any;
|
||||
|
||||
const result = normalizeJsonSchema(schema);
|
||||
expect(result).not.toHaveProperty('x-custom-root');
|
||||
expect(result.properties.outer).not.toHaveProperty('x-custom-outer');
|
||||
expect(result.properties.outer.properties.inner).not.toHaveProperty('x-custom-inner');
|
||||
expect(result.properties.arr.items).not.toHaveProperty('x-item-meta');
|
||||
// Standard fields should be preserved
|
||||
expect(result.type).toBe('object');
|
||||
expect(result.properties.outer.type).toBe('object');
|
||||
expect(result.properties.outer.properties.inner.type).toBe('string');
|
||||
expect(result.properties.arr.items.type).toBe('string');
|
||||
});
|
||||
|
||||
it('should strip $defs and definitions as a safety net', () => {
|
||||
const schema = {
|
||||
type: 'object',
|
||||
properties: {
|
||||
name: { type: 'string' },
|
||||
},
|
||||
$defs: {
|
||||
SomeType: { type: 'string' },
|
||||
},
|
||||
} as any;
|
||||
|
||||
const result = normalizeJsonSchema(schema);
|
||||
expect(result).not.toHaveProperty('$defs');
|
||||
expect(result.type).toBe('object');
|
||||
expect(result.properties.name).toEqual({ type: 'string' });
|
||||
});
|
||||
|
||||
it('should strip x-* fields inside oneOf/anyOf/allOf', () => {
|
||||
const schema = {
|
||||
type: 'object',
|
||||
oneOf: [
|
||||
{ type: 'string', 'x-meta': 'a' },
|
||||
{ type: 'number', 'x-meta': 'b' },
|
||||
],
|
||||
} as any;
|
||||
|
||||
const result = normalizeJsonSchema(schema);
|
||||
expect(result.oneOf[0]).toEqual({ type: 'string' });
|
||||
expect(result.oneOf[1]).toEqual({ type: 'number' });
|
||||
});
|
||||
|
||||
it('should handle a Google Maps MCP-like schema with $defs and x-google-enum-descriptions', () => {
|
||||
const schema = {
|
||||
type: 'object',
|
||||
properties: {
|
||||
origin: { type: 'string', description: 'Starting address' },
|
||||
destination: { type: 'string', description: 'Ending address' },
|
||||
travelMode: {
|
||||
type: 'string',
|
||||
enum: ['DRIVE', 'BICYCLE', 'TRANSIT', 'WALK'],
|
||||
'x-google-enum-descriptions': ['By car', 'By bicycle', 'By public transit', 'By walking'],
|
||||
},
|
||||
waypoints: {
|
||||
type: 'array',
|
||||
items: { $ref: '#/$defs/Waypoint' },
|
||||
},
|
||||
},
|
||||
required: ['origin', 'destination'],
|
||||
$defs: {
|
||||
Waypoint: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
location: { type: 'string' },
|
||||
stopover: { type: 'boolean' },
|
||||
},
|
||||
},
|
||||
},
|
||||
} as any;
|
||||
|
||||
// First resolve refs, then normalize
|
||||
const resolved = resolveJsonSchemaRefs(schema);
|
||||
const result = normalizeJsonSchema(resolved);
|
||||
|
||||
// $defs should be stripped (by both resolveJsonSchemaRefs and normalizeJsonSchema)
|
||||
expect(result).not.toHaveProperty('$defs');
|
||||
// x-google-enum-descriptions should be stripped
|
||||
expect(result.properties.travelMode).not.toHaveProperty('x-google-enum-descriptions');
|
||||
// $ref should be resolved inline
|
||||
expect(result.properties.waypoints.items).not.toHaveProperty('$ref');
|
||||
expect(result.properties.waypoints.items).toEqual({
|
||||
type: 'object',
|
||||
properties: {
|
||||
location: { type: 'string' },
|
||||
stopover: { type: 'boolean' },
|
||||
},
|
||||
});
|
||||
// Standard fields preserved
|
||||
expect(result.properties.travelMode.enum).toEqual(['DRIVE', 'BICYCLE', 'TRANSIT', 'WALK']);
|
||||
expect(result.properties.origin).toEqual({ type: 'string', description: 'Starting address' });
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ import type {
|
|||
import type { MCPOAuthTokens } from './oauth/types';
|
||||
import { withTimeout } from '~/utils/promise';
|
||||
import type * as t from './types';
|
||||
import { createSSRFSafeUndiciConnect, resolveHostnameSSRF } from '~/auth';
|
||||
import { sanitizeUrlForLogging } from './utils';
|
||||
import { mcpConfig } from './mcpConfig';
|
||||
|
||||
|
|
@ -213,6 +214,7 @@ interface MCPConnectionParams {
|
|||
serverConfig: t.MCPOptions;
|
||||
userId?: string;
|
||||
oauthTokens?: MCPOAuthTokens | null;
|
||||
useSSRFProtection?: boolean;
|
||||
}
|
||||
|
||||
export class MCPConnection extends EventEmitter {
|
||||
|
|
@ -233,6 +235,7 @@ export class MCPConnection extends EventEmitter {
|
|||
private oauthTokens?: MCPOAuthTokens | null;
|
||||
private requestHeaders?: Record<string, string> | null;
|
||||
private oauthRequired = false;
|
||||
private readonly useSSRFProtection: boolean;
|
||||
iconPath?: string;
|
||||
timeout?: number;
|
||||
url?: string;
|
||||
|
|
@ -263,6 +266,7 @@ export class MCPConnection extends EventEmitter {
|
|||
this.options = params.serverConfig;
|
||||
this.serverName = params.serverName;
|
||||
this.userId = params.userId;
|
||||
this.useSSRFProtection = params.useSSRFProtection === true;
|
||||
this.iconPath = params.serverConfig.iconPath;
|
||||
this.timeout = params.serverConfig.timeout;
|
||||
this.lastPingTime = Date.now();
|
||||
|
|
@ -301,6 +305,7 @@ export class MCPConnection extends EventEmitter {
|
|||
getHeaders: () => Record<string, string> | null | undefined,
|
||||
timeout?: number,
|
||||
): (input: UndiciRequestInfo, init?: UndiciRequestInit) => Promise<UndiciResponse> {
|
||||
const ssrfConnect = this.useSSRFProtection ? createSSRFSafeUndiciConnect() : undefined;
|
||||
return function customFetch(
|
||||
input: UndiciRequestInfo,
|
||||
init?: UndiciRequestInit,
|
||||
|
|
@ -310,6 +315,7 @@ export class MCPConnection extends EventEmitter {
|
|||
const agent = new Agent({
|
||||
bodyTimeout: effectiveTimeout,
|
||||
headersTimeout: effectiveTimeout,
|
||||
...(ssrfConnect != null ? { connect: ssrfConnect } : {}),
|
||||
});
|
||||
if (!requestHeaders) {
|
||||
return undiciFetch(input, { ...init, dispatcher: agent });
|
||||
|
|
@ -342,7 +348,7 @@ export class MCPConnection extends EventEmitter {
|
|||
logger.error(`${this.getLogPrefix()} ${errorContext}: ${errorMessage}`);
|
||||
}
|
||||
|
||||
private constructTransport(options: t.MCPOptions): Transport {
|
||||
private async constructTransport(options: t.MCPOptions): Promise<Transport> {
|
||||
try {
|
||||
let type: t.MCPOptions['type'];
|
||||
if (isStdioOptions(options)) {
|
||||
|
|
@ -378,6 +384,15 @@ export class MCPConnection extends EventEmitter {
|
|||
throw new Error('Invalid options for websocket transport.');
|
||||
}
|
||||
this.url = options.url;
|
||||
if (this.useSSRFProtection) {
|
||||
const wsHostname = new URL(options.url).hostname;
|
||||
const isSSRF = await resolveHostnameSSRF(wsHostname);
|
||||
if (isSSRF) {
|
||||
throw new Error(
|
||||
`SSRF protection: WebSocket host "${wsHostname}" resolved to a private/reserved IP address`,
|
||||
);
|
||||
}
|
||||
}
|
||||
return new WebSocketClientTransport(new URL(options.url));
|
||||
|
||||
case 'sse': {
|
||||
|
|
@ -402,6 +417,7 @@ export class MCPConnection extends EventEmitter {
|
|||
* The connect timeout is extended because proxies may delay initial response.
|
||||
*/
|
||||
const sseTimeout = this.timeout || SSE_CONNECT_TIMEOUT;
|
||||
const ssrfConnect = this.useSSRFProtection ? createSSRFSafeUndiciConnect() : undefined;
|
||||
const transport = new SSEClientTransport(url, {
|
||||
requestInit: {
|
||||
/** User/OAuth headers override SSE defaults */
|
||||
|
|
@ -420,6 +436,7 @@ export class MCPConnection extends EventEmitter {
|
|||
/** Extended keep-alive for long-lived SSE connections */
|
||||
keepAliveTimeout: sseTimeout,
|
||||
keepAliveMaxTimeout: sseTimeout * 2,
|
||||
...(ssrfConnect != null ? { connect: ssrfConnect } : {}),
|
||||
});
|
||||
return undiciFetch(url, {
|
||||
...init,
|
||||
|
|
@ -542,7 +559,11 @@ export class MCPConnection extends EventEmitter {
|
|||
}
|
||||
|
||||
this.isReconnecting = true;
|
||||
const backoffDelay = (attempt: number) => Math.min(1000 * Math.pow(2, attempt), 30000);
|
||||
const backoffDelay = (attempt: number) => {
|
||||
const base = Math.min(1000 * Math.pow(2, attempt), 30000);
|
||||
const jitter = Math.floor(Math.random() * 1000); // up to 1s of random jitter
|
||||
return base + jitter;
|
||||
};
|
||||
|
||||
try {
|
||||
while (
|
||||
|
|
@ -629,7 +650,7 @@ export class MCPConnection extends EventEmitter {
|
|||
}
|
||||
}
|
||||
|
||||
this.transport = this.constructTransport(this.options);
|
||||
this.transport = await this.constructTransport(this.options);
|
||||
this.setupTransportDebugHandlers();
|
||||
|
||||
const connectTimeout = this.options.initTimeout ?? 120000;
|
||||
|
|
|
|||
|
|
@ -336,6 +336,69 @@ describe('OAuthReconnectionManager', () => {
|
|||
});
|
||||
});
|
||||
|
||||
describe('reconnection staggering', () => {
|
||||
let reconnectionTracker: OAuthReconnectionTracker;
|
||||
|
||||
beforeEach(async () => {
|
||||
jest.useFakeTimers();
|
||||
reconnectionTracker = new OAuthReconnectionTracker();
|
||||
reconnectionManager = await OAuthReconnectionManager.createInstance(
|
||||
flowManager,
|
||||
tokenMethods,
|
||||
reconnectionTracker,
|
||||
);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
jest.useRealTimers();
|
||||
});
|
||||
|
||||
it('should stagger reconnection attempts for multiple servers', async () => {
|
||||
const userId = 'user-123';
|
||||
const oauthServers = new Set(['server1', 'server2', 'server3']);
|
||||
(mockRegistryInstance.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
|
||||
|
||||
// All servers have valid tokens and are not connected
|
||||
tokenMethods.findToken.mockImplementation(async ({ identifier }) => {
|
||||
return {
|
||||
userId,
|
||||
identifier,
|
||||
expiresAt: new Date(Date.now() + 3600000),
|
||||
} as unknown as MCPOAuthTokens;
|
||||
});
|
||||
|
||||
const mockNewConnection = {
|
||||
isConnected: jest.fn().mockResolvedValue(true),
|
||||
disconnect: jest.fn(),
|
||||
};
|
||||
mockMCPManager.getUserConnection.mockResolvedValue(
|
||||
mockNewConnection as unknown as MCPConnection,
|
||||
);
|
||||
(mockRegistryInstance.getServerConfig as jest.Mock).mockResolvedValue(
|
||||
{} as unknown as MCPOptions,
|
||||
);
|
||||
|
||||
await reconnectionManager.reconnectServers(userId);
|
||||
|
||||
// Only the first server should have been attempted immediately
|
||||
expect(mockMCPManager.getUserConnection).toHaveBeenCalledTimes(1);
|
||||
expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ serverName: 'server1' }),
|
||||
);
|
||||
|
||||
// After advancing all timers, all servers should have been attempted
|
||||
await jest.runAllTimersAsync();
|
||||
|
||||
expect(mockMCPManager.getUserConnection).toHaveBeenCalledTimes(3);
|
||||
expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ serverName: 'server2' }),
|
||||
);
|
||||
expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ serverName: 'server3' }),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('reconnection timeout behavior', () => {
|
||||
let reconnectionTracker: OAuthReconnectionTracker;
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import { MCPManager } from '~/mcp/MCPManager';
|
|||
import { MCPServersRegistry } from '~/mcp/registry/MCPServersRegistry';
|
||||
|
||||
const DEFAULT_CONNECTION_TIMEOUT_MS = 10_000; // ms
|
||||
const RECONNECT_STAGGER_MS = 500; // ms between each server reconnection
|
||||
|
||||
export class OAuthReconnectionManager {
|
||||
private static instance: OAuthReconnectionManager | null = null;
|
||||
|
|
@ -84,9 +85,14 @@ export class OAuthReconnectionManager {
|
|||
this.reconnectionsTracker.setActive(userId, serverName);
|
||||
}
|
||||
|
||||
// 3. attempt to reconnect the servers
|
||||
for (const serverName of serversToReconnect) {
|
||||
void this.tryReconnect(userId, serverName);
|
||||
// 3. attempt to reconnect the servers with staggered delays to avoid connection storms
|
||||
for (let i = 0; i < serversToReconnect.length; i++) {
|
||||
const serverName = serversToReconnect[i];
|
||||
if (i === 0) {
|
||||
void this.tryReconnect(userId, serverName);
|
||||
} else {
|
||||
setTimeout(() => void this.tryReconnect(userId, serverName), i * RECONNECT_STAGGER_MS);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ export class MCPServerInspector {
|
|||
private readonly serverName: string,
|
||||
private readonly config: t.ParsedServerConfig,
|
||||
private connection: MCPConnection | undefined,
|
||||
private readonly useSSRFProtection: boolean = false,
|
||||
) {}
|
||||
|
||||
/**
|
||||
|
|
@ -42,8 +43,9 @@ export class MCPServerInspector {
|
|||
throw new MCPDomainNotAllowedError(domain ?? 'unknown');
|
||||
}
|
||||
|
||||
const useSSRFProtection = !Array.isArray(allowedDomains) || allowedDomains.length === 0;
|
||||
const start = Date.now();
|
||||
const inspector = new MCPServerInspector(serverName, rawConfig, connection);
|
||||
const inspector = new MCPServerInspector(serverName, rawConfig, connection, useSSRFProtection);
|
||||
await inspector.inspectServer();
|
||||
inspector.config.initDuration = Date.now() - start;
|
||||
return inspector.config;
|
||||
|
|
@ -59,6 +61,7 @@ export class MCPServerInspector {
|
|||
this.connection = await MCPConnectionFactory.create({
|
||||
serverName: this.serverName,
|
||||
serverConfig: this.config,
|
||||
useSSRFProtection: this.useSSRFProtection,
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -77,6 +77,15 @@ export class MCPServersRegistry {
|
|||
return MCPServersRegistry.instance;
|
||||
}
|
||||
|
||||
public getAllowedDomains(): string[] | null | undefined {
|
||||
return this.allowedDomains;
|
||||
}
|
||||
|
||||
/** Returns true when no explicit allowedDomains allowlist is configured, enabling SSRF TOCTOU protection */
|
||||
public shouldEnableSSRFProtection(): boolean {
|
||||
return !Array.isArray(this.allowedDomains) || this.allowedDomains.length === 0;
|
||||
}
|
||||
|
||||
public async getServerConfig(
|
||||
serverName: string,
|
||||
userId?: string,
|
||||
|
|
|
|||
|
|
@ -276,6 +276,7 @@ describe('MCPServerInspector', () => {
|
|||
expect(MCPConnectionFactory.create).toHaveBeenCalledWith({
|
||||
serverName: 'test_server',
|
||||
serverConfig: expect.objectContaining({ type: 'stdio', command: 'node' }),
|
||||
useSSRFProtection: true,
|
||||
});
|
||||
|
||||
// Verify temporary connection was disconnected
|
||||
|
|
|
|||
|
|
@ -166,6 +166,7 @@ export type AddServerResult = {
|
|||
export interface BasicConnectionOptions {
|
||||
serverName: string;
|
||||
serverConfig: MCPOptions;
|
||||
useSSRFProtection?: boolean;
|
||||
}
|
||||
|
||||
export interface OAuthConnectionOptions {
|
||||
|
|
|
|||
|
|
@ -203,9 +203,9 @@ export function resolveJsonSchemaRefs<T extends Record<string, unknown>>(
|
|||
const result: Record<string, unknown> = {};
|
||||
|
||||
for (const [key, value] of Object.entries(schema)) {
|
||||
// Skip $defs/definitions at root level to avoid infinite recursion
|
||||
if ((key === '$defs' || key === 'definitions') && !visited.size) {
|
||||
result[key] = value;
|
||||
// Skip $defs/definitions — they are only used for resolving $ref and
|
||||
// should not appear in the resolved output (e.g. Google/Gemini API rejects them).
|
||||
if (key === '$defs' || key === 'definitions') {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
@ -248,6 +248,80 @@ export function resolveJsonSchemaRefs<T extends Record<string, unknown>>(
|
|||
return result as T;
|
||||
}
|
||||
|
||||
/**
|
||||
* Recursively normalizes a JSON schema for LLM API compatibility.
|
||||
*
|
||||
* Transformations applied:
|
||||
* - Converts `const` values to `enum` arrays (Gemini/Vertex AI rejects `const`)
|
||||
* - Strips vendor extension fields (`x-*` prefixed keys, e.g. `x-google-enum-descriptions`)
|
||||
* - Strips leftover `$defs`/`definitions` blocks that may survive ref resolution
|
||||
*
|
||||
* @param schema - The JSON schema to normalize
|
||||
* @returns The normalized schema
|
||||
*/
|
||||
export function normalizeJsonSchema<T extends Record<string, unknown>>(schema: T): T {
|
||||
if (!schema || typeof schema !== 'object') {
|
||||
return schema;
|
||||
}
|
||||
|
||||
if (Array.isArray(schema)) {
|
||||
return schema.map((item) =>
|
||||
item && typeof item === 'object' ? normalizeJsonSchema(item) : item,
|
||||
) as unknown as T;
|
||||
}
|
||||
|
||||
const result: Record<string, unknown> = {};
|
||||
|
||||
for (const [key, value] of Object.entries(schema)) {
|
||||
// Strip vendor extension fields (e.g. x-google-enum-descriptions) —
|
||||
// these are valid in JSON Schema but rejected by Google/Gemini API.
|
||||
if (key.startsWith('x-')) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Strip leftover $defs/definitions (should already be resolved by resolveJsonSchemaRefs,
|
||||
// but strip as a safety net for schemas that bypass ref resolution).
|
||||
if (key === '$defs' || key === 'definitions') {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (key === 'const' && !('enum' in schema)) {
|
||||
result['enum'] = [value];
|
||||
continue;
|
||||
}
|
||||
|
||||
if (key === 'const' && 'enum' in schema) {
|
||||
// Skip `const` when `enum` already exists
|
||||
continue;
|
||||
}
|
||||
|
||||
if (key === 'properties' && value && typeof value === 'object' && !Array.isArray(value)) {
|
||||
const newProps: Record<string, unknown> = {};
|
||||
for (const [propKey, propValue] of Object.entries(value as Record<string, unknown>)) {
|
||||
newProps[propKey] =
|
||||
propValue && typeof propValue === 'object'
|
||||
? normalizeJsonSchema(propValue as Record<string, unknown>)
|
||||
: propValue;
|
||||
}
|
||||
result[key] = newProps;
|
||||
} else if (
|
||||
(key === 'items' || key === 'additionalProperties') &&
|
||||
value &&
|
||||
typeof value === 'object'
|
||||
) {
|
||||
result[key] = normalizeJsonSchema(value as Record<string, unknown>);
|
||||
} else if ((key === 'oneOf' || key === 'anyOf' || key === 'allOf') && Array.isArray(value)) {
|
||||
result[key] = value.map((item) =>
|
||||
item && typeof item === 'object' ? normalizeJsonSchema(item) : item,
|
||||
);
|
||||
} else {
|
||||
result[key] = value;
|
||||
}
|
||||
}
|
||||
|
||||
return result as T;
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a JSON Schema to a Zod schema.
|
||||
*
|
||||
|
|
|
|||
|
|
@ -0,0 +1,258 @@
|
|||
import type { Redis, Cluster } from 'ioredis';
|
||||
|
||||
/**
|
||||
* Integration tests for concurrency middleware atomic Lua scripts.
|
||||
*
|
||||
* Tests that the Lua-based check-and-increment / decrement operations
|
||||
* are truly atomic and eliminate the INCR+check+DECR race window.
|
||||
*
|
||||
* Run with: USE_REDIS=true npx jest --config packages/api/jest.config.js concurrency.cache_integration
|
||||
*/
|
||||
describe('Concurrency Middleware Integration Tests', () => {
|
||||
let originalEnv: NodeJS.ProcessEnv;
|
||||
let ioredisClient: Redis | Cluster | null = null;
|
||||
let checkAndIncrementPendingRequest: (
|
||||
userId: string,
|
||||
) => Promise<{ allowed: boolean; pendingRequests: number; limit: number }>;
|
||||
let decrementPendingRequest: (userId: string) => Promise<void>;
|
||||
const testPrefix = 'Concurrency-Integration-Test';
|
||||
|
||||
beforeAll(async () => {
|
||||
originalEnv = { ...process.env };
|
||||
|
||||
process.env.USE_REDIS = process.env.USE_REDIS ?? 'true';
|
||||
process.env.USE_REDIS_CLUSTER = process.env.USE_REDIS_CLUSTER ?? 'false';
|
||||
process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379';
|
||||
process.env.REDIS_KEY_PREFIX = testPrefix;
|
||||
process.env.REDIS_PING_INTERVAL = '0';
|
||||
process.env.REDIS_RETRY_MAX_ATTEMPTS = '5';
|
||||
process.env.LIMIT_CONCURRENT_MESSAGES = 'true';
|
||||
process.env.CONCURRENT_MESSAGE_MAX = '2';
|
||||
|
||||
jest.resetModules();
|
||||
|
||||
const { ioredisClient: client } = await import('../../cache/redisClients');
|
||||
ioredisClient = client;
|
||||
|
||||
if (!ioredisClient) {
|
||||
console.warn('Redis not available, skipping integration tests');
|
||||
return;
|
||||
}
|
||||
|
||||
// Import concurrency module after Redis client is available
|
||||
const concurrency = await import('../concurrency');
|
||||
checkAndIncrementPendingRequest = concurrency.checkAndIncrementPendingRequest;
|
||||
decrementPendingRequest = concurrency.decrementPendingRequest;
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
if (!ioredisClient) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const keys = await ioredisClient.keys(`${testPrefix}*`);
|
||||
if (keys.length > 0) {
|
||||
await Promise.all(keys.map((key) => ioredisClient!.del(key)));
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn('Error cleaning up test keys:', error);
|
||||
}
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
if (ioredisClient) {
|
||||
try {
|
||||
await ioredisClient.quit();
|
||||
} catch {
|
||||
try {
|
||||
ioredisClient.disconnect();
|
||||
} catch {
|
||||
// Ignore
|
||||
}
|
||||
}
|
||||
}
|
||||
process.env = originalEnv;
|
||||
});
|
||||
|
||||
describe('Atomic Check and Increment', () => {
|
||||
test('should allow requests within the concurrency limit', async () => {
|
||||
if (!ioredisClient) {
|
||||
return;
|
||||
}
|
||||
|
||||
const userId = `user-allow-${Date.now()}`;
|
||||
|
||||
// First request - should be allowed (count = 1, limit = 2)
|
||||
const result1 = await checkAndIncrementPendingRequest(userId);
|
||||
expect(result1.allowed).toBe(true);
|
||||
expect(result1.pendingRequests).toBe(1);
|
||||
expect(result1.limit).toBe(2);
|
||||
|
||||
// Second request - should be allowed (count = 2, limit = 2)
|
||||
const result2 = await checkAndIncrementPendingRequest(userId);
|
||||
expect(result2.allowed).toBe(true);
|
||||
expect(result2.pendingRequests).toBe(2);
|
||||
});
|
||||
|
||||
test('should reject requests over the concurrency limit', async () => {
|
||||
if (!ioredisClient) {
|
||||
return;
|
||||
}
|
||||
|
||||
const userId = `user-reject-${Date.now()}`;
|
||||
|
||||
// Fill up to the limit
|
||||
await checkAndIncrementPendingRequest(userId);
|
||||
await checkAndIncrementPendingRequest(userId);
|
||||
|
||||
// Third request - should be rejected (count would be 3, limit = 2)
|
||||
const result = await checkAndIncrementPendingRequest(userId);
|
||||
expect(result.allowed).toBe(false);
|
||||
expect(result.pendingRequests).toBe(3); // Reports the count that was over-limit
|
||||
});
|
||||
|
||||
test('should not leave stale counter after rejection (atomic rollback)', async () => {
|
||||
if (!ioredisClient) {
|
||||
return;
|
||||
}
|
||||
|
||||
const userId = `user-rollback-${Date.now()}`;
|
||||
|
||||
// Fill up to the limit
|
||||
await checkAndIncrementPendingRequest(userId);
|
||||
await checkAndIncrementPendingRequest(userId);
|
||||
|
||||
// Attempt over-limit (should be rejected and atomically rolled back)
|
||||
const rejected = await checkAndIncrementPendingRequest(userId);
|
||||
expect(rejected.allowed).toBe(false);
|
||||
|
||||
// The key value should still be 2, not 3 — verify the Lua script decremented back
|
||||
const key = `PENDING_REQ:${userId}`;
|
||||
const rawValue = await ioredisClient.get(key);
|
||||
expect(rawValue).toBe('2');
|
||||
});
|
||||
|
||||
test('should handle concurrent requests atomically (no over-admission)', async () => {
|
||||
if (!ioredisClient) {
|
||||
return;
|
||||
}
|
||||
|
||||
const userId = `user-concurrent-${Date.now()}`;
|
||||
|
||||
// Fire 20 concurrent requests for the same user (limit = 2)
|
||||
const results = await Promise.all(
|
||||
Array.from({ length: 20 }, () => checkAndIncrementPendingRequest(userId)),
|
||||
);
|
||||
|
||||
const allowed = results.filter((r) => r.allowed);
|
||||
const rejected = results.filter((r) => !r.allowed);
|
||||
|
||||
// Exactly 2 should be allowed (the concurrency limit)
|
||||
expect(allowed.length).toBe(2);
|
||||
expect(rejected.length).toBe(18);
|
||||
|
||||
// The key value should be exactly 2 after all atomic operations
|
||||
const key = `PENDING_REQ:${userId}`;
|
||||
const rawValue = await ioredisClient.get(key);
|
||||
expect(rawValue).toBe('2');
|
||||
|
||||
// Clean up
|
||||
await decrementPendingRequest(userId);
|
||||
await decrementPendingRequest(userId);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Atomic Decrement', () => {
|
||||
test('should decrement pending requests', async () => {
|
||||
if (!ioredisClient) {
|
||||
return;
|
||||
}
|
||||
|
||||
const userId = `user-decrement-${Date.now()}`;
|
||||
|
||||
await checkAndIncrementPendingRequest(userId);
|
||||
await checkAndIncrementPendingRequest(userId);
|
||||
|
||||
// Decrement once
|
||||
await decrementPendingRequest(userId);
|
||||
|
||||
const key = `PENDING_REQ:${userId}`;
|
||||
const rawValue = await ioredisClient.get(key);
|
||||
expect(rawValue).toBe('1');
|
||||
});
|
||||
|
||||
test('should clean up key when count reaches zero', async () => {
|
||||
if (!ioredisClient) {
|
||||
return;
|
||||
}
|
||||
|
||||
const userId = `user-cleanup-${Date.now()}`;
|
||||
|
||||
await checkAndIncrementPendingRequest(userId);
|
||||
await decrementPendingRequest(userId);
|
||||
|
||||
// Key should be deleted (not left as "0")
|
||||
const key = `PENDING_REQ:${userId}`;
|
||||
const exists = await ioredisClient.exists(key);
|
||||
expect(exists).toBe(0);
|
||||
});
|
||||
|
||||
test('should clean up key on double-decrement (negative protection)', async () => {
|
||||
if (!ioredisClient) {
|
||||
return;
|
||||
}
|
||||
|
||||
const userId = `user-double-decr-${Date.now()}`;
|
||||
|
||||
await checkAndIncrementPendingRequest(userId);
|
||||
await decrementPendingRequest(userId);
|
||||
await decrementPendingRequest(userId); // Double-decrement
|
||||
|
||||
// Key should be deleted, not negative
|
||||
const key = `PENDING_REQ:${userId}`;
|
||||
const exists = await ioredisClient.exists(key);
|
||||
expect(exists).toBe(0);
|
||||
});
|
||||
|
||||
test('should allow new requests after decrement frees a slot', async () => {
|
||||
if (!ioredisClient) {
|
||||
return;
|
||||
}
|
||||
|
||||
const userId = `user-free-slot-${Date.now()}`;
|
||||
|
||||
// Fill to limit
|
||||
await checkAndIncrementPendingRequest(userId);
|
||||
await checkAndIncrementPendingRequest(userId);
|
||||
|
||||
// Verify at limit
|
||||
const atLimit = await checkAndIncrementPendingRequest(userId);
|
||||
expect(atLimit.allowed).toBe(false);
|
||||
|
||||
// Free a slot
|
||||
await decrementPendingRequest(userId);
|
||||
|
||||
// Should now be allowed again
|
||||
const allowed = await checkAndIncrementPendingRequest(userId);
|
||||
expect(allowed.allowed).toBe(true);
|
||||
expect(allowed.pendingRequests).toBe(2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('TTL Behavior', () => {
|
||||
test('should set TTL on the concurrency key', async () => {
|
||||
if (!ioredisClient) {
|
||||
return;
|
||||
}
|
||||
|
||||
const userId = `user-ttl-${Date.now()}`;
|
||||
await checkAndIncrementPendingRequest(userId);
|
||||
|
||||
const key = `PENDING_REQ:${userId}`;
|
||||
const ttl = await ioredisClient.ttl(key);
|
||||
expect(ttl).toBeGreaterThan(0);
|
||||
expect(ttl).toBeLessThanOrEqual(60);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -9,6 +9,40 @@ const LIMIT_CONCURRENT_MESSAGES = process.env.LIMIT_CONCURRENT_MESSAGES;
|
|||
const CONCURRENT_MESSAGE_MAX = math(process.env.CONCURRENT_MESSAGE_MAX, 2);
|
||||
const CONCURRENT_VIOLATION_SCORE = math(process.env.CONCURRENT_VIOLATION_SCORE, 1);
|
||||
|
||||
/**
|
||||
* Lua script for atomic check-and-increment.
|
||||
* Increments the key, sets TTL, and if over limit decrements back.
|
||||
* Returns positive count if allowed, negative count if rejected.
|
||||
* Single round-trip, fully atomic — eliminates the INCR/check/DECR race window.
|
||||
*/
|
||||
const CHECK_AND_INCREMENT_SCRIPT = `
|
||||
local key = KEYS[1]
|
||||
local limit = tonumber(ARGV[1])
|
||||
local ttl = tonumber(ARGV[2])
|
||||
local current = redis.call('INCR', key)
|
||||
redis.call('EXPIRE', key, ttl)
|
||||
if current > limit then
|
||||
redis.call('DECR', key)
|
||||
return -current
|
||||
end
|
||||
return current
|
||||
`;
|
||||
|
||||
/**
|
||||
* Lua script for atomic decrement-and-cleanup.
|
||||
* Decrements the key and deletes it if the count reaches zero or below.
|
||||
* Eliminates the DECR-then-DEL race window.
|
||||
*/
|
||||
const DECREMENT_SCRIPT = `
|
||||
local key = KEYS[1]
|
||||
local current = redis.call('DECR', key)
|
||||
if current <= 0 then
|
||||
redis.call('DEL', key)
|
||||
return 0
|
||||
end
|
||||
return current
|
||||
`;
|
||||
|
||||
/** Lazily initialized cache for pending requests (used only for in-memory fallback) */
|
||||
let pendingReqCache: ReturnType<typeof standardCache> | null = null;
|
||||
|
||||
|
|
@ -80,36 +114,28 @@ export async function checkAndIncrementPendingRequest(
|
|||
return { allowed: true, pendingRequests: 0, limit };
|
||||
}
|
||||
|
||||
// Use atomic Redis INCR when available to prevent race conditions
|
||||
// Use atomic Lua script when Redis is available to prevent race conditions.
|
||||
// A single EVAL round-trip atomically increments, checks, and decrements if over-limit.
|
||||
if (USE_REDIS && ioredisClient) {
|
||||
const key = buildKey(userId);
|
||||
try {
|
||||
// Pipeline ensures INCR and EXPIRE execute atomically in one round-trip
|
||||
// This prevents edge cases where crash between operations leaves key without TTL
|
||||
const pipeline = ioredisClient.pipeline();
|
||||
pipeline.incr(key);
|
||||
pipeline.expire(key, 60);
|
||||
const results = await pipeline.exec();
|
||||
const result = (await ioredisClient.eval(
|
||||
CHECK_AND_INCREMENT_SCRIPT,
|
||||
1,
|
||||
key,
|
||||
limit,
|
||||
60,
|
||||
)) as number;
|
||||
|
||||
if (!results || results[0][0]) {
|
||||
throw new Error('Pipeline execution failed');
|
||||
if (result < 0) {
|
||||
// Negative return means over-limit (absolute value is the count before decrement)
|
||||
const count = -result;
|
||||
logger.debug(`[concurrency] User ${userId} exceeded concurrent limit: ${count}/${limit}`);
|
||||
return { allowed: false, pendingRequests: count, limit };
|
||||
}
|
||||
|
||||
const newCount = results[0][1] as number;
|
||||
|
||||
if (newCount > limit) {
|
||||
// Over limit - decrement back and reject
|
||||
await ioredisClient.decr(key);
|
||||
logger.debug(
|
||||
`[concurrency] User ${userId} exceeded concurrent limit: ${newCount}/${limit}`,
|
||||
);
|
||||
return { allowed: false, pendingRequests: newCount, limit };
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
`[concurrency] User ${userId} incremented pending requests: ${newCount}/${limit}`,
|
||||
);
|
||||
return { allowed: true, pendingRequests: newCount, limit };
|
||||
logger.debug(`[concurrency] User ${userId} incremented pending requests: ${result}/${limit}`);
|
||||
return { allowed: true, pendingRequests: result, limit };
|
||||
} catch (error) {
|
||||
logger.error('[concurrency] Redis atomic increment failed:', error);
|
||||
// On Redis error, allow the request to proceed (fail-open)
|
||||
|
|
@ -164,18 +190,12 @@ export async function decrementPendingRequest(userId: string): Promise<void> {
|
|||
return;
|
||||
}
|
||||
|
||||
// Use atomic Redis DECR when available
|
||||
// Use atomic Lua script to decrement and clean up zero/negative keys in one round-trip
|
||||
if (USE_REDIS && ioredisClient) {
|
||||
const key = buildKey(userId);
|
||||
try {
|
||||
const newCount = await ioredisClient.decr(key);
|
||||
if (newCount < 0) {
|
||||
// Counter went negative - reset to 0 and delete
|
||||
await ioredisClient.del(key);
|
||||
logger.debug(`[concurrency] User ${userId} pending requests cleared (was negative)`);
|
||||
} else if (newCount === 0) {
|
||||
// Clean up zero-value keys
|
||||
await ioredisClient.del(key);
|
||||
const newCount = (await ioredisClient.eval(DECREMENT_SCRIPT, 1, key)) as number;
|
||||
if (newCount === 0) {
|
||||
logger.debug(`[concurrency] User ${userId} pending requests cleared`);
|
||||
} else {
|
||||
logger.debug(`[concurrency] User ${userId} decremented pending requests: ${newCount}`);
|
||||
|
|
|
|||
99
packages/api/src/oauth/csrf.spec.ts
Normal file
99
packages/api/src/oauth/csrf.spec.ts
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
import { shouldUseSecureCookie } from './csrf';
|
||||
|
||||
describe('shouldUseSecureCookie', () => {
|
||||
const originalEnv = process.env;
|
||||
|
||||
beforeEach(() => {
|
||||
process.env = { ...originalEnv };
|
||||
});
|
||||
|
||||
afterAll(() => {
|
||||
process.env = originalEnv;
|
||||
});
|
||||
|
||||
it('should return true in production with a non-localhost domain', () => {
|
||||
process.env.NODE_ENV = 'production';
|
||||
process.env.DOMAIN_SERVER = 'https://myapp.example.com';
|
||||
expect(shouldUseSecureCookie()).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false in development regardless of domain', () => {
|
||||
process.env.NODE_ENV = 'development';
|
||||
process.env.DOMAIN_SERVER = 'https://myapp.example.com';
|
||||
expect(shouldUseSecureCookie()).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false when NODE_ENV is not set', () => {
|
||||
delete process.env.NODE_ENV;
|
||||
process.env.DOMAIN_SERVER = 'https://myapp.example.com';
|
||||
expect(shouldUseSecureCookie()).toBe(false);
|
||||
});
|
||||
|
||||
describe('localhost detection in production', () => {
|
||||
beforeEach(() => {
|
||||
process.env.NODE_ENV = 'production';
|
||||
});
|
||||
|
||||
it('should return false for http://localhost:3080', () => {
|
||||
process.env.DOMAIN_SERVER = 'http://localhost:3080';
|
||||
expect(shouldUseSecureCookie()).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false for https://localhost:3080', () => {
|
||||
process.env.DOMAIN_SERVER = 'https://localhost:3080';
|
||||
expect(shouldUseSecureCookie()).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false for http://localhost (no port)', () => {
|
||||
process.env.DOMAIN_SERVER = 'http://localhost';
|
||||
expect(shouldUseSecureCookie()).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false for http://127.0.0.1:3080', () => {
|
||||
process.env.DOMAIN_SERVER = 'http://127.0.0.1:3080';
|
||||
expect(shouldUseSecureCookie()).toBe(false);
|
||||
});
|
||||
|
||||
it('should return true for http://[::1]:3080 (IPv6 loopback — not detected due to URL bracket parsing)', () => {
|
||||
// Known limitation: new URL('http://[::1]:3080').hostname returns '[::1]' (with brackets)
|
||||
// but the check compares against '::1' (without brackets). IPv6 localhost is rare in practice.
|
||||
process.env.DOMAIN_SERVER = 'http://[::1]:3080';
|
||||
expect(shouldUseSecureCookie()).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false for subdomain of localhost', () => {
|
||||
process.env.DOMAIN_SERVER = 'http://app.localhost:3080';
|
||||
expect(shouldUseSecureCookie()).toBe(false);
|
||||
});
|
||||
|
||||
it('should return true for a domain containing "localhost" as a substring but not as hostname', () => {
|
||||
process.env.DOMAIN_SERVER = 'https://notlocalhost.example.com';
|
||||
expect(shouldUseSecureCookie()).toBe(true);
|
||||
});
|
||||
|
||||
it('should return true for a regular production domain', () => {
|
||||
process.env.DOMAIN_SERVER = 'https://chat.example.com';
|
||||
expect(shouldUseSecureCookie()).toBe(true);
|
||||
});
|
||||
|
||||
it('should return true when DOMAIN_SERVER is empty (conservative default)', () => {
|
||||
process.env.DOMAIN_SERVER = '';
|
||||
expect(shouldUseSecureCookie()).toBe(true);
|
||||
});
|
||||
|
||||
it('should return true when DOMAIN_SERVER is not set (conservative default)', () => {
|
||||
delete process.env.DOMAIN_SERVER;
|
||||
expect(shouldUseSecureCookie()).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle DOMAIN_SERVER without protocol prefix', () => {
|
||||
process.env.DOMAIN_SERVER = 'localhost:3080';
|
||||
expect(shouldUseSecureCookie()).toBe(false);
|
||||
});
|
||||
|
||||
it('should handle case-insensitive hostnames', () => {
|
||||
process.env.DOMAIN_SERVER = 'http://LOCALHOST:3080';
|
||||
expect(shouldUseSecureCookie()).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
119
packages/api/src/oauth/csrf.ts
Normal file
119
packages/api/src/oauth/csrf.ts
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
import crypto from 'crypto';
|
||||
import type { Request, Response, NextFunction } from 'express';
|
||||
|
||||
export const OAUTH_CSRF_COOKIE = 'oauth_csrf';
|
||||
export const OAUTH_CSRF_MAX_AGE = 10 * 60 * 1000;
|
||||
|
||||
export const OAUTH_SESSION_COOKIE = 'oauth_session';
|
||||
export const OAUTH_SESSION_MAX_AGE = 24 * 60 * 60 * 1000;
|
||||
export const OAUTH_SESSION_COOKIE_PATH = '/api';
|
||||
|
||||
/**
|
||||
* Determines if secure cookies should be used.
|
||||
* Returns `true` in production unless the server is running on localhost (HTTP).
|
||||
* This allows cookies to work on `http://localhost` during local development
|
||||
* even when `NODE_ENV=production` (common in Docker Compose setups).
|
||||
*/
|
||||
export function shouldUseSecureCookie(): boolean {
|
||||
const isProduction = process.env.NODE_ENV === 'production';
|
||||
const domainServer = process.env.DOMAIN_SERVER || '';
|
||||
|
||||
let hostname = '';
|
||||
if (domainServer) {
|
||||
try {
|
||||
const normalized = /^https?:\/\//i.test(domainServer)
|
||||
? domainServer
|
||||
: `http://${domainServer}`;
|
||||
const url = new URL(normalized);
|
||||
hostname = (url.hostname || '').toLowerCase();
|
||||
} catch {
|
||||
hostname = domainServer.toLowerCase();
|
||||
}
|
||||
}
|
||||
|
||||
const isLocalhost =
|
||||
hostname === 'localhost' ||
|
||||
hostname === '127.0.0.1' ||
|
||||
hostname === '::1' ||
|
||||
hostname.endsWith('.localhost');
|
||||
|
||||
return isProduction && !isLocalhost;
|
||||
}
|
||||
|
||||
/** Generates an HMAC-based token for OAuth CSRF protection */
|
||||
export function generateOAuthCsrfToken(flowId: string, secret?: string): string {
|
||||
const key = secret || process.env.JWT_SECRET;
|
||||
if (!key) {
|
||||
throw new Error('JWT_SECRET is required for OAuth CSRF token generation');
|
||||
}
|
||||
return crypto.createHmac('sha256', key).update(flowId).digest('hex').slice(0, 32);
|
||||
}
|
||||
|
||||
/** Sets a SameSite=Lax CSRF cookie bound to a specific OAuth flow */
|
||||
export function setOAuthCsrfCookie(res: Response, flowId: string, cookiePath: string): void {
|
||||
res.cookie(OAUTH_CSRF_COOKIE, generateOAuthCsrfToken(flowId), {
|
||||
httpOnly: true,
|
||||
secure: shouldUseSecureCookie(),
|
||||
sameSite: 'lax',
|
||||
maxAge: OAUTH_CSRF_MAX_AGE,
|
||||
path: cookiePath,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates the per-flow CSRF cookie against the expected HMAC.
|
||||
* Uses timing-safe comparison and always clears the cookie to prevent replay.
|
||||
*/
|
||||
export function validateOAuthCsrf(
|
||||
req: Request,
|
||||
res: Response,
|
||||
flowId: string,
|
||||
cookiePath: string,
|
||||
): boolean {
|
||||
const cookie = (req.cookies as Record<string, string> | undefined)?.[OAUTH_CSRF_COOKIE];
|
||||
res.clearCookie(OAUTH_CSRF_COOKIE, { path: cookiePath });
|
||||
if (!cookie) {
|
||||
return false;
|
||||
}
|
||||
const expected = generateOAuthCsrfToken(flowId);
|
||||
if (cookie.length !== expected.length) {
|
||||
return false;
|
||||
}
|
||||
return crypto.timingSafeEqual(Buffer.from(cookie), Buffer.from(expected));
|
||||
}
|
||||
|
||||
/**
|
||||
* Express middleware that sets the OAuth session cookie after JWT authentication.
|
||||
* Chain after requireJwtAuth on routes that precede an OAuth redirect (e.g., reinitialize, bind).
|
||||
*/
|
||||
export function setOAuthSession(req: Request, res: Response, next: NextFunction): void {
|
||||
const user = (req as Request & { user?: { id?: string } }).user;
|
||||
if (user?.id && !(req.cookies as Record<string, string> | undefined)?.[OAUTH_SESSION_COOKIE]) {
|
||||
setOAuthSessionCookie(res, user.id);
|
||||
}
|
||||
next();
|
||||
}
|
||||
|
||||
/** Sets a SameSite=Lax session cookie that binds the browser to the authenticated userId */
|
||||
export function setOAuthSessionCookie(res: Response, userId: string): void {
|
||||
res.cookie(OAUTH_SESSION_COOKIE, generateOAuthCsrfToken(userId), {
|
||||
httpOnly: true,
|
||||
secure: shouldUseSecureCookie(),
|
||||
sameSite: 'lax',
|
||||
maxAge: OAUTH_SESSION_MAX_AGE,
|
||||
path: OAUTH_SESSION_COOKIE_PATH,
|
||||
});
|
||||
}
|
||||
|
||||
/** Validates the session cookie against the expected userId using timing-safe comparison */
|
||||
export function validateOAuthSession(req: Request, userId: string): boolean {
|
||||
const cookie = (req.cookies as Record<string, string> | undefined)?.[OAUTH_SESSION_COOKIE];
|
||||
if (!cookie) {
|
||||
return false;
|
||||
}
|
||||
const expected = generateOAuthCsrfToken(userId);
|
||||
if (cookie.length !== expected.length) {
|
||||
return false;
|
||||
}
|
||||
return crypto.timingSafeEqual(Buffer.from(cookie), Buffer.from(expected));
|
||||
}
|
||||
|
|
@ -1 +1,2 @@
|
|||
export * from './csrf';
|
||||
export * from './tokens';
|
||||
|
|
|
|||
|
|
@ -745,7 +745,6 @@ class GenerationJobManagerClass {
|
|||
const subscription = this.eventTransport.subscribe(streamId, {
|
||||
onChunk: (event) => {
|
||||
const e = event as t.ServerSentEvent;
|
||||
// Filter out internal events
|
||||
if (!(e as Record<string, unknown>)._internal) {
|
||||
onChunk(e);
|
||||
}
|
||||
|
|
@ -754,14 +753,15 @@ class GenerationJobManagerClass {
|
|||
onError,
|
||||
});
|
||||
|
||||
// Check if this is the first subscriber
|
||||
if (subscription.ready) {
|
||||
await subscription.ready;
|
||||
}
|
||||
|
||||
const isFirst = this.eventTransport.isFirstSubscriber(streamId);
|
||||
|
||||
// First subscriber: replay buffered events and mark as connected
|
||||
if (!runtime.hasSubscriber) {
|
||||
runtime.hasSubscriber = true;
|
||||
|
||||
// Replay any events that were emitted before subscriber connected
|
||||
if (runtime.earlyEventBuffer.length > 0) {
|
||||
logger.debug(
|
||||
`[GenerationJobManager] Replaying ${runtime.earlyEventBuffer.length} buffered events for ${streamId}`,
|
||||
|
|
@ -771,6 +771,8 @@ class GenerationJobManagerClass {
|
|||
}
|
||||
runtime.earlyEventBuffer = [];
|
||||
}
|
||||
|
||||
this.eventTransport.syncReorderBuffer?.(streamId);
|
||||
}
|
||||
|
||||
if (isFirst) {
|
||||
|
|
@ -823,12 +825,13 @@ class GenerationJobManagerClass {
|
|||
}
|
||||
}
|
||||
|
||||
// Buffer early events if no subscriber yet (replay when first subscriber connects)
|
||||
if (!runtime.hasSubscriber) {
|
||||
runtime.earlyEventBuffer.push(event);
|
||||
if (!this._isRedis) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Await the transport emit - critical for Redis mode to maintain event order
|
||||
await this.eventTransport.emitChunk(streamId, event);
|
||||
}
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -19,8 +19,11 @@ describe('RedisEventTransport Integration Tests', () => {
|
|||
originalEnv = { ...process.env };
|
||||
|
||||
process.env.USE_REDIS = process.env.USE_REDIS ?? 'true';
|
||||
process.env.USE_REDIS_CLUSTER = process.env.USE_REDIS_CLUSTER ?? 'false';
|
||||
process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379';
|
||||
process.env.REDIS_KEY_PREFIX = testPrefix;
|
||||
process.env.REDIS_PING_INTERVAL = '0';
|
||||
process.env.REDIS_RETRY_MAX_ATTEMPTS = '5';
|
||||
|
||||
jest.resetModules();
|
||||
|
||||
|
|
@ -890,4 +893,121 @@ describe('RedisEventTransport Integration Tests', () => {
|
|||
subscriber.disconnect();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Publish Error Propagation', () => {
|
||||
test('should swallow emitChunk publish errors (callers fire-and-forget)', async () => {
|
||||
const { RedisEventTransport } = await import('../implementations/RedisEventTransport');
|
||||
|
||||
const mockPublisher = {
|
||||
publish: jest.fn().mockRejectedValue(new Error('Redis connection lost')),
|
||||
};
|
||||
const mockSubscriber = {
|
||||
on: jest.fn(),
|
||||
subscribe: jest.fn().mockResolvedValue(undefined),
|
||||
unsubscribe: jest.fn().mockResolvedValue(undefined),
|
||||
};
|
||||
|
||||
const transport = new RedisEventTransport(
|
||||
mockPublisher as unknown as Redis,
|
||||
mockSubscriber as unknown as Redis,
|
||||
);
|
||||
|
||||
const streamId = `error-prop-chunk-${Date.now()}`;
|
||||
|
||||
// emitChunk swallows errors because callers often fire-and-forget (no await).
|
||||
// Throwing would cause unhandled promise rejections.
|
||||
await expect(transport.emitChunk(streamId, { data: 'test' })).resolves.toBeUndefined();
|
||||
|
||||
transport.destroy();
|
||||
});
|
||||
|
||||
test('should throw when emitDone publish fails', async () => {
|
||||
const { RedisEventTransport } = await import('../implementations/RedisEventTransport');
|
||||
|
||||
const mockPublisher = {
|
||||
publish: jest.fn().mockRejectedValue(new Error('Redis connection lost')),
|
||||
};
|
||||
const mockSubscriber = {
|
||||
on: jest.fn(),
|
||||
subscribe: jest.fn().mockResolvedValue(undefined),
|
||||
unsubscribe: jest.fn().mockResolvedValue(undefined),
|
||||
};
|
||||
|
||||
const transport = new RedisEventTransport(
|
||||
mockPublisher as unknown as Redis,
|
||||
mockSubscriber as unknown as Redis,
|
||||
);
|
||||
|
||||
const streamId = `error-prop-done-${Date.now()}`;
|
||||
|
||||
await expect(transport.emitDone(streamId, { finished: true })).rejects.toThrow(
|
||||
'Redis connection lost',
|
||||
);
|
||||
|
||||
transport.destroy();
|
||||
});
|
||||
|
||||
test('should throw when emitError publish fails', async () => {
|
||||
const { RedisEventTransport } = await import('../implementations/RedisEventTransport');
|
||||
|
||||
const mockPublisher = {
|
||||
publish: jest.fn().mockRejectedValue(new Error('Redis connection lost')),
|
||||
};
|
||||
const mockSubscriber = {
|
||||
on: jest.fn(),
|
||||
subscribe: jest.fn().mockResolvedValue(undefined),
|
||||
unsubscribe: jest.fn().mockResolvedValue(undefined),
|
||||
};
|
||||
|
||||
const transport = new RedisEventTransport(
|
||||
mockPublisher as unknown as Redis,
|
||||
mockSubscriber as unknown as Redis,
|
||||
);
|
||||
|
||||
const streamId = `error-prop-error-${Date.now()}`;
|
||||
|
||||
await expect(transport.emitError(streamId, 'some error')).rejects.toThrow(
|
||||
'Redis connection lost',
|
||||
);
|
||||
|
||||
transport.destroy();
|
||||
});
|
||||
|
||||
test('should still deliver events successfully when publish succeeds', async () => {
|
||||
if (!ioredisClient) {
|
||||
console.warn('Redis not available, skipping test');
|
||||
return;
|
||||
}
|
||||
|
||||
const { RedisEventTransport } = await import('../implementations/RedisEventTransport');
|
||||
|
||||
const subscriber = (ioredisClient as Redis).duplicate();
|
||||
const transport = new RedisEventTransport(ioredisClient, subscriber);
|
||||
|
||||
const streamId = `error-prop-success-${Date.now()}`;
|
||||
const receivedChunks: unknown[] = [];
|
||||
let doneEvent: unknown = null;
|
||||
|
||||
transport.subscribe(streamId, {
|
||||
onChunk: (event) => receivedChunks.push(event),
|
||||
onDone: (event) => {
|
||||
doneEvent = event;
|
||||
},
|
||||
});
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 200));
|
||||
|
||||
// These should NOT throw
|
||||
await transport.emitChunk(streamId, { text: 'hello' });
|
||||
await transport.emitDone(streamId, { finished: true });
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 200));
|
||||
|
||||
expect(receivedChunks.length).toBe(1);
|
||||
expect(doneEvent).toEqual({ finished: true });
|
||||
|
||||
transport.destroy();
|
||||
subscriber.disconnect();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -24,8 +24,11 @@ describe('RedisJobStore Integration Tests', () => {
|
|||
|
||||
// Set up test environment
|
||||
process.env.USE_REDIS = process.env.USE_REDIS ?? 'true';
|
||||
process.env.USE_REDIS_CLUSTER = process.env.USE_REDIS_CLUSTER ?? 'false';
|
||||
process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379';
|
||||
process.env.REDIS_KEY_PREFIX = testPrefix;
|
||||
process.env.REDIS_PING_INTERVAL = '0';
|
||||
process.env.REDIS_RETRY_MAX_ATTEMPTS = '5';
|
||||
|
||||
jest.resetModules();
|
||||
|
||||
|
|
@ -880,6 +883,67 @@ describe('RedisJobStore Integration Tests', () => {
|
|||
});
|
||||
});
|
||||
|
||||
describe('Race Condition: updateJob after deleteJob', () => {
|
||||
test('should not re-create job hash when updateJob runs after deleteJob', async () => {
|
||||
if (!ioredisClient) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { RedisJobStore } = await import('../implementations/RedisJobStore');
|
||||
const store = new RedisJobStore(ioredisClient);
|
||||
await store.initialize();
|
||||
|
||||
const streamId = `race-condition-${Date.now()}`;
|
||||
await store.createJob(streamId, 'user-1', streamId);
|
||||
|
||||
const jobKey = `stream:{${streamId}}:job`;
|
||||
const ttlBefore = await ioredisClient.ttl(jobKey);
|
||||
expect(ttlBefore).toBeGreaterThan(0);
|
||||
|
||||
await store.deleteJob(streamId);
|
||||
|
||||
const afterDelete = await ioredisClient.exists(jobKey);
|
||||
expect(afterDelete).toBe(0);
|
||||
|
||||
await store.updateJob(streamId, { finalEvent: JSON.stringify({ final: true }) });
|
||||
|
||||
const afterUpdate = await ioredisClient.exists(jobKey);
|
||||
expect(afterUpdate).toBe(0);
|
||||
|
||||
await store.destroy();
|
||||
});
|
||||
|
||||
test('should not leave orphan keys from concurrent emitDone and deleteJob', async () => {
|
||||
if (!ioredisClient) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { RedisJobStore } = await import('../implementations/RedisJobStore');
|
||||
const store = new RedisJobStore(ioredisClient);
|
||||
await store.initialize();
|
||||
|
||||
const streamId = `concurrent-race-${Date.now()}`;
|
||||
await store.createJob(streamId, 'user-1', streamId);
|
||||
|
||||
const jobKey = `stream:{${streamId}}:job`;
|
||||
|
||||
await Promise.all([
|
||||
store.updateJob(streamId, { finalEvent: JSON.stringify({ final: true }) }),
|
||||
store.deleteJob(streamId),
|
||||
]);
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
|
||||
const exists = await ioredisClient.exists(jobKey);
|
||||
const ttl = exists ? await ioredisClient.ttl(jobKey) : -2;
|
||||
|
||||
expect(ttl === -2 || ttl > 0).toBe(true);
|
||||
expect(ttl).not.toBe(-1);
|
||||
|
||||
await store.destroy();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Local Graph Cache Optimization', () => {
|
||||
test('should use local cache when available', async () => {
|
||||
if (!ioredisClient) {
|
||||
|
|
@ -972,4 +1036,196 @@ describe('RedisJobStore Integration Tests', () => {
|
|||
await instance2.destroy();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Batched Cleanup', () => {
|
||||
test('should clean up many stale jobs in parallel batches', async () => {
|
||||
if (!ioredisClient) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { RedisJobStore } = await import('../implementations/RedisJobStore');
|
||||
// Very short TTL so jobs are immediately stale
|
||||
const store = new RedisJobStore(ioredisClient, { runningTtl: 1 });
|
||||
await store.initialize();
|
||||
|
||||
const jobCount = 75; // More than one batch of 50
|
||||
const veryOldTimestamp = Date.now() - 10000; // 10 seconds ago
|
||||
|
||||
// Create many stale jobs directly in Redis
|
||||
for (let i = 0; i < jobCount; i++) {
|
||||
const streamId = `batch-cleanup-${Date.now()}-${i}`;
|
||||
const jobKey = `stream:{${streamId}}:job`;
|
||||
await ioredisClient.hmset(jobKey, {
|
||||
streamId,
|
||||
userId: 'batch-user',
|
||||
status: 'running',
|
||||
createdAt: veryOldTimestamp.toString(),
|
||||
syncSent: '0',
|
||||
});
|
||||
await ioredisClient.sadd('stream:running', streamId);
|
||||
}
|
||||
|
||||
// Verify jobs are in the running set
|
||||
const runningBefore = await ioredisClient.scard('stream:running');
|
||||
expect(runningBefore).toBeGreaterThanOrEqual(jobCount);
|
||||
|
||||
// Run cleanup - should process in batches of 50
|
||||
const cleaned = await store.cleanup();
|
||||
expect(cleaned).toBeGreaterThanOrEqual(jobCount);
|
||||
|
||||
await store.destroy();
|
||||
});
|
||||
|
||||
test('should not clean up valid running jobs during batch cleanup', async () => {
|
||||
if (!ioredisClient) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { RedisJobStore } = await import('../implementations/RedisJobStore');
|
||||
const store = new RedisJobStore(ioredisClient, { runningTtl: 1200 });
|
||||
await store.initialize();
|
||||
|
||||
// Create a mix of valid and stale jobs
|
||||
const validStreamId = `valid-job-${Date.now()}`;
|
||||
await store.createJob(validStreamId, 'user-1', validStreamId);
|
||||
|
||||
const staleStreamId = `stale-job-${Date.now()}`;
|
||||
const jobKey = `stream:{${staleStreamId}}:job`;
|
||||
await ioredisClient.hmset(jobKey, {
|
||||
streamId: staleStreamId,
|
||||
userId: 'user-1',
|
||||
status: 'running',
|
||||
createdAt: (Date.now() - 2000000).toString(), // Very old
|
||||
syncSent: '0',
|
||||
});
|
||||
await ioredisClient.sadd('stream:running', staleStreamId);
|
||||
|
||||
const cleaned = await store.cleanup();
|
||||
expect(cleaned).toBeGreaterThanOrEqual(1);
|
||||
|
||||
// Valid job should still exist
|
||||
const validJob = await store.getJob(validStreamId);
|
||||
expect(validJob).not.toBeNull();
|
||||
expect(validJob?.status).toBe('running');
|
||||
|
||||
await store.destroy();
|
||||
});
|
||||
});
|
||||
|
||||
describe('appendChunk TTL Refresh', () => {
|
||||
test('should set TTL on the chunk stream', async () => {
|
||||
if (!ioredisClient) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { RedisJobStore } = await import('../implementations/RedisJobStore');
|
||||
const store = new RedisJobStore(ioredisClient, { runningTtl: 120 });
|
||||
await store.initialize();
|
||||
|
||||
const streamId = `append-ttl-${Date.now()}`;
|
||||
await store.createJob(streamId, 'user-1', streamId);
|
||||
|
||||
await store.appendChunk(streamId, {
|
||||
event: 'on_message_delta',
|
||||
data: { id: 'step-1', type: 'text', text: 'first' },
|
||||
});
|
||||
|
||||
const chunkKey = `stream:{${streamId}}:chunks`;
|
||||
const ttl = await ioredisClient.ttl(chunkKey);
|
||||
expect(ttl).toBeGreaterThan(0);
|
||||
expect(ttl).toBeLessThanOrEqual(120);
|
||||
|
||||
await store.destroy();
|
||||
});
|
||||
|
||||
test('should refresh TTL on subsequent chunks (not just first)', async () => {
|
||||
if (!ioredisClient) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { RedisJobStore } = await import('../implementations/RedisJobStore');
|
||||
const store = new RedisJobStore(ioredisClient, { runningTtl: 120 });
|
||||
await store.initialize();
|
||||
|
||||
const streamId = `append-refresh-${Date.now()}`;
|
||||
await store.createJob(streamId, 'user-1', streamId);
|
||||
|
||||
// Append first chunk
|
||||
await store.appendChunk(streamId, {
|
||||
event: 'on_message_delta',
|
||||
data: { id: 'step-1', type: 'text', text: 'first' },
|
||||
});
|
||||
|
||||
const chunkKey = `stream:{${streamId}}:chunks`;
|
||||
const ttl1 = await ioredisClient.ttl(chunkKey);
|
||||
expect(ttl1).toBeGreaterThan(0);
|
||||
|
||||
// Manually reduce TTL to simulate time passing
|
||||
await ioredisClient.expire(chunkKey, 30);
|
||||
const reducedTtl = await ioredisClient.ttl(chunkKey);
|
||||
expect(reducedTtl).toBeLessThanOrEqual(30);
|
||||
|
||||
// Append another chunk - TTL should be refreshed back to running TTL
|
||||
await store.appendChunk(streamId, {
|
||||
event: 'on_message_delta',
|
||||
data: { id: 'step-1', type: 'text', text: 'second' },
|
||||
});
|
||||
|
||||
const ttl2 = await ioredisClient.ttl(chunkKey);
|
||||
// Should be refreshed to ~120, not still ~30
|
||||
expect(ttl2).toBeGreaterThan(30);
|
||||
expect(ttl2).toBeLessThanOrEqual(120);
|
||||
|
||||
await store.destroy();
|
||||
});
|
||||
|
||||
test('should store chunks correctly via pipeline', async () => {
|
||||
if (!ioredisClient) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { RedisJobStore } = await import('../implementations/RedisJobStore');
|
||||
const store = new RedisJobStore(ioredisClient);
|
||||
await store.initialize();
|
||||
|
||||
const streamId = `append-pipeline-${Date.now()}`;
|
||||
await store.createJob(streamId, 'user-1', streamId);
|
||||
|
||||
const chunks = [
|
||||
{
|
||||
event: 'on_run_step',
|
||||
data: {
|
||||
id: 'step-1',
|
||||
runId: 'run-1',
|
||||
index: 0,
|
||||
stepDetails: { type: 'message_creation' },
|
||||
},
|
||||
},
|
||||
{
|
||||
event: 'on_message_delta',
|
||||
data: { id: 'step-1', delta: { content: { type: 'text', text: 'Hello ' } } },
|
||||
},
|
||||
{
|
||||
event: 'on_message_delta',
|
||||
data: { id: 'step-1', delta: { content: { type: 'text', text: 'world!' } } },
|
||||
},
|
||||
];
|
||||
|
||||
for (const chunk of chunks) {
|
||||
await store.appendChunk(streamId, chunk);
|
||||
}
|
||||
|
||||
// Verify all chunks were stored
|
||||
const chunkKey = `stream:{${streamId}}:chunks`;
|
||||
const len = await ioredisClient.xlen(chunkKey);
|
||||
expect(len).toBe(3);
|
||||
|
||||
// Verify content can be reconstructed
|
||||
const content = await store.getContentParts(streamId);
|
||||
expect(content).not.toBeNull();
|
||||
expect(content!.content.length).toBeGreaterThan(0);
|
||||
|
||||
await store.destroy();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -146,7 +146,7 @@ describe('CollectedUsage - GenerationJobManager', () => {
|
|||
cleanupOnComplete: false,
|
||||
});
|
||||
|
||||
await GenerationJobManager.initialize();
|
||||
GenerationJobManager.initialize();
|
||||
|
||||
const streamId = `manager-test-${Date.now()}`;
|
||||
await GenerationJobManager.createJob(streamId, 'user-1');
|
||||
|
|
@ -179,7 +179,7 @@ describe('CollectedUsage - GenerationJobManager', () => {
|
|||
cleanupOnComplete: false,
|
||||
});
|
||||
|
||||
await GenerationJobManager.initialize();
|
||||
GenerationJobManager.initialize();
|
||||
|
||||
const streamId = `no-usage-test-${Date.now()}`;
|
||||
await GenerationJobManager.createJob(streamId, 'user-1');
|
||||
|
|
@ -202,7 +202,7 @@ describe('CollectedUsage - GenerationJobManager', () => {
|
|||
isRedis: false,
|
||||
});
|
||||
|
||||
await GenerationJobManager.initialize();
|
||||
GenerationJobManager.initialize();
|
||||
|
||||
const collectedUsage: UsageMetadata[] = [
|
||||
{ input_tokens: 100, output_tokens: 50, model: 'gpt-4' },
|
||||
|
|
@ -235,7 +235,7 @@ describe('AbortJob - Text and CollectedUsage', () => {
|
|||
cleanupOnComplete: false,
|
||||
});
|
||||
|
||||
await GenerationJobManager.initialize();
|
||||
GenerationJobManager.initialize();
|
||||
|
||||
const streamId = `text-extract-${Date.now()}`;
|
||||
await GenerationJobManager.createJob(streamId, 'user-1');
|
||||
|
|
@ -267,7 +267,7 @@ describe('AbortJob - Text and CollectedUsage', () => {
|
|||
cleanupOnComplete: false,
|
||||
});
|
||||
|
||||
await GenerationJobManager.initialize();
|
||||
GenerationJobManager.initialize();
|
||||
|
||||
const streamId = `empty-text-${Date.now()}`;
|
||||
await GenerationJobManager.createJob(streamId, 'user-1');
|
||||
|
|
@ -291,7 +291,7 @@ describe('AbortJob - Text and CollectedUsage', () => {
|
|||
cleanupOnComplete: false,
|
||||
});
|
||||
|
||||
await GenerationJobManager.initialize();
|
||||
GenerationJobManager.initialize();
|
||||
|
||||
const streamId = `full-abort-${Date.now()}`;
|
||||
await GenerationJobManager.createJob(streamId, 'user-1');
|
||||
|
|
@ -328,7 +328,7 @@ describe('AbortJob - Text and CollectedUsage', () => {
|
|||
isRedis: false,
|
||||
});
|
||||
|
||||
await GenerationJobManager.initialize();
|
||||
GenerationJobManager.initialize();
|
||||
|
||||
const abortResult = await GenerationJobManager.abortJob('non-existent-job');
|
||||
|
||||
|
|
@ -365,7 +365,7 @@ describe('Real-world Scenarios', () => {
|
|||
cleanupOnComplete: false,
|
||||
});
|
||||
|
||||
await GenerationJobManager.initialize();
|
||||
GenerationJobManager.initialize();
|
||||
|
||||
const streamId = `parallel-abort-${Date.now()}`;
|
||||
await GenerationJobManager.createJob(streamId, 'user-1');
|
||||
|
|
@ -419,7 +419,7 @@ describe('Real-world Scenarios', () => {
|
|||
cleanupOnComplete: false,
|
||||
});
|
||||
|
||||
await GenerationJobManager.initialize();
|
||||
GenerationJobManager.initialize();
|
||||
|
||||
const streamId = `cache-abort-${Date.now()}`;
|
||||
await GenerationJobManager.createJob(streamId, 'user-1');
|
||||
|
|
@ -459,7 +459,7 @@ describe('Real-world Scenarios', () => {
|
|||
cleanupOnComplete: false,
|
||||
});
|
||||
|
||||
await GenerationJobManager.initialize();
|
||||
GenerationJobManager.initialize();
|
||||
|
||||
const streamId = `sequential-abort-${Date.now()}`;
|
||||
await GenerationJobManager.createJob(streamId, 'user-1');
|
||||
|
|
|
|||
|
|
@ -0,0 +1,450 @@
|
|||
import type { Redis, Cluster } from 'ioredis';
|
||||
import { RedisEventTransport } from '~/stream/implementations/RedisEventTransport';
|
||||
import { GenerationJobManagerClass } from '~/stream/GenerationJobManager';
|
||||
import { createStreamServices } from '~/stream/createStreamServices';
|
||||
import {
|
||||
ioredisClient as staticRedisClient,
|
||||
keyvRedisClient as staticKeyvClient,
|
||||
keyvRedisClientReady,
|
||||
} from '~/cache/redisClients';
|
||||
|
||||
/**
|
||||
* Regression tests for the reconnect reorder buffer desync bug.
|
||||
*
|
||||
* Bug: When a user disconnects and reconnects to a stream multiple times,
|
||||
* the second+ reconnect lost chunks because the transport deleted stream state
|
||||
* on last unsubscribe, destroying the allSubscribersLeftCallbacks registered
|
||||
* by createJob(). This prevented hasSubscriber from being reset, which in turn
|
||||
* prevented syncReorderBuffer from being called on reconnect.
|
||||
*
|
||||
* Fix: Preserve stream state (callbacks, abort handlers) across reconnect cycles
|
||||
* instead of deleting it. The state is fully cleaned up by cleanup() when the
|
||||
* job completes.
|
||||
*
|
||||
* Run with: USE_REDIS=true npx jest reconnect-reorder-desync
|
||||
*/
|
||||
describe('Reconnect Reorder Buffer Desync (Regression)', () => {
|
||||
describe('Callback preservation across reconnect cycles (Unit)', () => {
|
||||
test('allSubscribersLeft callback fires on every disconnect, not just the first', () => {
|
||||
const mockPublisher = {
|
||||
publish: jest.fn().mockResolvedValue(1),
|
||||
};
|
||||
const mockSubscriber = {
|
||||
on: jest.fn(),
|
||||
subscribe: jest.fn().mockResolvedValue(undefined),
|
||||
unsubscribe: jest.fn().mockResolvedValue(undefined),
|
||||
};
|
||||
|
||||
const transport = new RedisEventTransport(
|
||||
mockPublisher as unknown as Redis,
|
||||
mockSubscriber as unknown as Redis,
|
||||
);
|
||||
|
||||
const streamId = 'callback-persist-test';
|
||||
let callbackFireCount = 0;
|
||||
|
||||
// Register callback (simulates what createJob does)
|
||||
transport.onAllSubscribersLeft(streamId, () => {
|
||||
callbackFireCount++;
|
||||
});
|
||||
|
||||
// First subscribe/unsubscribe cycle
|
||||
const sub1 = transport.subscribe(streamId, { onChunk: () => {} });
|
||||
sub1.unsubscribe();
|
||||
|
||||
expect(callbackFireCount).toBe(1);
|
||||
|
||||
// Second subscribe/unsubscribe cycle — callback must still fire
|
||||
const sub2 = transport.subscribe(streamId, { onChunk: () => {} });
|
||||
sub2.unsubscribe();
|
||||
|
||||
expect(callbackFireCount).toBe(2);
|
||||
|
||||
// Third cycle — continues to work
|
||||
const sub3 = transport.subscribe(streamId, { onChunk: () => {} });
|
||||
sub3.unsubscribe();
|
||||
|
||||
expect(callbackFireCount).toBe(3);
|
||||
|
||||
transport.destroy();
|
||||
});
|
||||
|
||||
test('abort callback survives across reconnect cycles', () => {
|
||||
const mockPublisher = {
|
||||
publish: jest.fn().mockResolvedValue(1),
|
||||
};
|
||||
const mockSubscriber = {
|
||||
on: jest.fn(),
|
||||
subscribe: jest.fn().mockResolvedValue(undefined),
|
||||
unsubscribe: jest.fn().mockResolvedValue(undefined),
|
||||
};
|
||||
|
||||
const transport = new RedisEventTransport(
|
||||
mockPublisher as unknown as Redis,
|
||||
mockSubscriber as unknown as Redis,
|
||||
);
|
||||
|
||||
const streamId = 'abort-callback-persist-test';
|
||||
let abortCallbackFired = false;
|
||||
|
||||
// Register abort callback (simulates what createJob does)
|
||||
transport.onAbort(streamId, () => {
|
||||
abortCallbackFired = true;
|
||||
});
|
||||
|
||||
// Subscribe/unsubscribe cycle
|
||||
const sub1 = transport.subscribe(streamId, { onChunk: () => {} });
|
||||
sub1.unsubscribe();
|
||||
|
||||
// Re-subscribe and receive an abort signal
|
||||
const sub2 = transport.subscribe(streamId, { onChunk: () => {} });
|
||||
|
||||
const messageHandler = mockSubscriber.on.mock.calls.find(
|
||||
(call) => call[0] === 'message',
|
||||
)?.[1] as (channel: string, message: string) => void;
|
||||
|
||||
const channel = `stream:{${streamId}}:events`;
|
||||
messageHandler(channel, JSON.stringify({ type: 'abort' }));
|
||||
|
||||
// Abort callback should fire — it was preserved across the reconnect
|
||||
expect(abortCallbackFired).toBe(true);
|
||||
|
||||
sub2.unsubscribe();
|
||||
transport.destroy();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Reorder buffer sync on reconnect (Unit)', () => {
|
||||
/**
|
||||
* After the fix, the allSubscribersLeft callback fires on every disconnect,
|
||||
* which resets hasSubscriber. GenerationJobManager.subscribe() then enters
|
||||
* the if (!runtime.hasSubscriber) block and calls syncReorderBuffer.
|
||||
*
|
||||
* This test verifies at the transport level that when syncReorderBuffer IS
|
||||
* called (as it now will be on every reconnect), messages are delivered
|
||||
* immediately regardless of how many reconnect cycles have occurred.
|
||||
*/
|
||||
test('syncReorderBuffer works correctly on third+ reconnect', async () => {
|
||||
const mockPublisher = {
|
||||
publish: jest.fn().mockResolvedValue(1),
|
||||
};
|
||||
const mockSubscriber = {
|
||||
on: jest.fn(),
|
||||
subscribe: jest.fn().mockResolvedValue(undefined),
|
||||
unsubscribe: jest.fn().mockResolvedValue(undefined),
|
||||
};
|
||||
|
||||
const transport = new RedisEventTransport(
|
||||
mockPublisher as unknown as Redis,
|
||||
mockSubscriber as unknown as Redis,
|
||||
);
|
||||
|
||||
const streamId = 'reorder-multi-reconnect-test';
|
||||
|
||||
transport.onAllSubscribersLeft(streamId, () => {
|
||||
// Simulates the callback from createJob
|
||||
});
|
||||
|
||||
const messageHandler = mockSubscriber.on.mock.calls.find(
|
||||
(call) => call[0] === 'message',
|
||||
)?.[1] as (channel: string, message: string) => void;
|
||||
|
||||
const channel = `stream:{${streamId}}:events`;
|
||||
|
||||
// Run 3 full subscribe/emit/unsubscribe cycles
|
||||
for (let cycle = 0; cycle < 3; cycle++) {
|
||||
const chunks: unknown[] = [];
|
||||
const sub = transport.subscribe(streamId, {
|
||||
onChunk: (event) => chunks.push(event),
|
||||
});
|
||||
|
||||
// Sync reorder buffer (as GenerationJobManager.subscribe does)
|
||||
transport.syncReorderBuffer(streamId);
|
||||
|
||||
const baseSeq = cycle * 10;
|
||||
|
||||
// Emit 10 chunks (advances publisher sequence)
|
||||
for (let i = 0; i < 10; i++) {
|
||||
await transport.emitChunk(streamId, { index: baseSeq + i });
|
||||
}
|
||||
|
||||
// Deliver messages via pub/sub handler
|
||||
for (let i = 0; i < 10; i++) {
|
||||
messageHandler(
|
||||
channel,
|
||||
JSON.stringify({ type: 'chunk', seq: baseSeq + i, data: { index: baseSeq + i } }),
|
||||
);
|
||||
}
|
||||
|
||||
// Messages should be delivered immediately on every cycle
|
||||
expect(chunks.length).toBe(10);
|
||||
expect(chunks.map((c) => (c as { index: number }).index)).toEqual(
|
||||
Array.from({ length: 10 }, (_, i) => baseSeq + i),
|
||||
);
|
||||
|
||||
sub.unsubscribe();
|
||||
}
|
||||
|
||||
transport.destroy();
|
||||
});
|
||||
|
||||
test('reorder buffer works correctly when syncReorderBuffer IS called', async () => {
|
||||
const mockPublisher = {
|
||||
publish: jest.fn().mockResolvedValue(1),
|
||||
};
|
||||
const mockSubscriber = {
|
||||
on: jest.fn(),
|
||||
subscribe: jest.fn().mockResolvedValue(undefined),
|
||||
unsubscribe: jest.fn().mockResolvedValue(undefined),
|
||||
};
|
||||
|
||||
const transport = new RedisEventTransport(
|
||||
mockPublisher as unknown as Redis,
|
||||
mockSubscriber as unknown as Redis,
|
||||
);
|
||||
|
||||
const streamId = 'reorder-sync-test';
|
||||
|
||||
// Emit 20 chunks to advance publisher sequence
|
||||
for (let i = 0; i < 20; i++) {
|
||||
await transport.emitChunk(streamId, { index: i });
|
||||
}
|
||||
|
||||
// Subscribe and sync the reorder buffer
|
||||
const chunks: unknown[] = [];
|
||||
const sub = transport.subscribe(streamId, {
|
||||
onChunk: (event) => chunks.push(event),
|
||||
});
|
||||
|
||||
// This is the critical call - sync nextSeq to match publisher
|
||||
transport.syncReorderBuffer(streamId);
|
||||
|
||||
// Deliver messages starting at seq 20
|
||||
const messageHandler = mockSubscriber.on.mock.calls.find(
|
||||
(call) => call[0] === 'message',
|
||||
)?.[1] as (channel: string, message: string) => void;
|
||||
|
||||
const channel = `stream:{${streamId}}:events`;
|
||||
|
||||
for (let i = 20; i < 25; i++) {
|
||||
messageHandler(channel, JSON.stringify({ type: 'chunk', seq: i, data: { index: i } }));
|
||||
}
|
||||
|
||||
// Messages should be delivered IMMEDIATELY (no 500ms wait)
|
||||
// because nextSeq was synced to 20
|
||||
expect(chunks.length).toBe(5);
|
||||
expect(chunks.map((c) => (c as { index: number }).index)).toEqual([20, 21, 22, 23, 24]);
|
||||
|
||||
sub.unsubscribe();
|
||||
transport.destroy();
|
||||
});
|
||||
});
|
||||
|
||||
describe('End-to-end reconnect with GenerationJobManager (Integration)', () => {
|
||||
let originalEnv: NodeJS.ProcessEnv;
|
||||
let ioredisClient: Redis | Cluster | null = null;
|
||||
let dynamicKeyvClient: unknown = null;
|
||||
let dynamicKeyvReady: Promise<unknown> | null = null;
|
||||
const testPrefix = 'ReconnectDesync-Test';
|
||||
|
||||
beforeAll(async () => {
|
||||
originalEnv = { ...process.env };
|
||||
|
||||
process.env.USE_REDIS = process.env.USE_REDIS ?? 'true';
|
||||
process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379';
|
||||
process.env.REDIS_KEY_PREFIX = testPrefix;
|
||||
|
||||
jest.resetModules();
|
||||
|
||||
const redisModule = await import('~/cache/redisClients');
|
||||
ioredisClient = redisModule.ioredisClient;
|
||||
dynamicKeyvClient = redisModule.keyvRedisClient;
|
||||
dynamicKeyvReady = redisModule.keyvRedisClientReady;
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
jest.resetModules();
|
||||
|
||||
if (ioredisClient) {
|
||||
try {
|
||||
const keys = await ioredisClient.keys(`${testPrefix}*`);
|
||||
const streamKeys = await ioredisClient.keys('stream:*');
|
||||
const allKeys = [...keys, ...streamKeys];
|
||||
await Promise.all(allKeys.map((key) => ioredisClient!.del(key)));
|
||||
} catch {
|
||||
// Ignore cleanup errors
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
for (const ready of [keyvRedisClientReady, dynamicKeyvReady]) {
|
||||
if (ready) {
|
||||
await ready.catch(() => {});
|
||||
}
|
||||
}
|
||||
|
||||
const clients = [ioredisClient, staticRedisClient, staticKeyvClient, dynamicKeyvClient];
|
||||
for (const client of clients) {
|
||||
if (!client) {
|
||||
continue;
|
||||
}
|
||||
try {
|
||||
await (client as { disconnect: () => void | Promise<void> }).disconnect();
|
||||
} catch {
|
||||
/* ignore */
|
||||
}
|
||||
}
|
||||
|
||||
process.env = originalEnv;
|
||||
});
|
||||
|
||||
/**
|
||||
* Verifies that all reconnect cycles deliver chunks immediately —
|
||||
* not just the first reconnect.
|
||||
*/
|
||||
test('chunks are delivered immediately on every reconnect cycle', async () => {
|
||||
if (!ioredisClient) {
|
||||
console.warn('Redis not available, skipping test');
|
||||
return;
|
||||
}
|
||||
|
||||
const manager = new GenerationJobManagerClass();
|
||||
const services = createStreamServices({
|
||||
useRedis: true,
|
||||
redisClient: ioredisClient,
|
||||
});
|
||||
|
||||
manager.configure(services);
|
||||
manager.initialize();
|
||||
|
||||
const streamId = `reconnect-fixed-${Date.now()}`;
|
||||
await manager.createJob(streamId, 'user-1');
|
||||
|
||||
// Run 3 subscribe/emit/unsubscribe cycles
|
||||
for (let cycle = 0; cycle < 3; cycle++) {
|
||||
const chunks: unknown[] = [];
|
||||
const sub = await manager.subscribe(streamId, (event) => chunks.push(event));
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
|
||||
// Emit 10 chunks
|
||||
for (let i = 0; i < 10; i++) {
|
||||
await manager.emitChunk(streamId, {
|
||||
event: 'on_message_delta',
|
||||
data: {
|
||||
delta: { content: { type: 'text', text: `c${cycle}-${i}` } },
|
||||
index: cycle * 10 + i,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// Chunks should arrive within 200ms (well under the 500ms force-flush timeout)
|
||||
await new Promise((resolve) => setTimeout(resolve, 200));
|
||||
|
||||
expect(chunks.length).toBe(10);
|
||||
|
||||
sub!.unsubscribe();
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
}
|
||||
|
||||
await manager.destroy();
|
||||
});
|
||||
|
||||
/**
|
||||
* Verifies that syncSent is correctly reset on every disconnect,
|
||||
* proving the onAllSubscribersLeft callback survives reconnect cycles.
|
||||
*/
|
||||
test('onAllSubscribersLeft callback resets state on every disconnect', async () => {
|
||||
if (!ioredisClient) {
|
||||
console.warn('Redis not available, skipping test');
|
||||
return;
|
||||
}
|
||||
|
||||
const manager = new GenerationJobManagerClass();
|
||||
const services = createStreamServices({
|
||||
useRedis: true,
|
||||
redisClient: ioredisClient,
|
||||
});
|
||||
|
||||
manager.configure(services);
|
||||
manager.initialize();
|
||||
|
||||
const streamId = `callback-persist-integ-${Date.now()}`;
|
||||
await manager.createJob(streamId, 'user-1');
|
||||
|
||||
for (let cycle = 0; cycle < 3; cycle++) {
|
||||
const sub = await manager.subscribe(streamId, () => {});
|
||||
await new Promise((resolve) => setTimeout(resolve, 50));
|
||||
|
||||
// Mark sync as sent
|
||||
manager.markSyncSent(streamId);
|
||||
await new Promise((resolve) => setTimeout(resolve, 50));
|
||||
|
||||
let syncSent = await manager.wasSyncSent(streamId);
|
||||
expect(syncSent).toBe(true);
|
||||
|
||||
// Disconnect
|
||||
sub!.unsubscribe();
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
|
||||
// Callback should reset syncSent on every disconnect
|
||||
syncSent = await manager.wasSyncSent(streamId);
|
||||
expect(syncSent).toBe(false);
|
||||
}
|
||||
|
||||
await manager.destroy();
|
||||
});
|
||||
|
||||
/**
|
||||
* Verifies all reconnect cycles deliver chunks immediately with no
|
||||
* increasing gap pattern.
|
||||
*/
|
||||
test('no increasing gap pattern across reconnect cycles', async () => {
|
||||
if (!ioredisClient) {
|
||||
console.warn('Redis not available, skipping test');
|
||||
return;
|
||||
}
|
||||
|
||||
const manager = new GenerationJobManagerClass();
|
||||
const services = createStreamServices({
|
||||
useRedis: true,
|
||||
redisClient: ioredisClient,
|
||||
});
|
||||
|
||||
manager.configure(services);
|
||||
manager.initialize();
|
||||
|
||||
const streamId = `no-gaps-${Date.now()}`;
|
||||
await manager.createJob(streamId, 'user-1');
|
||||
|
||||
const chunksPerCycle = 15;
|
||||
|
||||
for (let cycle = 0; cycle < 4; cycle++) {
|
||||
const chunks: unknown[] = [];
|
||||
const sub = await manager.subscribe(streamId, (event) => chunks.push(event));
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
|
||||
// Emit chunks
|
||||
for (let i = 0; i < chunksPerCycle; i++) {
|
||||
await manager.emitChunk(streamId, {
|
||||
event: 'on_message_delta',
|
||||
data: {
|
||||
delta: { content: { type: 'text', text: `c${cycle}-${i}` } },
|
||||
index: cycle * chunksPerCycle + i,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// All chunks should arrive within 200ms on every cycle
|
||||
await new Promise((resolve) => setTimeout(resolve, 200));
|
||||
expect(chunks.length).toBe(chunksPerCycle);
|
||||
|
||||
sub!.unsubscribe();
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
}
|
||||
|
||||
await manager.destroy();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -32,7 +32,7 @@ export class InMemoryEventTransport implements IEventTransport {
|
|||
onDone?: (event: unknown) => void;
|
||||
onError?: (error: string) => void;
|
||||
},
|
||||
): { unsubscribe: () => void } {
|
||||
): { unsubscribe: () => void; ready?: Promise<void> } {
|
||||
const state = this.getOrCreateStream(streamId);
|
||||
|
||||
const chunkHandler = (event: unknown) => handlers.onChunk(event);
|
||||
|
|
@ -58,9 +58,11 @@ export class InMemoryEventTransport implements IEventTransport {
|
|||
// Check if all subscribers left - cleanup and notify
|
||||
if (currentState.emitter.listenerCount('chunk') === 0) {
|
||||
currentState.allSubscribersLeftCallback?.();
|
||||
// Auto-cleanup the stream entry when no subscribers remain
|
||||
/* Remove all EventEmitter listeners but preserve stream state
|
||||
* (including allSubscribersLeftCallback) for reconnection.
|
||||
* State is fully cleaned up by cleanup() when the job completes.
|
||||
*/
|
||||
currentState.emitter.removeAllListeners();
|
||||
this.streams.delete(streamId);
|
||||
}
|
||||
}
|
||||
},
|
||||
|
|
|
|||
|
|
@ -92,8 +92,8 @@ export class RedisEventTransport implements IEventTransport {
|
|||
private subscriber: Redis | Cluster;
|
||||
/** Track subscribers per stream */
|
||||
private streams = new Map<string, StreamSubscribers>();
|
||||
/** Track which channels we're subscribed to */
|
||||
private subscribedChannels = new Set<string>();
|
||||
/** Track channel subscription state: resolved promise = active, pending = in-flight */
|
||||
private channelSubscriptions = new Map<string, Promise<void>>();
|
||||
/** Counter for generating unique subscriber IDs */
|
||||
private subscriberIdCounter = 0;
|
||||
/** Sequence counters per stream for publishing (ensures ordered delivery in cluster mode) */
|
||||
|
|
@ -122,9 +122,32 @@ export class RedisEventTransport implements IEventTransport {
|
|||
return current;
|
||||
}
|
||||
|
||||
/** Reset sequence counter for a stream */
|
||||
private resetSequence(streamId: string): void {
|
||||
/** Reset publish sequence counter and subscriber reorder state for a stream (full cleanup only) */
|
||||
resetSequence(streamId: string): void {
|
||||
this.sequenceCounters.delete(streamId);
|
||||
const state = this.streams.get(streamId);
|
||||
if (state) {
|
||||
if (state.reorderBuffer.flushTimeout) {
|
||||
clearTimeout(state.reorderBuffer.flushTimeout);
|
||||
state.reorderBuffer.flushTimeout = null;
|
||||
}
|
||||
state.reorderBuffer.nextSeq = 0;
|
||||
state.reorderBuffer.pending.clear();
|
||||
}
|
||||
}
|
||||
|
||||
/** Advance subscriber reorder buffer to current publisher sequence without resetting publisher (cross-replica safe) */
|
||||
syncReorderBuffer(streamId: string): void {
|
||||
const currentSeq = this.sequenceCounters.get(streamId) ?? 0;
|
||||
const state = this.streams.get(streamId);
|
||||
if (state) {
|
||||
if (state.reorderBuffer.flushTimeout) {
|
||||
clearTimeout(state.reorderBuffer.flushTimeout);
|
||||
state.reorderBuffer.flushTimeout = null;
|
||||
}
|
||||
state.reorderBuffer.nextSeq = currentSeq;
|
||||
state.reorderBuffer.pending.clear();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -331,7 +354,7 @@ export class RedisEventTransport implements IEventTransport {
|
|||
onDone?: (event: unknown) => void;
|
||||
onError?: (error: string) => void;
|
||||
},
|
||||
): { unsubscribe: () => void } {
|
||||
): { unsubscribe: () => void; ready?: Promise<void> } {
|
||||
const channel = CHANNELS.events(streamId);
|
||||
const subscriberId = `sub_${++this.subscriberIdCounter}`;
|
||||
|
||||
|
|
@ -354,16 +377,23 @@ export class RedisEventTransport implements IEventTransport {
|
|||
streamState.count++;
|
||||
streamState.handlers.set(subscriberId, handlers);
|
||||
|
||||
// Subscribe to Redis channel if this is first subscriber
|
||||
if (!this.subscribedChannels.has(channel)) {
|
||||
this.subscribedChannels.add(channel);
|
||||
this.subscriber.subscribe(channel).catch((err) => {
|
||||
logger.error(`[RedisEventTransport] Failed to subscribe to ${channel}:`, err);
|
||||
});
|
||||
let readyPromise = this.channelSubscriptions.get(channel);
|
||||
|
||||
if (!readyPromise) {
|
||||
readyPromise = this.subscriber
|
||||
.subscribe(channel)
|
||||
.then(() => {
|
||||
logger.debug(`[RedisEventTransport] Subscription active for channel ${channel}`);
|
||||
})
|
||||
.catch((err) => {
|
||||
this.channelSubscriptions.delete(channel);
|
||||
logger.error(`[RedisEventTransport] Failed to subscribe to ${channel}:`, err);
|
||||
});
|
||||
this.channelSubscriptions.set(channel, readyPromise);
|
||||
}
|
||||
|
||||
// Return unsubscribe function
|
||||
return {
|
||||
ready: readyPromise,
|
||||
unsubscribe: () => {
|
||||
const state = this.streams.get(streamId);
|
||||
if (!state) {
|
||||
|
|
@ -385,7 +415,7 @@ export class RedisEventTransport implements IEventTransport {
|
|||
this.subscriber.unsubscribe(channel).catch((err) => {
|
||||
logger.error(`[RedisEventTransport] Failed to unsubscribe from ${channel}:`, err);
|
||||
});
|
||||
this.subscribedChannels.delete(channel);
|
||||
this.channelSubscriptions.delete(channel);
|
||||
|
||||
// Call all-subscribers-left callbacks
|
||||
for (const callback of state.allSubscribersLeftCallbacks) {
|
||||
|
|
@ -395,8 +425,15 @@ export class RedisEventTransport implements IEventTransport {
|
|||
logger.error(`[RedisEventTransport] Error in allSubscribersLeft callback:`, err);
|
||||
}
|
||||
}
|
||||
|
||||
this.streams.delete(streamId);
|
||||
/**
|
||||
* Preserve stream state (callbacks, abort handlers) for reconnection.
|
||||
* Previously this deleted the entire state, which lost the
|
||||
* allSubscribersLeftCallbacks and abortCallbacks registered by
|
||||
* GenerationJobManager.createJob(). On the next subscribe() call,
|
||||
* fresh state was created without those callbacks, causing
|
||||
* hasSubscriber to never reset and syncReorderBuffer to be skipped.
|
||||
* State is fully cleaned up by cleanup() when the job completes.
|
||||
*/
|
||||
}
|
||||
},
|
||||
};
|
||||
|
|
@ -431,6 +468,7 @@ export class RedisEventTransport implements IEventTransport {
|
|||
await this.publisher.publish(channel, JSON.stringify(message));
|
||||
} catch (err) {
|
||||
logger.error(`[RedisEventTransport] Failed to publish done:`, err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -447,6 +485,7 @@ export class RedisEventTransport implements IEventTransport {
|
|||
await this.publisher.publish(channel, JSON.stringify(message));
|
||||
} catch (err) {
|
||||
logger.error(`[RedisEventTransport] Failed to publish error:`, err);
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -532,12 +571,15 @@ export class RedisEventTransport implements IEventTransport {
|
|||
|
||||
state.abortCallbacks.push(callback);
|
||||
|
||||
// Subscribe to Redis channel if not already subscribed
|
||||
if (!this.subscribedChannels.has(channel)) {
|
||||
this.subscribedChannels.add(channel);
|
||||
this.subscriber.subscribe(channel).catch((err) => {
|
||||
logger.error(`[RedisEventTransport] Failed to subscribe to ${channel}:`, err);
|
||||
});
|
||||
if (!this.channelSubscriptions.has(channel)) {
|
||||
const ready = this.subscriber
|
||||
.subscribe(channel)
|
||||
.then(() => {})
|
||||
.catch((err) => {
|
||||
this.channelSubscriptions.delete(channel);
|
||||
logger.error(`[RedisEventTransport] Failed to subscribe to ${channel}:`, err);
|
||||
});
|
||||
this.channelSubscriptions.set(channel, ready);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -571,12 +613,11 @@ export class RedisEventTransport implements IEventTransport {
|
|||
// Reset sequence counter for this stream
|
||||
this.resetSequence(streamId);
|
||||
|
||||
// Unsubscribe from Redis channel
|
||||
if (this.subscribedChannels.has(channel)) {
|
||||
if (this.channelSubscriptions.has(channel)) {
|
||||
this.subscriber.unsubscribe(channel).catch((err) => {
|
||||
logger.error(`[RedisEventTransport] Failed to cleanup ${channel}:`, err);
|
||||
});
|
||||
this.subscribedChannels.delete(channel);
|
||||
this.channelSubscriptions.delete(channel);
|
||||
}
|
||||
|
||||
this.streams.delete(streamId);
|
||||
|
|
@ -595,18 +636,20 @@ export class RedisEventTransport implements IEventTransport {
|
|||
state.reorderBuffer.pending.clear();
|
||||
}
|
||||
|
||||
// Unsubscribe from all channels
|
||||
for (const channel of this.subscribedChannels) {
|
||||
this.subscriber.unsubscribe(channel).catch(() => {
|
||||
// Ignore errors during shutdown
|
||||
});
|
||||
for (const channel of this.channelSubscriptions.keys()) {
|
||||
this.subscriber.unsubscribe(channel).catch(() => {});
|
||||
}
|
||||
|
||||
this.subscribedChannels.clear();
|
||||
this.channelSubscriptions.clear();
|
||||
this.streams.clear();
|
||||
this.sequenceCounters.clear();
|
||||
|
||||
// Note: Don't close Redis connections - they may be shared
|
||||
try {
|
||||
this.subscriber.disconnect();
|
||||
} catch {
|
||||
/* ignore */
|
||||
}
|
||||
|
||||
logger.info('[RedisEventTransport] Destroyed');
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -156,13 +156,13 @@ export class RedisJobStore implements IJobStore {
|
|||
// For cluster mode, we can't pipeline keys on different slots
|
||||
// The job key uses hash tag {streamId}, runningJobs and userJobs are on different slots
|
||||
if (this.isCluster) {
|
||||
await this.redis.hmset(key, this.serializeJob(job));
|
||||
await this.redis.hset(key, this.serializeJob(job));
|
||||
await this.redis.expire(key, this.ttl.running);
|
||||
await this.redis.sadd(KEYS.runningJobs, streamId);
|
||||
await this.redis.sadd(userJobsKey, streamId);
|
||||
} else {
|
||||
const pipeline = this.redis.pipeline();
|
||||
pipeline.hmset(key, this.serializeJob(job));
|
||||
pipeline.hset(key, this.serializeJob(job));
|
||||
pipeline.expire(key, this.ttl.running);
|
||||
pipeline.sadd(KEYS.runningJobs, streamId);
|
||||
pipeline.sadd(userJobsKey, streamId);
|
||||
|
|
@ -183,17 +183,23 @@ export class RedisJobStore implements IJobStore {
|
|||
|
||||
async updateJob(streamId: string, updates: Partial<SerializableJobData>): Promise<void> {
|
||||
const key = KEYS.job(streamId);
|
||||
const exists = await this.redis.exists(key);
|
||||
if (!exists) {
|
||||
return;
|
||||
}
|
||||
|
||||
const serialized = this.serializeJob(updates as SerializableJobData);
|
||||
if (Object.keys(serialized).length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
await this.redis.hmset(key, serialized);
|
||||
const fields = Object.entries(serialized).flat();
|
||||
const updated = await this.redis.eval(
|
||||
'if redis.call("EXISTS", KEYS[1]) == 1 then redis.call("HSET", KEYS[1], unpack(ARGV)) return 1 else return 0 end',
|
||||
1,
|
||||
key,
|
||||
...fields,
|
||||
);
|
||||
|
||||
if (updated === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// If status changed to complete/error/aborted, update TTL and remove from running set
|
||||
// Note: userJobs cleanup is handled lazily via self-healing in getActiveJobIdsByUser
|
||||
|
|
@ -296,32 +302,46 @@ export class RedisJobStore implements IJobStore {
|
|||
}
|
||||
}
|
||||
|
||||
for (const streamId of streamIds) {
|
||||
const job = await this.getJob(streamId);
|
||||
// Process in batches of 50 to avoid sequential per-job round-trips
|
||||
const BATCH_SIZE = 50;
|
||||
for (let i = 0; i < streamIds.length; i += BATCH_SIZE) {
|
||||
const batch = streamIds.slice(i, i + BATCH_SIZE);
|
||||
const results = await Promise.allSettled(
|
||||
batch.map(async (streamId) => {
|
||||
const job = await this.getJob(streamId);
|
||||
|
||||
// Job no longer exists (TTL expired) - remove from set
|
||||
if (!job) {
|
||||
await this.redis.srem(KEYS.runningJobs, streamId);
|
||||
this.localGraphCache.delete(streamId);
|
||||
this.localCollectedUsageCache.delete(streamId);
|
||||
cleaned++;
|
||||
continue;
|
||||
}
|
||||
// Job no longer exists (TTL expired) - remove from set
|
||||
if (!job) {
|
||||
await this.redis.srem(KEYS.runningJobs, streamId);
|
||||
this.localGraphCache.delete(streamId);
|
||||
this.localCollectedUsageCache.delete(streamId);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Job completed but still in running set (shouldn't happen, but handle it)
|
||||
if (job.status !== 'running') {
|
||||
await this.redis.srem(KEYS.runningJobs, streamId);
|
||||
this.localGraphCache.delete(streamId);
|
||||
this.localCollectedUsageCache.delete(streamId);
|
||||
cleaned++;
|
||||
continue;
|
||||
}
|
||||
// Job completed but still in running set (shouldn't happen, but handle it)
|
||||
if (job.status !== 'running') {
|
||||
await this.redis.srem(KEYS.runningJobs, streamId);
|
||||
this.localGraphCache.delete(streamId);
|
||||
this.localCollectedUsageCache.delete(streamId);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Stale running job (failsafe - running for > configured TTL)
|
||||
if (now - job.createdAt > this.ttl.running * 1000) {
|
||||
logger.warn(`[RedisJobStore] Cleaning up stale job: ${streamId}`);
|
||||
await this.deleteJob(streamId);
|
||||
cleaned++;
|
||||
// Stale running job (failsafe - running for > configured TTL)
|
||||
if (now - job.createdAt > this.ttl.running * 1000) {
|
||||
logger.warn(`[RedisJobStore] Cleaning up stale job: ${streamId}`);
|
||||
await this.deleteJob(streamId);
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}),
|
||||
);
|
||||
for (const result of results) {
|
||||
if (result.status === 'fulfilled') {
|
||||
cleaned += result.value;
|
||||
} else {
|
||||
logger.warn(`[RedisJobStore] Cleanup failed for a job:`, result.reason);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -586,16 +606,14 @@ export class RedisJobStore implements IJobStore {
|
|||
*/
|
||||
async appendChunk(streamId: string, event: unknown): Promise<void> {
|
||||
const key = KEYS.chunks(streamId);
|
||||
const added = await this.redis.xadd(key, '*', 'event', JSON.stringify(event));
|
||||
|
||||
// Set TTL on first chunk (when stream is created)
|
||||
// Subsequent chunks inherit the stream's TTL
|
||||
if (added) {
|
||||
const len = await this.redis.xlen(key);
|
||||
if (len === 1) {
|
||||
await this.redis.expire(key, this.ttl.running);
|
||||
}
|
||||
}
|
||||
// Pipeline XADD + EXPIRE in a single round-trip.
|
||||
// EXPIRE is O(1) and idempotent — refreshing TTL on every chunk is better than
|
||||
// only setting it once, since the original approach could let the TTL expire
|
||||
// during long-running streams.
|
||||
const pipeline = this.redis.pipeline();
|
||||
pipeline.xadd(key, '*', 'event', JSON.stringify(event));
|
||||
pipeline.expire(key, this.ttl.running);
|
||||
await pipeline.exec();
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -286,7 +286,7 @@ export interface IJobStore {
|
|||
* Implementations can use EventEmitter, Redis Pub/Sub, etc.
|
||||
*/
|
||||
export interface IEventTransport {
|
||||
/** Subscribe to events for a stream */
|
||||
/** Subscribe to events for a stream. `ready` resolves once the transport can receive messages. */
|
||||
subscribe(
|
||||
streamId: string,
|
||||
handlers: {
|
||||
|
|
@ -294,7 +294,7 @@ export interface IEventTransport {
|
|||
onDone?: (event: unknown) => void;
|
||||
onError?: (error: string) => void;
|
||||
},
|
||||
): { unsubscribe: () => void };
|
||||
): { unsubscribe: () => void; ready?: Promise<void> };
|
||||
|
||||
/** Publish a chunk event - returns Promise in Redis mode for ordered delivery */
|
||||
emitChunk(streamId: string, event: unknown): void | Promise<void>;
|
||||
|
|
@ -329,6 +329,12 @@ export interface IEventTransport {
|
|||
/** Listen for all subscribers leaving */
|
||||
onAllSubscribersLeft(streamId: string, callback: () => void): void;
|
||||
|
||||
/** Reset publish sequence counter for a stream (used during full stream cleanup) */
|
||||
resetSequence?(streamId: string): void;
|
||||
|
||||
/** Advance subscriber reorder buffer to match publisher sequence (cross-replica safe: doesn't reset publisher counter) */
|
||||
syncReorderBuffer?(streamId: string): void;
|
||||
|
||||
/** Cleanup transport resources for a specific stream */
|
||||
cleanup(streamId: string): void;
|
||||
|
||||
|
|
|
|||
|
|
@ -8,9 +8,10 @@
|
|||
import { Constants, actionDelimiter } from 'librechat-data-provider';
|
||||
import type { AgentToolOptions } from 'librechat-data-provider';
|
||||
import type { LCToolRegistry, JsonSchemaType, LCTool, GenericTool } from '@librechat/agents';
|
||||
import { buildToolClassification, type ToolDefinition } from './classification';
|
||||
import type { ToolDefinition } from './classification';
|
||||
import { resolveJsonSchemaRefs, normalizeJsonSchema } from '~/mcp/zod';
|
||||
import { buildToolClassification } from './classification';
|
||||
import { getToolDefinition } from './registry/definitions';
|
||||
import { resolveJsonSchemaRefs } from '~/mcp/zod';
|
||||
|
||||
export interface MCPServerTool {
|
||||
function?: {
|
||||
|
|
@ -138,7 +139,7 @@ export async function loadToolDefinitions(
|
|||
name: actualToolName,
|
||||
description: toolDef.function.description,
|
||||
parameters: toolDef.function.parameters
|
||||
? resolveJsonSchemaRefs(toolDef.function.parameters)
|
||||
? normalizeJsonSchema(resolveJsonSchemaRefs(toolDef.function.parameters))
|
||||
: undefined,
|
||||
serverName,
|
||||
});
|
||||
|
|
@ -153,7 +154,7 @@ export async function loadToolDefinitions(
|
|||
name: toolName,
|
||||
description: toolDef.function.description,
|
||||
parameters: toolDef.function.parameters
|
||||
? resolveJsonSchemaRefs(toolDef.function.parameters)
|
||||
? normalizeJsonSchema(resolveJsonSchemaRefs(toolDef.function.parameters))
|
||||
: undefined,
|
||||
serverName,
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import type { Request } from 'express';
|
||||
import type { IUser, AppConfig } from '@librechat/data-schemas';
|
||||
import type { TEndpointOption } from 'librechat-data-provider';
|
||||
import type { Request } from 'express';
|
||||
|
||||
/**
|
||||
* LibreChat-specific request body type that extends Express Request body
|
||||
|
|
@ -11,8 +12,10 @@ export type RequestBody = {
|
|||
conversationId?: string;
|
||||
parentMessageId?: string;
|
||||
endpoint?: string;
|
||||
endpointType?: string;
|
||||
model?: string;
|
||||
key?: string;
|
||||
endpointOption?: Partial<TEndpointOption>;
|
||||
};
|
||||
|
||||
export type ServerRequest = Request<unknown, unknown, RequestBody> & {
|
||||
|
|
|
|||
|
|
@ -427,6 +427,35 @@ describe('OpenID Token Utilities', () => {
|
|||
expect(result).toContain('User:');
|
||||
});
|
||||
|
||||
it('should resolve LIBRECHAT_OPENID_ID_TOKEN and LIBRECHAT_OPENID_ACCESS_TOKEN to different values', () => {
|
||||
const user: Partial<TUser> = {
|
||||
id: 'user-123',
|
||||
provider: 'openid',
|
||||
openidId: 'oidc-sub-456',
|
||||
email: 'test@example.com',
|
||||
name: 'Test User',
|
||||
federatedTokens: {
|
||||
access_token: 'my-access-token',
|
||||
id_token: 'my-id-token',
|
||||
refresh_token: 'my-refresh-token',
|
||||
expires_at: Math.floor(Date.now() / 1000) + 3600,
|
||||
},
|
||||
};
|
||||
|
||||
const tokenInfo = extractOpenIDTokenInfo(user);
|
||||
expect(tokenInfo).not.toBeNull();
|
||||
expect(tokenInfo!.accessToken).toBe('my-access-token');
|
||||
expect(tokenInfo!.idToken).toBe('my-id-token');
|
||||
expect(tokenInfo!.accessToken).not.toBe(tokenInfo!.idToken);
|
||||
|
||||
const input = 'ACCESS={{LIBRECHAT_OPENID_ACCESS_TOKEN}}, ID={{LIBRECHAT_OPENID_ID_TOKEN}}';
|
||||
const result = processOpenIDPlaceholders(input, tokenInfo!);
|
||||
|
||||
expect(result).toBe('ACCESS=my-access-token, ID=my-id-token');
|
||||
// Verify they are not the same value (the reported bug)
|
||||
expect(result).not.toBe('ACCESS=my-access-token, ID=my-access-token');
|
||||
});
|
||||
|
||||
it('should handle expired tokens correctly', () => {
|
||||
const user: Partial<TUser> = {
|
||||
id: 'user-123',
|
||||
|
|
|
|||
|
|
@ -148,6 +148,7 @@ const anthropicModels = {
|
|||
'claude-3.5-sonnet-latest': 200000,
|
||||
'claude-haiku-4-5': 200000,
|
||||
'claude-sonnet-4': 1000000,
|
||||
'claude-sonnet-4-6': 1000000,
|
||||
'claude-4': 200000,
|
||||
'claude-opus-4': 200000,
|
||||
'claude-opus-4-5': 200000,
|
||||
|
|
@ -197,6 +198,8 @@ const moonshotModels = {
|
|||
'moonshot.kimi-k2.5': 262144,
|
||||
'moonshot.kimi-k2-thinking': 262144,
|
||||
'moonshot.kimi-k2-0711': 131072,
|
||||
'moonshotai.kimi': 262144,
|
||||
'moonshotai.kimi-k2.5': 262144,
|
||||
};
|
||||
|
||||
const metaModels = {
|
||||
|
|
@ -308,6 +311,11 @@ const amazonModels = {
|
|||
'nova-premier': 995000, // -5000 from max
|
||||
};
|
||||
|
||||
const openAIBedrockModels = {
|
||||
'openai.gpt-oss-20b': 128000,
|
||||
'openai.gpt-oss-120b': 128000,
|
||||
};
|
||||
|
||||
const bedrockModels = {
|
||||
...anthropicModels,
|
||||
...mistralModels,
|
||||
|
|
@ -317,6 +325,7 @@ const bedrockModels = {
|
|||
...metaModels,
|
||||
...ai21Models,
|
||||
...amazonModels,
|
||||
...openAIBedrockModels,
|
||||
};
|
||||
|
||||
const xAIModels = {
|
||||
|
|
@ -393,6 +402,7 @@ const anthropicMaxOutputs = {
|
|||
'claude-3-opus': 4096,
|
||||
'claude-haiku-4-5': 64000,
|
||||
'claude-sonnet-4': 64000,
|
||||
'claude-sonnet-4-6': 64000,
|
||||
'claude-opus-4': 32000,
|
||||
'claude-opus-4-5': 64000,
|
||||
'claude-opus-4-6': 128000,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import { forwardRef, ReactNode, Ref } from 'react';
|
||||
import { forwardRef, isValidElement, ReactNode, Ref } from 'react';
|
||||
import {
|
||||
OGDialogTitle,
|
||||
OGDialogClose,
|
||||
|
|
@ -19,13 +19,39 @@ type SelectionProps = {
|
|||
isLoading?: boolean;
|
||||
};
|
||||
|
||||
/**
|
||||
* Type guard to check if selection is a legacy SelectionProps object
|
||||
*/
|
||||
function isSelectionProps(selection: unknown): selection is SelectionProps {
|
||||
return (
|
||||
typeof selection === 'object' &&
|
||||
selection !== null &&
|
||||
!isValidElement(selection) &&
|
||||
('selectHandler' in selection ||
|
||||
'selectClasses' in selection ||
|
||||
'selectText' in selection ||
|
||||
'isLoading' in selection)
|
||||
);
|
||||
}
|
||||
|
||||
type DialogTemplateProps = {
|
||||
title: string;
|
||||
description?: string;
|
||||
main?: ReactNode;
|
||||
buttons?: ReactNode;
|
||||
leftButtons?: ReactNode;
|
||||
selection?: SelectionProps;
|
||||
/**
|
||||
* Selection button configuration. Can be either:
|
||||
* - An object with selectHandler, selectClasses, selectText, isLoading (legacy)
|
||||
* - A ReactNode for custom selection component
|
||||
* @example
|
||||
* // Legacy usage
|
||||
* selection={{ selectHandler: () => {}, selectText: 'Confirm' }}
|
||||
* @example
|
||||
* // Custom component
|
||||
* selection={<Button onClick={handleConfirm}>Confirm</Button>}
|
||||
*/
|
||||
selection?: SelectionProps | ReactNode;
|
||||
className?: string;
|
||||
overlayClassName?: string;
|
||||
headerClassName?: string;
|
||||
|
|
@ -49,14 +75,39 @@ const OGDialogTemplate = forwardRef((props: DialogTemplateProps, ref: Ref<HTMLDi
|
|||
mainClassName,
|
||||
headerClassName,
|
||||
footerClassName,
|
||||
showCloseButton,
|
||||
showCloseButton = false,
|
||||
overlayClassName,
|
||||
showCancelButton = true,
|
||||
} = props;
|
||||
const { selectHandler, selectClasses, selectText, isLoading } = selection || {};
|
||||
const isLegacySelection = isSelectionProps(selection);
|
||||
const { selectHandler, selectClasses, selectText, isLoading } = isLegacySelection
|
||||
? selection
|
||||
: {};
|
||||
|
||||
const defaultSelect =
|
||||
'bg-gray-800 text-white transition-colors hover:bg-gray-700 disabled:cursor-not-allowed disabled:opacity-50 dark:bg-gray-200 dark:text-gray-800 dark:hover:bg-gray-200';
|
||||
|
||||
let selectionContent = null;
|
||||
if (isLegacySelection) {
|
||||
selectionContent = (
|
||||
<OGDialogClose
|
||||
onClick={selectHandler}
|
||||
disabled={isLoading}
|
||||
className={`${
|
||||
selectClasses ?? defaultSelect
|
||||
} flex h-10 items-center justify-center rounded-lg border-none px-4 py-2 text-sm disabled:opacity-80 max-sm:order-first max-sm:w-full sm:order-none`}
|
||||
>
|
||||
{isLoading === true ? (
|
||||
<Spinner className="size-4 text-text-primary" />
|
||||
) : (
|
||||
(selectText as React.JSX.Element)
|
||||
)}
|
||||
</OGDialogClose>
|
||||
);
|
||||
} else if (selection) {
|
||||
selectionContent = selection;
|
||||
}
|
||||
|
||||
return (
|
||||
<OGDialogContent
|
||||
overlayClassName={overlayClassName}
|
||||
|
|
@ -75,38 +126,18 @@ const OGDialogTemplate = forwardRef((props: DialogTemplateProps, ref: Ref<HTMLDi
|
|||
</OGDialogHeader>
|
||||
<div className={cn('px-0 py-2', mainClassName)}>{main != null ? main : null}</div>
|
||||
<OGDialogFooter className={footerClassName}>
|
||||
<div>
|
||||
{leftButtons != null ? (
|
||||
<div className="mt-3 flex h-auto gap-3 max-sm:w-full max-sm:flex-col sm:mt-0 sm:flex-row">
|
||||
{leftButtons}
|
||||
</div>
|
||||
) : null}
|
||||
</div>
|
||||
<div className="flex h-auto gap-3 max-sm:w-full max-sm:flex-col sm:flex-row">
|
||||
{showCancelButton && (
|
||||
<OGDialogClose asChild>
|
||||
<Button variant="outline" aria-label={localize('com_ui_cancel')}>
|
||||
{localize('com_ui_cancel')}
|
||||
</Button>
|
||||
</OGDialogClose>
|
||||
)}
|
||||
{buttons != null ? buttons : null}
|
||||
{selection ? (
|
||||
<OGDialogClose
|
||||
onClick={selectHandler}
|
||||
disabled={isLoading}
|
||||
className={`${
|
||||
selectClasses ?? defaultSelect
|
||||
} flex h-10 items-center justify-center rounded-lg border-none px-4 py-2 text-sm disabled:opacity-80 max-sm:order-first max-sm:w-full sm:order-none`}
|
||||
>
|
||||
{isLoading === true ? (
|
||||
<Spinner className="size-4 text-white" />
|
||||
) : (
|
||||
(selectText as React.JSX.Element)
|
||||
)}
|
||||
</OGDialogClose>
|
||||
) : null}
|
||||
</div>
|
||||
{leftButtons != null ? (
|
||||
<div className="mr-auto flex flex-row gap-2">{leftButtons}</div>
|
||||
) : null}
|
||||
{showCancelButton && (
|
||||
<OGDialogClose asChild>
|
||||
<Button variant="outline" aria-label={localize('com_ui_cancel')}>
|
||||
{localize('com_ui_cancel')}
|
||||
</Button>
|
||||
</OGDialogClose>
|
||||
)}
|
||||
{buttons != null ? buttons : null}
|
||||
{selectionContent}
|
||||
</OGDialogFooter>
|
||||
</OGDialogContent>
|
||||
);
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ interface RadioProps {
|
|||
disabled?: boolean;
|
||||
className?: string;
|
||||
fullWidth?: boolean;
|
||||
'aria-labelledby'?: string;
|
||||
}
|
||||
|
||||
const Radio = memo(function Radio({
|
||||
|
|
@ -23,6 +24,7 @@ const Radio = memo(function Radio({
|
|||
disabled = false,
|
||||
className = '',
|
||||
fullWidth = false,
|
||||
'aria-labelledby': ariaLabelledBy,
|
||||
}: RadioProps) {
|
||||
const localize = useLocalize();
|
||||
const buttonRefs = useRef<(HTMLButtonElement | null)[]>([]);
|
||||
|
|
@ -79,6 +81,7 @@ const Radio = memo(function Radio({
|
|||
<div
|
||||
className="relative inline-flex items-center rounded-lg bg-muted p-1 opacity-50"
|
||||
role="radiogroup"
|
||||
aria-labelledby={ariaLabelledBy}
|
||||
>
|
||||
<span className="px-4 py-2 text-xs text-muted-foreground">
|
||||
{localize('com_ui_no_options')}
|
||||
|
|
@ -93,6 +96,7 @@ const Radio = memo(function Radio({
|
|||
<div
|
||||
className={`relative ${fullWidth ? 'flex' : 'inline-flex'} items-center rounded-lg bg-muted p-1 ${className}`}
|
||||
role="radiogroup"
|
||||
aria-labelledby={ariaLabelledBy}
|
||||
>
|
||||
{selectedIndex >= 0 && isMounted && (
|
||||
<div
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ const Theme = ({ theme, onChange }: { theme: string; onChange: (value: string) =
|
|||
|
||||
const themeIcons: Record<ThemeType, JSX.Element> = {
|
||||
system: <Monitor aria-hidden="true" />,
|
||||
dark: <Moon color="white" aria-hidden="true" />,
|
||||
dark: <Moon aria-hidden="true" />,
|
||||
light: <Sun aria-hidden="true" />,
|
||||
};
|
||||
|
||||
|
|
@ -35,7 +35,7 @@ const Theme = ({ theme, onChange }: { theme: string; onChange: (value: string) =
|
|||
|
||||
return (
|
||||
<button
|
||||
className="flex items-center gap-2 rounded-lg p-2 transition-colors hover:bg-surface-hover focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-blue-600 focus-visible:ring-offset-2 dark:focus-visible:ring-0"
|
||||
className="flex items-center gap-2 rounded-lg p-2 text-text-primary transition-colors hover:bg-surface-hover focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-blue-600 focus-visible:ring-offset-2 dark:focus-visible:ring-0"
|
||||
aria-label={localize('com_ui_toggle_theme')}
|
||||
aria-keyshortcuts="Ctrl+Shift+T"
|
||||
onClick={(e) => {
|
||||
|
|
@ -77,13 +77,6 @@ const ThemeSelector = ({ returnThemeOnly }: { returnThemeOnly?: boolean }) => {
|
|||
[setTheme, localize],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (theme === 'system') {
|
||||
const prefersDarkScheme = window.matchMedia('(prefers-color-scheme: dark)').matches;
|
||||
setTheme(prefersDarkScheme ? 'dark' : 'light');
|
||||
}
|
||||
}, [theme, setTheme]);
|
||||
|
||||
useEffect(() => {
|
||||
if (announcement) {
|
||||
const timeout = setTimeout(() => setAnnouncement(''), 1000);
|
||||
|
|
|
|||
|
|
@ -1,17 +1,18 @@
|
|||
// This file is kept for backward compatibility but is no longer used internally.
|
||||
// Theme state is now managed via React useState + localStorage in ThemeProvider.
|
||||
|
||||
import { atomWithStorage } from 'jotai/utils';
|
||||
import { IThemeRGB } from '../types';
|
||||
|
||||
/**
|
||||
* Atom for storing the theme mode (light/dark/system) in localStorage
|
||||
* Key: 'color-theme'
|
||||
* @deprecated Use ThemeContext instead. This atom is no longer used internally.
|
||||
*/
|
||||
export const themeModeAtom = atomWithStorage<string>('color-theme', 'system', undefined, {
|
||||
getOnInit: true,
|
||||
});
|
||||
|
||||
/**
|
||||
* Atom for storing custom theme colors in localStorage
|
||||
* Key: 'theme-colors'
|
||||
* @deprecated Use ThemeContext instead. This atom is no longer used internally.
|
||||
*/
|
||||
export const themeColorsAtom = atomWithStorage<IThemeRGB | undefined>(
|
||||
'theme-colors',
|
||||
|
|
@ -23,8 +24,7 @@ export const themeColorsAtom = atomWithStorage<IThemeRGB | undefined>(
|
|||
);
|
||||
|
||||
/**
|
||||
* Atom for storing the theme name in localStorage
|
||||
* Key: 'theme-name'
|
||||
* @deprecated Use ThemeContext instead. This atom is no longer used internally.
|
||||
*/
|
||||
export const themeNameAtom = atomWithStorage<string | undefined>(
|
||||
'theme-name',
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
import React, { createContext, useContext, useEffect, useMemo, useCallback, useRef } from 'react';
|
||||
import { useAtom } from 'jotai';
|
||||
import React, { createContext, useContext, useEffect, useMemo, useCallback, useState, useRef } from 'react';
|
||||
import { IThemeRGB } from '../types';
|
||||
import applyTheme from '../utils/applyTheme';
|
||||
import { themeModeAtom, themeColorsAtom, themeNameAtom } from '../atoms/themeAtoms';
|
||||
|
||||
const THEME_KEY = 'color-theme';
|
||||
const THEME_COLORS_KEY = 'theme-colors';
|
||||
const THEME_NAME_KEY = 'theme-name';
|
||||
|
||||
type ThemeContextType = {
|
||||
theme: string; // 'light' | 'dark' | 'system'
|
||||
|
|
@ -40,6 +42,70 @@ export const isDark = (theme: string): boolean => {
|
|||
return theme === 'dark';
|
||||
};
|
||||
|
||||
/**
|
||||
* Validate that a parsed value looks like an IThemeRGB object
|
||||
*/
|
||||
const isValidThemeColors = (value: unknown): value is IThemeRGB => {
|
||||
if (typeof value !== 'object' || value === null || Array.isArray(value)) {
|
||||
return false;
|
||||
}
|
||||
for (const key of Object.keys(value)) {
|
||||
const val = (value as Record<string, unknown>)[key];
|
||||
if (val !== undefined && typeof val !== 'string') {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
/**
|
||||
* Get initial theme from localStorage or default to 'system'
|
||||
*/
|
||||
const getInitialTheme = (): string => {
|
||||
if (typeof window === 'undefined') return 'system';
|
||||
try {
|
||||
const stored = localStorage.getItem(THEME_KEY);
|
||||
if (stored && ['light', 'dark', 'system'].includes(stored)) {
|
||||
return stored;
|
||||
}
|
||||
} catch {
|
||||
// localStorage not available
|
||||
}
|
||||
return 'system';
|
||||
};
|
||||
|
||||
/**
|
||||
* Get initial theme colors from localStorage
|
||||
*/
|
||||
const getInitialThemeColors = (): IThemeRGB | undefined => {
|
||||
if (typeof window === 'undefined') return undefined;
|
||||
try {
|
||||
const stored = localStorage.getItem(THEME_COLORS_KEY);
|
||||
if (stored) {
|
||||
const parsed = JSON.parse(stored);
|
||||
if (isValidThemeColors(parsed)) {
|
||||
return parsed;
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// localStorage not available or invalid JSON
|
||||
}
|
||||
return undefined;
|
||||
};
|
||||
|
||||
/**
|
||||
* Get initial theme name from localStorage
|
||||
*/
|
||||
const getInitialThemeName = (): string | undefined => {
|
||||
if (typeof window === 'undefined') return undefined;
|
||||
try {
|
||||
return localStorage.getItem(THEME_NAME_KEY) || undefined;
|
||||
} catch {
|
||||
// localStorage not available
|
||||
}
|
||||
return undefined;
|
||||
};
|
||||
|
||||
/**
|
||||
* ThemeProvider component that handles both dark/light mode switching
|
||||
* and dynamic color themes via CSS variables with localStorage persistence
|
||||
|
|
@ -50,102 +116,128 @@ export function ThemeProvider({
|
|||
themeName: propThemeName,
|
||||
initialTheme,
|
||||
}: ThemeProviderProps) {
|
||||
// Use jotai atoms for persistent state
|
||||
const [theme, setTheme] = useAtom(themeModeAtom);
|
||||
const [storedThemeRGB, setStoredThemeRGB] = useAtom(themeColorsAtom);
|
||||
const [storedThemeName, setStoredThemeName] = useAtom(themeNameAtom);
|
||||
const [theme, setThemeState] = useState<string>(getInitialTheme);
|
||||
const [themeRGB, setThemeRGBState] = useState<IThemeRGB | undefined>(getInitialThemeColors);
|
||||
const [themeName, setThemeNameState] = useState<string | undefined>(getInitialThemeName);
|
||||
|
||||
// Track if props have been initialized
|
||||
const propsInitialized = useRef(false);
|
||||
const initialized = useRef(false);
|
||||
|
||||
const setTheme = useCallback((newTheme: string) => {
|
||||
setThemeState(newTheme);
|
||||
if (typeof window === 'undefined') return;
|
||||
try {
|
||||
localStorage.setItem(THEME_KEY, newTheme);
|
||||
} catch {
|
||||
// localStorage not available
|
||||
}
|
||||
}, []);
|
||||
|
||||
const setThemeRGB = useCallback((colors?: IThemeRGB) => {
|
||||
setThemeRGBState(colors);
|
||||
if (typeof window === 'undefined') return;
|
||||
try {
|
||||
if (colors) {
|
||||
localStorage.setItem(THEME_COLORS_KEY, JSON.stringify(colors));
|
||||
} else {
|
||||
localStorage.removeItem(THEME_COLORS_KEY);
|
||||
}
|
||||
} catch {
|
||||
// localStorage not available
|
||||
}
|
||||
}, []);
|
||||
|
||||
const setThemeName = useCallback((name?: string) => {
|
||||
setThemeNameState(name);
|
||||
if (typeof window === 'undefined') return;
|
||||
try {
|
||||
if (name) {
|
||||
localStorage.setItem(THEME_NAME_KEY, name);
|
||||
} else {
|
||||
localStorage.removeItem(THEME_NAME_KEY);
|
||||
}
|
||||
} catch {
|
||||
// localStorage not available
|
||||
}
|
||||
}, []);
|
||||
|
||||
// Initialize from props only once on mount
|
||||
useEffect(() => {
|
||||
if (!propsInitialized.current) {
|
||||
propsInitialized.current = true;
|
||||
if (initialized.current) return;
|
||||
initialized.current = true;
|
||||
|
||||
// Set initial theme if provided
|
||||
if (initialTheme) {
|
||||
setTheme(initialTheme);
|
||||
}
|
||||
|
||||
// Set initial theme colors if provided
|
||||
if (propThemeRGB) {
|
||||
setStoredThemeRGB(propThemeRGB);
|
||||
}
|
||||
|
||||
// Set initial theme name if provided
|
||||
if (propThemeName) {
|
||||
setStoredThemeName(propThemeName);
|
||||
}
|
||||
// Set initial theme if provided
|
||||
if (initialTheme) {
|
||||
setTheme(initialTheme);
|
||||
}
|
||||
}, [initialTheme, propThemeRGB, propThemeName, setTheme, setStoredThemeRGB, setStoredThemeName]);
|
||||
|
||||
// Set initial theme colors if provided
|
||||
if (propThemeRGB) {
|
||||
setThemeRGB(propThemeRGB);
|
||||
}
|
||||
|
||||
// Set initial theme name if provided
|
||||
if (propThemeName) {
|
||||
setThemeName(propThemeName);
|
||||
}
|
||||
}, [initialTheme, propThemeRGB, propThemeName, setTheme, setThemeRGB, setThemeName]);
|
||||
|
||||
// Apply class-based dark mode
|
||||
const applyThemeMode = useCallback((rawTheme: string) => {
|
||||
const applyThemeMode = useCallback((currentTheme: string) => {
|
||||
const root = window.document.documentElement;
|
||||
const darkMode = isDark(rawTheme);
|
||||
const darkMode = isDark(currentTheme);
|
||||
|
||||
root.classList.remove(darkMode ? 'light' : 'dark');
|
||||
root.classList.add(darkMode ? 'dark' : 'light');
|
||||
}, []);
|
||||
|
||||
// Handle system theme changes
|
||||
useEffect(() => {
|
||||
const mediaQuery = window.matchMedia('(prefers-color-scheme: dark)');
|
||||
const changeThemeOnSystemChange = () => {
|
||||
if (theme === 'system') {
|
||||
applyThemeMode('system');
|
||||
}
|
||||
};
|
||||
|
||||
mediaQuery.addEventListener('change', changeThemeOnSystemChange);
|
||||
return () => {
|
||||
mediaQuery.removeEventListener('change', changeThemeOnSystemChange);
|
||||
};
|
||||
}, [theme, applyThemeMode]);
|
||||
|
||||
// Apply dark/light mode class
|
||||
// Apply theme mode whenever theme changes
|
||||
useEffect(() => {
|
||||
applyThemeMode(theme);
|
||||
}, [theme, applyThemeMode]);
|
||||
|
||||
// Listen for system theme changes when theme is 'system'
|
||||
useEffect(() => {
|
||||
if (theme !== 'system') return;
|
||||
|
||||
const mediaQuery = window.matchMedia('(prefers-color-scheme: dark)');
|
||||
const handleChange = () => {
|
||||
applyThemeMode('system');
|
||||
};
|
||||
|
||||
mediaQuery.addEventListener('change', handleChange);
|
||||
return () => mediaQuery.removeEventListener('change', handleChange);
|
||||
}, [theme, applyThemeMode]);
|
||||
|
||||
// Apply dynamic color theme
|
||||
useEffect(() => {
|
||||
if (storedThemeRGB) {
|
||||
applyTheme(storedThemeRGB);
|
||||
if (themeRGB) {
|
||||
applyTheme(themeRGB);
|
||||
}
|
||||
}, [storedThemeRGB]);
|
||||
}, [themeRGB]);
|
||||
|
||||
// Reset theme function
|
||||
const resetTheme = useCallback(() => {
|
||||
setTheme('system');
|
||||
setStoredThemeRGB(undefined);
|
||||
setStoredThemeName(undefined);
|
||||
setThemeRGB(undefined);
|
||||
setThemeName(undefined);
|
||||
// Remove any custom CSS variables
|
||||
const root = document.documentElement;
|
||||
const customProps = Array.from(root.style).filter((prop) => prop.startsWith('--'));
|
||||
customProps.forEach((prop) => root.style.removeProperty(prop));
|
||||
}, [setTheme, setStoredThemeRGB, setStoredThemeName]);
|
||||
}, [setTheme, setThemeRGB, setThemeName]);
|
||||
|
||||
const value = useMemo(
|
||||
() => ({
|
||||
theme,
|
||||
setTheme,
|
||||
themeRGB: storedThemeRGB,
|
||||
setThemeRGB: setStoredThemeRGB,
|
||||
themeName: storedThemeName,
|
||||
setThemeName: setStoredThemeName,
|
||||
themeRGB,
|
||||
setThemeRGB,
|
||||
themeName,
|
||||
setThemeName,
|
||||
resetTheme,
|
||||
}),
|
||||
[
|
||||
theme,
|
||||
setTheme,
|
||||
storedThemeRGB,
|
||||
setStoredThemeRGB,
|
||||
storedThemeName,
|
||||
setStoredThemeName,
|
||||
resetTheme,
|
||||
],
|
||||
[theme, setTheme, themeRGB, setThemeRGB, themeName, setThemeName, resetTheme],
|
||||
);
|
||||
|
||||
return <ThemeContext.Provider value={value}>{children}</ThemeContext.Provider>;
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@
|
|||
},
|
||||
"homepage": "https://librechat.ai",
|
||||
"dependencies": {
|
||||
"axios": "^1.12.1",
|
||||
"axios": "^1.13.5",
|
||||
"dayjs": "^1.11.13",
|
||||
"js-yaml": "^4.1.1",
|
||||
"zod": "^3.22.4"
|
||||
|
|
|
|||
18
packages/data-provider/react-query/package-lock.json
generated
18
packages/data-provider/react-query/package-lock.json
generated
|
|
@ -6,7 +6,7 @@
|
|||
"": {
|
||||
"name": "librechat-data-provider/react-query",
|
||||
"dependencies": {
|
||||
"axios": "^1.12.1"
|
||||
"axios": "^1.13.5"
|
||||
}
|
||||
},
|
||||
"node_modules/asynckit": {
|
||||
|
|
@ -16,13 +16,13 @@
|
|||
"license": "MIT"
|
||||
},
|
||||
"node_modules/axios": {
|
||||
"version": "1.12.1",
|
||||
"resolved": "https://registry.npmjs.org/axios/-/axios-1.12.1.tgz",
|
||||
"integrity": "sha512-Kn4kbSXpkFHCGE6rBFNwIv0GQs4AvDT80jlveJDKFxjbTYMUeB4QtsdPCv6H8Cm19Je7IU6VFtRl2zWZI0rudQ==",
|
||||
"version": "1.13.5",
|
||||
"resolved": "https://registry.npmjs.org/axios/-/axios-1.13.5.tgz",
|
||||
"integrity": "sha512-cz4ur7Vb0xS4/KUN0tPWe44eqxrIu31me+fbang3ijiNscE129POzipJJA6zniq2C/Z6sJCjMimjS8Lc/GAs8Q==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"follow-redirects": "^1.15.6",
|
||||
"form-data": "^4.0.4",
|
||||
"follow-redirects": "^1.15.11",
|
||||
"form-data": "^4.0.5",
|
||||
"proxy-from-env": "^1.1.0"
|
||||
}
|
||||
},
|
||||
|
|
@ -140,9 +140,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/form-data": {
|
||||
"version": "4.0.4",
|
||||
"resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.4.tgz",
|
||||
"integrity": "sha512-KrGhL9Q4zjj0kiUt5OO4Mr/A/jlI2jDYs5eHBpYHPcBEVSiipAvn2Ko2HnPe20rmcuuvMHNdZFp+4IlGTMF0Ow==",
|
||||
"version": "4.0.5",
|
||||
"resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.5.tgz",
|
||||
"integrity": "sha512-8RipRLol37bNs2bhoV67fiTEvdTrbMUYcFTiy3+wuuOnUog2QBHCZWXDRijWQfAkhBj2Uf5UnVaiWwA5vdd82w==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"asynckit": "^0.4.0",
|
||||
|
|
|
|||
|
|
@ -5,6 +5,6 @@
|
|||
"module": "./index.es.js",
|
||||
"types": "../dist/types/react-query/index.d.ts",
|
||||
"dependencies": {
|
||||
"axios": "^1.12.1"
|
||||
"axios": "^1.13.5"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -459,6 +459,82 @@ describe('ActionRequest', () => {
|
|||
await expect(actionRequest.execute()).rejects.toThrow('Unsupported HTTP method: invalid');
|
||||
});
|
||||
|
||||
describe('SSRF-safe agent passthrough', () => {
|
||||
beforeEach(() => {
|
||||
mockedAxios.get.mockResolvedValue({ data: { success: true } });
|
||||
mockedAxios.post.mockResolvedValue({ data: { success: true } });
|
||||
});
|
||||
|
||||
it('should pass httpAgent and httpsAgent to axios.create when provided', async () => {
|
||||
const mockHttpAgent = { keepAlive: true };
|
||||
const mockHttpsAgent = { keepAlive: true };
|
||||
|
||||
const actionRequest = new ActionRequest(
|
||||
'https://example.com',
|
||||
'/test',
|
||||
'GET',
|
||||
'testOp',
|
||||
false,
|
||||
'application/json',
|
||||
);
|
||||
const executor = actionRequest.createExecutor();
|
||||
executor.setParams({ key: 'value' });
|
||||
await executor.execute({ httpAgent: mockHttpAgent, httpsAgent: mockHttpsAgent });
|
||||
|
||||
expect(mockedAxios.create).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
httpAgent: mockHttpAgent,
|
||||
httpsAgent: mockHttpsAgent,
|
||||
maxRedirects: 0,
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should not include agent keys when no options are provided', async () => {
|
||||
const actionRequest = new ActionRequest(
|
||||
'https://example.com',
|
||||
'/test',
|
||||
'GET',
|
||||
'testOp',
|
||||
false,
|
||||
'application/json',
|
||||
);
|
||||
const executor = actionRequest.createExecutor();
|
||||
executor.setParams({ key: 'value' });
|
||||
await executor.execute();
|
||||
|
||||
const createArg = mockedAxios.create.mock.calls[
|
||||
mockedAxios.create.mock.calls.length - 1
|
||||
][0] as Record<string, unknown>;
|
||||
expect(createArg).not.toHaveProperty('httpAgent');
|
||||
expect(createArg).not.toHaveProperty('httpsAgent');
|
||||
});
|
||||
|
||||
it('should pass agents through for POST requests', async () => {
|
||||
const mockAgent = { ssrf: true };
|
||||
|
||||
const actionRequest = new ActionRequest(
|
||||
'https://example.com',
|
||||
'/test',
|
||||
'POST',
|
||||
'testOp',
|
||||
false,
|
||||
'application/json',
|
||||
);
|
||||
const executor = actionRequest.createExecutor();
|
||||
executor.setParams({ body: 'data' });
|
||||
await executor.execute({ httpAgent: mockAgent, httpsAgent: mockAgent });
|
||||
|
||||
expect(mockedAxios.create).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
httpAgent: mockAgent,
|
||||
httpsAgent: mockAgent,
|
||||
}),
|
||||
);
|
||||
expect(mockedAxios.post).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('ActionRequest Concurrent Execution', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
|
|
|
|||
|
|
@ -46,6 +46,30 @@ describe('supportsAdaptiveThinking', () => {
|
|||
expect(supportsAdaptiveThinking('claude-opus-4-0')).toBe(false);
|
||||
});
|
||||
|
||||
test('should return true for claude-sonnet-4-6', () => {
|
||||
expect(supportsAdaptiveThinking('claude-sonnet-4-6')).toBe(true);
|
||||
});
|
||||
|
||||
test('should return true for claude-sonnet-4.6', () => {
|
||||
expect(supportsAdaptiveThinking('claude-sonnet-4.6')).toBe(true);
|
||||
});
|
||||
|
||||
test('should return true for claude-sonnet-4-7 (future)', () => {
|
||||
expect(supportsAdaptiveThinking('claude-sonnet-4-7')).toBe(true);
|
||||
});
|
||||
|
||||
test('should return true for anthropic.claude-sonnet-4-6 (Bedrock)', () => {
|
||||
expect(supportsAdaptiveThinking('anthropic.claude-sonnet-4-6')).toBe(true);
|
||||
});
|
||||
|
||||
test('should return true for us.anthropic.claude-sonnet-4-6 (cross-region Bedrock)', () => {
|
||||
expect(supportsAdaptiveThinking('us.anthropic.claude-sonnet-4-6')).toBe(true);
|
||||
});
|
||||
|
||||
test('should return true for claude-4-6-sonnet (alternate naming)', () => {
|
||||
expect(supportsAdaptiveThinking('claude-4-6-sonnet')).toBe(true);
|
||||
});
|
||||
|
||||
test('should return false for claude-sonnet-4-5', () => {
|
||||
expect(supportsAdaptiveThinking('claude-sonnet-4-5')).toBe(false);
|
||||
});
|
||||
|
|
@ -104,6 +128,14 @@ describe('supportsContext1m', () => {
|
|||
expect(supportsContext1m('claude-sonnet-4-5')).toBe(true);
|
||||
});
|
||||
|
||||
test('should return true for claude-sonnet-4-6', () => {
|
||||
expect(supportsContext1m('claude-sonnet-4-6')).toBe(true);
|
||||
});
|
||||
|
||||
test('should return true for anthropic.claude-sonnet-4-6 (Bedrock)', () => {
|
||||
expect(supportsContext1m('anthropic.claude-sonnet-4-6')).toBe(true);
|
||||
});
|
||||
|
||||
test('should return true for claude-sonnet-5 (future)', () => {
|
||||
expect(supportsContext1m('claude-sonnet-5')).toBe(true);
|
||||
});
|
||||
|
|
@ -237,14 +269,42 @@ describe('bedrockInputParser', () => {
|
|||
]);
|
||||
});
|
||||
|
||||
test('should match anthropic.claude-4-7-sonnet model with 1M context header', () => {
|
||||
test('should match anthropic.claude-sonnet-4-6 with adaptive thinking and 1M context header', () => {
|
||||
const input = {
|
||||
model: 'anthropic.claude-sonnet-4-6',
|
||||
};
|
||||
const result = bedrockInputParser.parse(input) as Record<string, unknown>;
|
||||
const additionalFields = result.additionalModelRequestFields as Record<string, unknown>;
|
||||
expect(additionalFields.thinking).toEqual({ type: 'adaptive' });
|
||||
expect(additionalFields.thinkingBudget).toBeUndefined();
|
||||
expect(additionalFields.anthropic_beta).toEqual([
|
||||
'output-128k-2025-02-19',
|
||||
'context-1m-2025-08-07',
|
||||
]);
|
||||
});
|
||||
|
||||
test('should match us.anthropic.claude-sonnet-4-6 with adaptive thinking and 1M context header', () => {
|
||||
const input = {
|
||||
model: 'us.anthropic.claude-sonnet-4-6',
|
||||
};
|
||||
const result = bedrockInputParser.parse(input) as Record<string, unknown>;
|
||||
const additionalFields = result.additionalModelRequestFields as Record<string, unknown>;
|
||||
expect(additionalFields.thinking).toEqual({ type: 'adaptive' });
|
||||
expect(additionalFields.thinkingBudget).toBeUndefined();
|
||||
expect(additionalFields.anthropic_beta).toEqual([
|
||||
'output-128k-2025-02-19',
|
||||
'context-1m-2025-08-07',
|
||||
]);
|
||||
});
|
||||
|
||||
test('should match anthropic.claude-4-7-sonnet model with adaptive thinking and 1M context header', () => {
|
||||
const input = {
|
||||
model: 'anthropic.claude-4-7-sonnet',
|
||||
};
|
||||
const result = bedrockInputParser.parse(input) as Record<string, unknown>;
|
||||
const additionalFields = result.additionalModelRequestFields as Record<string, unknown>;
|
||||
expect(additionalFields.thinking).toBe(true);
|
||||
expect(additionalFields.thinkingBudget).toBe(2000);
|
||||
expect(additionalFields.thinking).toEqual({ type: 'adaptive' });
|
||||
expect(additionalFields.thinkingBudget).toBeUndefined();
|
||||
expect(additionalFields.anthropic_beta).toEqual([
|
||||
'output-128k-2025-02-19',
|
||||
'context-1m-2025-08-07',
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import { replaceSpecialVars, parseCompactConvo, parseTextParts } from '../src/parsers';
|
||||
import { replaceSpecialVars, parseConvo, parseCompactConvo, parseTextParts } from '../src/parsers';
|
||||
import { specialVariables } from '../src/config';
|
||||
import { EModelEndpoint } from '../src/schemas';
|
||||
import { ContentTypes } from '../src/types/runs';
|
||||
|
|
@ -262,6 +262,257 @@ describe('parseCompactConvo', () => {
|
|||
});
|
||||
});
|
||||
|
||||
describe('parseConvo - defaultParamsEndpoint', () => {
|
||||
test('should strip maxOutputTokens for custom endpoint without defaultParamsEndpoint', () => {
|
||||
const conversation: Partial<TConversation> = {
|
||||
model: 'anthropic/claude-opus-4.5',
|
||||
temperature: 0.7,
|
||||
maxOutputTokens: 8192,
|
||||
maxContextTokens: 50000,
|
||||
};
|
||||
|
||||
const result = parseConvo({
|
||||
endpoint: 'MyCustomEndpoint' as EModelEndpoint,
|
||||
endpointType: EModelEndpoint.custom,
|
||||
conversation,
|
||||
});
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result?.temperature).toBe(0.7);
|
||||
expect(result?.maxContextTokens).toBe(50000);
|
||||
expect(result?.maxOutputTokens).toBeUndefined();
|
||||
});
|
||||
|
||||
test('should preserve maxOutputTokens when defaultParamsEndpoint is anthropic', () => {
|
||||
const conversation: Partial<TConversation> = {
|
||||
model: 'anthropic/claude-opus-4.5',
|
||||
temperature: 0.7,
|
||||
maxOutputTokens: 8192,
|
||||
topP: 0.9,
|
||||
topK: 40,
|
||||
maxContextTokens: 50000,
|
||||
};
|
||||
|
||||
const result = parseConvo({
|
||||
endpoint: 'MyCustomEndpoint' as EModelEndpoint,
|
||||
endpointType: EModelEndpoint.custom,
|
||||
conversation,
|
||||
defaultParamsEndpoint: EModelEndpoint.anthropic,
|
||||
});
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result?.model).toBe('anthropic/claude-opus-4.5');
|
||||
expect(result?.temperature).toBe(0.7);
|
||||
expect(result?.maxOutputTokens).toBe(8192);
|
||||
expect(result?.topP).toBe(0.9);
|
||||
expect(result?.topK).toBe(40);
|
||||
expect(result?.maxContextTokens).toBe(50000);
|
||||
});
|
||||
|
||||
test('should strip OpenAI-specific fields when defaultParamsEndpoint is anthropic', () => {
|
||||
const conversation: Partial<TConversation> = {
|
||||
model: 'anthropic/claude-opus-4.5',
|
||||
temperature: 0.7,
|
||||
max_tokens: 4096,
|
||||
top_p: 0.9,
|
||||
presence_penalty: 0.5,
|
||||
frequency_penalty: 0.3,
|
||||
};
|
||||
|
||||
const result = parseConvo({
|
||||
endpoint: 'MyCustomEndpoint' as EModelEndpoint,
|
||||
endpointType: EModelEndpoint.custom,
|
||||
conversation,
|
||||
defaultParamsEndpoint: EModelEndpoint.anthropic,
|
||||
});
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result?.temperature).toBe(0.7);
|
||||
expect(result?.max_tokens).toBeUndefined();
|
||||
expect(result?.top_p).toBeUndefined();
|
||||
expect(result?.presence_penalty).toBeUndefined();
|
||||
expect(result?.frequency_penalty).toBeUndefined();
|
||||
});
|
||||
|
||||
test('should preserve max_tokens when defaultParamsEndpoint is not set (OpenAI default)', () => {
|
||||
const conversation: Partial<TConversation> = {
|
||||
model: 'gpt-4o',
|
||||
temperature: 0.7,
|
||||
max_tokens: 4096,
|
||||
top_p: 0.9,
|
||||
};
|
||||
|
||||
const result = parseConvo({
|
||||
endpoint: 'MyCustomEndpoint' as EModelEndpoint,
|
||||
endpointType: EModelEndpoint.custom,
|
||||
conversation,
|
||||
});
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result?.max_tokens).toBe(4096);
|
||||
expect(result?.top_p).toBe(0.9);
|
||||
});
|
||||
|
||||
test('should preserve Google-specific fields when defaultParamsEndpoint is google', () => {
|
||||
const conversation: Partial<TConversation> = {
|
||||
model: 'gemini-pro',
|
||||
temperature: 0.7,
|
||||
maxOutputTokens: 8192,
|
||||
topP: 0.9,
|
||||
topK: 40,
|
||||
};
|
||||
|
||||
const result = parseConvo({
|
||||
endpoint: 'MyCustomEndpoint' as EModelEndpoint,
|
||||
endpointType: EModelEndpoint.custom,
|
||||
conversation,
|
||||
defaultParamsEndpoint: EModelEndpoint.google,
|
||||
});
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result?.maxOutputTokens).toBe(8192);
|
||||
expect(result?.topP).toBe(0.9);
|
||||
expect(result?.topK).toBe(40);
|
||||
});
|
||||
|
||||
test('should not strip fields from non-custom endpoints that already have a schema', () => {
|
||||
const conversation: Partial<TConversation> = {
|
||||
model: 'gpt-4o',
|
||||
temperature: 0.7,
|
||||
max_tokens: 4096,
|
||||
top_p: 0.9,
|
||||
};
|
||||
|
||||
const result = parseConvo({
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
conversation,
|
||||
defaultParamsEndpoint: EModelEndpoint.anthropic,
|
||||
});
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result?.max_tokens).toBe(4096);
|
||||
expect(result?.top_p).toBe(0.9);
|
||||
});
|
||||
|
||||
test('should not carry bedrock region to custom endpoint without defaultParamsEndpoint', () => {
|
||||
const conversation: Partial<TConversation> = {
|
||||
model: 'gpt-4o',
|
||||
temperature: 0.7,
|
||||
region: 'us-east-1',
|
||||
};
|
||||
|
||||
const result = parseConvo({
|
||||
endpoint: 'MyCustomEndpoint' as EModelEndpoint,
|
||||
endpointType: EModelEndpoint.custom,
|
||||
conversation,
|
||||
});
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result?.temperature).toBe(0.7);
|
||||
expect(result?.region).toBeUndefined();
|
||||
});
|
||||
|
||||
test('should fall back to endpointType schema when defaultParamsEndpoint is invalid', () => {
|
||||
const conversation: Partial<TConversation> = {
|
||||
model: 'gpt-4o',
|
||||
temperature: 0.7,
|
||||
max_tokens: 4096,
|
||||
maxOutputTokens: 8192,
|
||||
};
|
||||
|
||||
const result = parseConvo({
|
||||
endpoint: 'MyCustomEndpoint' as EModelEndpoint,
|
||||
endpointType: EModelEndpoint.custom,
|
||||
conversation,
|
||||
defaultParamsEndpoint: 'nonexistent_endpoint',
|
||||
});
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result?.max_tokens).toBe(4096);
|
||||
expect(result?.maxOutputTokens).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('parseCompactConvo - defaultParamsEndpoint', () => {
|
||||
test('should strip maxOutputTokens for custom endpoint without defaultParamsEndpoint', () => {
|
||||
const conversation: Partial<TConversation> = {
|
||||
model: 'anthropic/claude-opus-4.5',
|
||||
temperature: 0.7,
|
||||
maxOutputTokens: 8192,
|
||||
};
|
||||
|
||||
const result = parseCompactConvo({
|
||||
endpoint: 'MyCustomEndpoint' as EModelEndpoint,
|
||||
endpointType: EModelEndpoint.custom,
|
||||
conversation,
|
||||
});
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result?.temperature).toBe(0.7);
|
||||
expect(result?.maxOutputTokens).toBeUndefined();
|
||||
});
|
||||
|
||||
test('should preserve maxOutputTokens when defaultParamsEndpoint is anthropic', () => {
|
||||
const conversation: Partial<TConversation> = {
|
||||
model: 'anthropic/claude-opus-4.5',
|
||||
temperature: 0.7,
|
||||
maxOutputTokens: 8192,
|
||||
topP: 0.9,
|
||||
maxContextTokens: 50000,
|
||||
};
|
||||
|
||||
const result = parseCompactConvo({
|
||||
endpoint: 'MyCustomEndpoint' as EModelEndpoint,
|
||||
endpointType: EModelEndpoint.custom,
|
||||
conversation,
|
||||
defaultParamsEndpoint: EModelEndpoint.anthropic,
|
||||
});
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result?.maxOutputTokens).toBe(8192);
|
||||
expect(result?.topP).toBe(0.9);
|
||||
expect(result?.maxContextTokens).toBe(50000);
|
||||
});
|
||||
|
||||
test('should strip iconURL even when defaultParamsEndpoint is set', () => {
|
||||
const conversation: Partial<TConversation> = {
|
||||
model: 'anthropic/claude-opus-4.5',
|
||||
iconURL: 'https://malicious.com/track.png',
|
||||
maxOutputTokens: 8192,
|
||||
};
|
||||
|
||||
const result = parseCompactConvo({
|
||||
endpoint: 'MyCustomEndpoint' as EModelEndpoint,
|
||||
endpointType: EModelEndpoint.custom,
|
||||
conversation,
|
||||
defaultParamsEndpoint: EModelEndpoint.anthropic,
|
||||
});
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result?.['iconURL']).toBeUndefined();
|
||||
expect(result?.maxOutputTokens).toBe(8192);
|
||||
});
|
||||
|
||||
test('should fall back to endpointType when defaultParamsEndpoint is null', () => {
|
||||
const conversation: Partial<TConversation> = {
|
||||
model: 'gpt-4o',
|
||||
max_tokens: 4096,
|
||||
maxOutputTokens: 8192,
|
||||
};
|
||||
|
||||
const result = parseCompactConvo({
|
||||
endpoint: 'MyCustomEndpoint' as EModelEndpoint,
|
||||
endpointType: EModelEndpoint.custom,
|
||||
conversation,
|
||||
defaultParamsEndpoint: null,
|
||||
});
|
||||
|
||||
expect(result).not.toBeNull();
|
||||
expect(result?.max_tokens).toBe(4096);
|
||||
expect(result?.maxOutputTokens).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('parseTextParts', () => {
|
||||
test('should concatenate text parts', () => {
|
||||
const parts: TMessageContentParts[] = [
|
||||
|
|
|
|||
|
|
@ -283,7 +283,7 @@ class RequestExecutor {
|
|||
return this;
|
||||
}
|
||||
|
||||
async execute() {
|
||||
async execute(options?: { httpAgent?: unknown; httpsAgent?: unknown }) {
|
||||
const url = createURL(this.config.domain, this.path);
|
||||
const headers: Record<string, string> = {
|
||||
...this.authHeaders,
|
||||
|
|
@ -300,10 +300,15 @@ class RequestExecutor {
|
|||
*
|
||||
* By setting maxRedirects: 0, we prevent this attack vector.
|
||||
* The action will receive the redirect response (3xx) instead of following it.
|
||||
*
|
||||
* SECURITY: When httpAgent/httpsAgent are provided (SSRF-safe agents), they validate
|
||||
* the DNS-resolved IP at TCP connect time, preventing TOCTOU DNS rebinding attacks.
|
||||
*/
|
||||
const axios = _axios.create({
|
||||
maxRedirects: 0,
|
||||
validateStatus: (status) => status >= 200 && status < 400, // Accept 3xx but don't follow
|
||||
validateStatus: (status) => status >= 200 && status < 400,
|
||||
...(options?.httpAgent != null ? { httpAgent: options.httpAgent } : {}),
|
||||
...(options?.httpsAgent != null ? { httpsAgent: options.httpsAgent } : {}),
|
||||
});
|
||||
|
||||
// Initialize separate containers for query and body parameters.
|
||||
|
|
|
|||
|
|
@ -181,6 +181,11 @@ export const cancelMCPOAuth = (serverName: string) => {
|
|||
return `${BASE_URL}/api/mcp/oauth/cancel/${serverName}`;
|
||||
};
|
||||
|
||||
export const mcpOAuthBind = (serverName: string) => `${BASE_URL}/api/mcp/${serverName}/oauth/bind`;
|
||||
|
||||
export const actionOAuthBind = (actionId: string) =>
|
||||
`${BASE_URL}/api/actions/${actionId}/oauth/bind`;
|
||||
|
||||
export const config = () => `${BASE_URL}/api/config`;
|
||||
|
||||
export const prompts = () => `${BASE_URL}/api/prompts`;
|
||||
|
|
|
|||
|
|
@ -35,27 +35,34 @@ function parseOpusVersion(model: string): { major: number; minor: number } | nul
|
|||
return null;
|
||||
}
|
||||
|
||||
/** Extracts sonnet major version from both naming formats */
|
||||
function parseSonnetVersion(model: string): number | null {
|
||||
const nameFirst = model.match(/claude-sonnet[-.]?(\d+)/);
|
||||
/** Extracts sonnet major/minor version from both naming formats.
|
||||
* Uses single-digit minor capture to avoid matching date suffixes (e.g., -20250514). */
|
||||
function parseSonnetVersion(model: string): { major: number; minor: number } | null {
|
||||
const nameFirst = model.match(/claude-sonnet[-.]?(\d+)(?:[-.](\d)(?!\d))?/);
|
||||
if (nameFirst) {
|
||||
return parseInt(nameFirst[1], 10);
|
||||
return {
|
||||
major: parseInt(nameFirst[1], 10),
|
||||
minor: nameFirst[2] != null ? parseInt(nameFirst[2], 10) : 0,
|
||||
};
|
||||
}
|
||||
const numFirst = model.match(/claude-(\d+)(?:[-.]?\d+)?-sonnet/);
|
||||
const numFirst = model.match(/claude-(\d+)(?:[-.](\d)(?!\d))?-sonnet/);
|
||||
if (numFirst) {
|
||||
return parseInt(numFirst[1], 10);
|
||||
return {
|
||||
major: parseInt(numFirst[1], 10),
|
||||
minor: numFirst[2] != null ? parseInt(numFirst[2], 10) : 0,
|
||||
};
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/** Checks if a model supports adaptive thinking (Opus 4.6+, Sonnet 5+) */
|
||||
/** Checks if a model supports adaptive thinking (Opus 4.6+, Sonnet 4.6+) */
|
||||
export function supportsAdaptiveThinking(model: string): boolean {
|
||||
const opus = parseOpusVersion(model);
|
||||
if (opus && (opus.major > 4 || (opus.major === 4 && opus.minor >= 6))) {
|
||||
return true;
|
||||
}
|
||||
const sonnet = parseSonnetVersion(model);
|
||||
if (sonnet != null && sonnet >= 5) {
|
||||
if (sonnet != null && (sonnet.major > 4 || (sonnet.major === 4 && sonnet.minor >= 6))) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
|
@ -64,7 +71,7 @@ export function supportsAdaptiveThinking(model: string): boolean {
|
|||
/** Checks if a model qualifies for the context-1m beta header (Sonnet 4+, Opus 4.6+, Opus 5+) */
|
||||
export function supportsContext1m(model: string): boolean {
|
||||
const sonnet = parseSonnetVersion(model);
|
||||
if (sonnet != null && sonnet >= 4) {
|
||||
if (sonnet != null && sonnet.major >= 4) {
|
||||
return true;
|
||||
}
|
||||
const opus = parseOpusVersion(model);
|
||||
|
|
|
|||
|
|
@ -1133,6 +1133,7 @@ const sharedOpenAIModels = [
|
|||
];
|
||||
|
||||
const sharedAnthropicModels = [
|
||||
'claude-sonnet-4-6',
|
||||
'claude-opus-4-6',
|
||||
'claude-sonnet-4-5',
|
||||
'claude-sonnet-4-5-20250929',
|
||||
|
|
@ -1154,6 +1155,7 @@ const sharedAnthropicModels = [
|
|||
];
|
||||
|
||||
export const bedrockModels = [
|
||||
'anthropic.claude-sonnet-4-6',
|
||||
'anthropic.claude-opus-4-6-v1',
|
||||
'anthropic.claude-sonnet-4-5-20250929-v1:0',
|
||||
'anthropic.claude-haiku-4-5-20251001-v1:0',
|
||||
|
|
@ -1364,6 +1366,10 @@ export enum CacheKeys {
|
|||
* Key for the config store namespace.
|
||||
*/
|
||||
CONFIG_STORE = 'CONFIG_STORE',
|
||||
/**
|
||||
* Key for the tool cache namespace (plugins, MCP tools, tool definitions).
|
||||
*/
|
||||
TOOL_CACHE = 'TOOL_CACHE',
|
||||
/**
|
||||
* Key for the roles cache.
|
||||
*/
|
||||
|
|
@ -1754,6 +1760,8 @@ export enum Constants {
|
|||
mcp_all = 'sys__all__sys',
|
||||
/** Unique value to indicate clearing MCP servers from UI state. For frontend use only. */
|
||||
mcp_clear = 'sys__clear__sys',
|
||||
/** Key suffix for non-spec user default tool storage */
|
||||
spec_defaults_key = '__defaults__',
|
||||
/**
|
||||
* Unique value to indicate the MCP tool was added to an agent.
|
||||
* This helps inform the UI if the mcp server was previously added.
|
||||
|
|
@ -1904,3 +1912,14 @@ export function getEndpointField<
|
|||
}
|
||||
return config[property];
|
||||
}
|
||||
|
||||
/** Resolves the `defaultParamsEndpoint` for a given endpoint from its custom params config */
|
||||
export function getDefaultParamsEndpoint(
|
||||
endpointsConfig: TEndpointsConfig | undefined | null,
|
||||
endpoint: string | null | undefined,
|
||||
): string | undefined {
|
||||
if (!endpointsConfig || !endpoint) {
|
||||
return undefined;
|
||||
}
|
||||
return endpointsConfig[endpoint]?.customParams?.defaultParamsEndpoint;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -178,6 +178,14 @@ export const reinitializeMCPServer = (serverName: string) => {
|
|||
return request.post(endpoints.mcpReinitialize(serverName));
|
||||
};
|
||||
|
||||
export const bindMCPOAuth = (serverName: string): Promise<{ success: boolean }> => {
|
||||
return request.post(endpoints.mcpOAuthBind(serverName));
|
||||
};
|
||||
|
||||
export const bindActionOAuth = (actionId: string): Promise<{ success: boolean }> => {
|
||||
return request.post(endpoints.actionOAuthBind(actionId));
|
||||
};
|
||||
|
||||
export const getMCPConnectionStatus = (): Promise<q.MCPConnectionStatusResponse> => {
|
||||
return request.get(endpoints.mcpConnectionStatus());
|
||||
};
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ export type TModelSpec = {
|
|||
webSearch?: boolean;
|
||||
fileSearch?: boolean;
|
||||
executeCode?: boolean;
|
||||
artifacts?: string | boolean;
|
||||
mcpServers?: string[];
|
||||
};
|
||||
|
||||
|
|
@ -54,6 +55,7 @@ export const tModelSpecSchema = z.object({
|
|||
webSearch: z.boolean().optional(),
|
||||
fileSearch: z.boolean().optional(),
|
||||
executeCode: z.boolean().optional(),
|
||||
artifacts: z.union([z.string(), z.boolean()]).optional(),
|
||||
mcpServers: z.array(z.string()).optional(),
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -952,6 +952,9 @@ export const paramSettings: Record<string, SettingsConfiguration | undefined> =
|
|||
[`${EModelEndpoint.bedrock}-${BedrockProviders.Amazon}`]: bedrockGeneral,
|
||||
[`${EModelEndpoint.bedrock}-${BedrockProviders.DeepSeek}`]: bedrockGeneral,
|
||||
[`${EModelEndpoint.bedrock}-${BedrockProviders.Moonshot}`]: bedrockMoonshot,
|
||||
[`${EModelEndpoint.bedrock}-${BedrockProviders.MoonshotAI}`]: bedrockMoonshot,
|
||||
[`${EModelEndpoint.bedrock}-${BedrockProviders.OpenAI}`]: bedrockGeneral,
|
||||
[`${EModelEndpoint.bedrock}-${BedrockProviders.ZAI}`]: bedrockGeneral,
|
||||
[EModelEndpoint.google]: googleConfig,
|
||||
};
|
||||
|
||||
|
|
@ -1000,6 +1003,12 @@ export const presetSettings: Record<
|
|||
col1: bedrockMoonshotCol1,
|
||||
col2: bedrockMoonshotCol2,
|
||||
},
|
||||
[`${EModelEndpoint.bedrock}-${BedrockProviders.MoonshotAI}`]: {
|
||||
col1: bedrockMoonshotCol1,
|
||||
col2: bedrockMoonshotCol2,
|
||||
},
|
||||
[`${EModelEndpoint.bedrock}-${BedrockProviders.OpenAI}`]: bedrockGeneralColumns,
|
||||
[`${EModelEndpoint.bedrock}-${BedrockProviders.ZAI}`]: bedrockGeneralColumns,
|
||||
[EModelEndpoint.google]: {
|
||||
col1: googleCol1,
|
||||
col2: googleCol2,
|
||||
|
|
|
|||
|
|
@ -144,26 +144,25 @@ export const parseConvo = ({
|
|||
endpointType,
|
||||
conversation,
|
||||
possibleValues,
|
||||
defaultParamsEndpoint,
|
||||
}: {
|
||||
endpoint: EndpointSchemaKey;
|
||||
endpointType?: EndpointSchemaKey | null;
|
||||
conversation: Partial<s.TConversation | s.TPreset> | null;
|
||||
possibleValues?: TPossibleValues;
|
||||
// TODO: POC for default schema
|
||||
// defaultSchema?: Partial<EndpointSchema>,
|
||||
defaultParamsEndpoint?: string | null;
|
||||
}) => {
|
||||
let schema = endpointSchemas[endpoint] as EndpointSchema | undefined;
|
||||
|
||||
if (!schema && !endpointType) {
|
||||
throw new Error(`Unknown endpoint: ${endpoint}`);
|
||||
} else if (!schema && endpointType) {
|
||||
schema = endpointSchemas[endpointType];
|
||||
} else if (!schema) {
|
||||
const overrideSchema = defaultParamsEndpoint
|
||||
? endpointSchemas[defaultParamsEndpoint as EndpointSchemaKey]
|
||||
: undefined;
|
||||
schema = overrideSchema ?? (endpointType ? endpointSchemas[endpointType] : undefined);
|
||||
}
|
||||
|
||||
// if (defaultSchema && schemaCreators[endpoint]) {
|
||||
// schema = schemaCreators[endpoint](defaultSchema);
|
||||
// }
|
||||
|
||||
const convo = schema?.parse(conversation) as s.TConversation | undefined;
|
||||
const { models } = possibleValues ?? {};
|
||||
|
||||
|
|
@ -310,13 +309,13 @@ export const parseCompactConvo = ({
|
|||
endpointType,
|
||||
conversation,
|
||||
possibleValues,
|
||||
defaultParamsEndpoint,
|
||||
}: {
|
||||
endpoint?: EndpointSchemaKey;
|
||||
endpointType?: EndpointSchemaKey | null;
|
||||
conversation: Partial<s.TConversation | s.TPreset>;
|
||||
possibleValues?: TPossibleValues;
|
||||
// TODO: POC for default schema
|
||||
// defaultSchema?: Partial<EndpointSchema>,
|
||||
defaultParamsEndpoint?: string | null;
|
||||
}): Omit<s.TConversation, 'iconURL'> | null => {
|
||||
if (!endpoint) {
|
||||
throw new Error(`undefined endpoint: ${endpoint}`);
|
||||
|
|
@ -326,8 +325,11 @@ export const parseCompactConvo = ({
|
|||
|
||||
if (!schema && !endpointType) {
|
||||
throw new Error(`Unknown endpoint: ${endpoint}`);
|
||||
} else if (!schema && endpointType) {
|
||||
schema = compactEndpointSchemas[endpointType];
|
||||
} else if (!schema) {
|
||||
const overrideSchema = defaultParamsEndpoint
|
||||
? compactEndpointSchemas[defaultParamsEndpoint as EndpointSchemaKey]
|
||||
: undefined;
|
||||
schema = overrideSchema ?? (endpointType ? compactEndpointSchemas[endpointType] : undefined);
|
||||
}
|
||||
|
||||
if (!schema) {
|
||||
|
|
|
|||
|
|
@ -101,7 +101,10 @@ export enum BedrockProviders {
|
|||
Meta = 'meta',
|
||||
MistralAI = 'mistral',
|
||||
Moonshot = 'moonshot',
|
||||
MoonshotAI = 'moonshotai',
|
||||
OpenAI = 'openai',
|
||||
StabilityAI = 'stability',
|
||||
ZAI = 'zai',
|
||||
}
|
||||
|
||||
export const getModelKey = (endpoint: EModelEndpoint | string, model: string) => {
|
||||
|
|
|
|||
|
|
@ -99,6 +99,7 @@ export type TEphemeralAgent = {
|
|||
web_search?: boolean;
|
||||
file_search?: boolean;
|
||||
execute_code?: boolean;
|
||||
artifacts?: string;
|
||||
};
|
||||
|
||||
export type TPayload = Partial<TMessage> &
|
||||
|
|
|
|||
|
|
@ -1015,4 +1015,239 @@ describe('Meilisearch Mongoose plugin', () => {
|
|||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Missing _meiliIndex property handling in sync process', () => {
|
||||
test('syncWithMeili includes documents with missing _meiliIndex', async () => {
|
||||
const conversationModel = createConversationModel(mongoose) as SchemaWithMeiliMethods;
|
||||
await conversationModel.deleteMany({});
|
||||
mockAddDocumentsInBatches.mockClear();
|
||||
|
||||
// Insert documents with different _meiliIndex states
|
||||
await conversationModel.collection.insertMany([
|
||||
{
|
||||
conversationId: new mongoose.Types.ObjectId(),
|
||||
user: new mongoose.Types.ObjectId(),
|
||||
title: 'Missing _meiliIndex',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
expiredAt: null,
|
||||
// _meiliIndex is not set (missing/undefined)
|
||||
},
|
||||
{
|
||||
conversationId: new mongoose.Types.ObjectId(),
|
||||
user: new mongoose.Types.ObjectId(),
|
||||
title: 'Explicit false',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
expiredAt: null,
|
||||
_meiliIndex: false,
|
||||
},
|
||||
{
|
||||
conversationId: new mongoose.Types.ObjectId(),
|
||||
user: new mongoose.Types.ObjectId(),
|
||||
title: 'Already indexed',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
expiredAt: null,
|
||||
_meiliIndex: true,
|
||||
},
|
||||
]);
|
||||
|
||||
// Run sync
|
||||
await conversationModel.syncWithMeili();
|
||||
|
||||
// Should have processed 2 documents (missing and false, but not true)
|
||||
expect(mockAddDocumentsInBatches).toHaveBeenCalled();
|
||||
|
||||
// Check that both documents without _meiliIndex=true are now indexed
|
||||
const indexedCount = await conversationModel.countDocuments({
|
||||
expiredAt: null,
|
||||
_meiliIndex: true,
|
||||
});
|
||||
expect(indexedCount).toBe(3); // All 3 should now be indexed
|
||||
|
||||
// Verify documents with missing _meiliIndex were updated
|
||||
const docsWithMissingIndex = await conversationModel.countDocuments({
|
||||
expiredAt: null,
|
||||
title: 'Missing _meiliIndex',
|
||||
_meiliIndex: true,
|
||||
});
|
||||
expect(docsWithMissingIndex).toBe(1);
|
||||
});
|
||||
|
||||
test('getSyncProgress counts documents with missing _meiliIndex as not indexed', async () => {
|
||||
const messageModel = createMessageModel(mongoose) as SchemaWithMeiliMethods;
|
||||
await messageModel.deleteMany({});
|
||||
|
||||
// Insert documents with different _meiliIndex states
|
||||
await messageModel.collection.insertMany([
|
||||
{
|
||||
messageId: new mongoose.Types.ObjectId(),
|
||||
conversationId: new mongoose.Types.ObjectId(),
|
||||
user: new mongoose.Types.ObjectId(),
|
||||
isCreatedByUser: true,
|
||||
expiredAt: null,
|
||||
_meiliIndex: true,
|
||||
},
|
||||
{
|
||||
messageId: new mongoose.Types.ObjectId(),
|
||||
conversationId: new mongoose.Types.ObjectId(),
|
||||
user: new mongoose.Types.ObjectId(),
|
||||
isCreatedByUser: true,
|
||||
expiredAt: null,
|
||||
_meiliIndex: false,
|
||||
},
|
||||
{
|
||||
messageId: new mongoose.Types.ObjectId(),
|
||||
conversationId: new mongoose.Types.ObjectId(),
|
||||
user: new mongoose.Types.ObjectId(),
|
||||
isCreatedByUser: true,
|
||||
expiredAt: null,
|
||||
// _meiliIndex is missing
|
||||
},
|
||||
]);
|
||||
|
||||
const progress = await messageModel.getSyncProgress();
|
||||
|
||||
// Total should be 3
|
||||
expect(progress.totalDocuments).toBe(3);
|
||||
// Only 1 is indexed (with _meiliIndex: true)
|
||||
expect(progress.totalProcessed).toBe(1);
|
||||
// Not complete since 2 documents are not indexed
|
||||
expect(progress.isComplete).toBe(false);
|
||||
});
|
||||
|
||||
test('query with _meiliIndex: { $ne: true } includes missing values', async () => {
|
||||
const conversationModel = createConversationModel(mongoose) as SchemaWithMeiliMethods;
|
||||
await conversationModel.deleteMany({});
|
||||
|
||||
// Insert documents with different _meiliIndex states
|
||||
await conversationModel.collection.insertMany([
|
||||
{
|
||||
conversationId: new mongoose.Types.ObjectId(),
|
||||
user: new mongoose.Types.ObjectId(),
|
||||
title: 'Missing',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
expiredAt: null,
|
||||
// _meiliIndex is missing
|
||||
},
|
||||
{
|
||||
conversationId: new mongoose.Types.ObjectId(),
|
||||
user: new mongoose.Types.ObjectId(),
|
||||
title: 'False',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
expiredAt: null,
|
||||
_meiliIndex: false,
|
||||
},
|
||||
{
|
||||
conversationId: new mongoose.Types.ObjectId(),
|
||||
user: new mongoose.Types.ObjectId(),
|
||||
title: 'True',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
expiredAt: null,
|
||||
_meiliIndex: true,
|
||||
},
|
||||
]);
|
||||
|
||||
// Query for documents where _meiliIndex is not true (used in syncWithMeili)
|
||||
const unindexedDocs = await conversationModel.find({
|
||||
expiredAt: null,
|
||||
_meiliIndex: { $ne: true },
|
||||
});
|
||||
|
||||
// Should find 2 documents (missing and false, but not true)
|
||||
expect(unindexedDocs.length).toBe(2);
|
||||
const titles = unindexedDocs.map((doc) => doc.title).sort();
|
||||
expect(titles).toEqual(['False', 'Missing']);
|
||||
});
|
||||
|
||||
test('syncWithMeili processes all documents where _meiliIndex is not true', async () => {
|
||||
const messageModel = createMessageModel(mongoose) as SchemaWithMeiliMethods;
|
||||
await messageModel.deleteMany({});
|
||||
mockAddDocumentsInBatches.mockClear();
|
||||
|
||||
// Create a mix of documents with missing and false _meiliIndex
|
||||
await messageModel.collection.insertMany([
|
||||
{
|
||||
messageId: new mongoose.Types.ObjectId(),
|
||||
conversationId: new mongoose.Types.ObjectId(),
|
||||
user: new mongoose.Types.ObjectId(),
|
||||
isCreatedByUser: true,
|
||||
expiredAt: null,
|
||||
// _meiliIndex missing
|
||||
},
|
||||
{
|
||||
messageId: new mongoose.Types.ObjectId(),
|
||||
conversationId: new mongoose.Types.ObjectId(),
|
||||
user: new mongoose.Types.ObjectId(),
|
||||
isCreatedByUser: true,
|
||||
expiredAt: null,
|
||||
_meiliIndex: false,
|
||||
},
|
||||
{
|
||||
messageId: new mongoose.Types.ObjectId(),
|
||||
conversationId: new mongoose.Types.ObjectId(),
|
||||
user: new mongoose.Types.ObjectId(),
|
||||
isCreatedByUser: true,
|
||||
expiredAt: null,
|
||||
// _meiliIndex missing
|
||||
},
|
||||
]);
|
||||
|
||||
// Count documents that should be synced (where _meiliIndex: { $ne: true })
|
||||
const toSyncCount = await messageModel.countDocuments({
|
||||
expiredAt: null,
|
||||
_meiliIndex: { $ne: true },
|
||||
});
|
||||
expect(toSyncCount).toBe(3); // All 3 should be synced
|
||||
|
||||
await messageModel.syncWithMeili();
|
||||
|
||||
// All should now be indexed
|
||||
const indexedCount = await messageModel.countDocuments({
|
||||
expiredAt: null,
|
||||
_meiliIndex: true,
|
||||
});
|
||||
expect(indexedCount).toBe(3);
|
||||
});
|
||||
|
||||
test('syncWithMeili treats missing _meiliIndex same as false', async () => {
|
||||
const conversationModel = createConversationModel(mongoose) as SchemaWithMeiliMethods;
|
||||
await conversationModel.deleteMany({});
|
||||
mockAddDocumentsInBatches.mockClear();
|
||||
|
||||
// Insert one document with missing _meiliIndex and one with false
|
||||
await conversationModel.collection.insertMany([
|
||||
{
|
||||
conversationId: new mongoose.Types.ObjectId(),
|
||||
user: new mongoose.Types.ObjectId(),
|
||||
title: 'Missing',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
expiredAt: null,
|
||||
// _meiliIndex is missing
|
||||
},
|
||||
{
|
||||
conversationId: new mongoose.Types.ObjectId(),
|
||||
user: new mongoose.Types.ObjectId(),
|
||||
title: 'False',
|
||||
endpoint: EModelEndpoint.openAI,
|
||||
expiredAt: null,
|
||||
_meiliIndex: false,
|
||||
},
|
||||
]);
|
||||
|
||||
// Both should be picked up by the sync query
|
||||
const toSync = await conversationModel.find({
|
||||
expiredAt: null,
|
||||
_meiliIndex: { $ne: true },
|
||||
});
|
||||
expect(toSync.length).toBe(2);
|
||||
|
||||
await conversationModel.syncWithMeili();
|
||||
|
||||
// Both should be indexed after sync
|
||||
const afterSync = await conversationModel.find({
|
||||
expiredAt: null,
|
||||
_meiliIndex: true,
|
||||
});
|
||||
expect(afterSync.length).toBe(2);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -162,8 +162,8 @@ const createMeiliMongooseModel = ({
|
|||
|
||||
/**
|
||||
* Synchronizes data between the MongoDB collection and the MeiliSearch index by
|
||||
* incrementally indexing only documents where `expiredAt` is `null` and `_meiliIndex` is `false`
|
||||
* (i.e., non-expired documents that have not yet been indexed).
|
||||
* incrementally indexing only documents where `expiredAt` is `null` and `_meiliIndex` is not `true`
|
||||
* (i.e., non-expired documents that have not yet been indexed, including those with missing or null `_meiliIndex`).
|
||||
* */
|
||||
static async syncWithMeili(this: SchemaWithMeiliMethods): Promise<void> {
|
||||
const startTime = Date.now();
|
||||
|
|
@ -196,7 +196,7 @@ const createMeiliMongooseModel = ({
|
|||
while (hasMore) {
|
||||
const query: FilterQuery<unknown> = {
|
||||
expiredAt: null,
|
||||
_meiliIndex: false,
|
||||
_meiliIndex: { $ne: true },
|
||||
};
|
||||
|
||||
try {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue