mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-01-16 15:35:31 +01:00
Merge branch 'main' into feature/entra-id-azure-integration
This commit is contained in:
commit
af661b1df2
293 changed files with 20207 additions and 13884 deletions
|
|
@ -1,7 +1,13 @@
|
|||
export default {
|
||||
collectCoverageFrom: ['src/**/*.{js,jsx,ts,tsx}', '!<rootDir>/node_modules/'],
|
||||
coveragePathIgnorePatterns: ['/node_modules/', '/dist/'],
|
||||
testPathIgnorePatterns: ['/node_modules/', '/dist/', '\\.dev\\.ts$'],
|
||||
testPathIgnorePatterns: [
|
||||
'/node_modules/',
|
||||
'/dist/',
|
||||
'\\.dev\\.ts$',
|
||||
'\\.helper\\.ts$',
|
||||
'\\.helper\\.d\\.ts$',
|
||||
],
|
||||
coverageReporters: ['text', 'cobertura'],
|
||||
testResultsProcessor: 'jest-junit',
|
||||
moduleNameMapper: {
|
||||
|
|
@ -18,4 +24,4 @@ export default {
|
|||
// },
|
||||
restoreMocks: true,
|
||||
testTimeout: 15000,
|
||||
};
|
||||
};
|
||||
|
|
|
|||
|
|
@ -18,9 +18,11 @@
|
|||
"build:dev": "npm run clean && NODE_ENV=development rollup -c --bundleConfigAsCjs",
|
||||
"build:watch": "NODE_ENV=development rollup -c -w --bundleConfigAsCjs",
|
||||
"build:watch:prod": "rollup -c -w --bundleConfigAsCjs",
|
||||
"test": "jest --coverage --watch --testPathIgnorePatterns=\"\\.integration\\.\"",
|
||||
"test:ci": "jest --coverage --ci --testPathIgnorePatterns=\"\\.integration\\.\"",
|
||||
"test:cache:integration": "jest --testPathPattern=\"src/cache/.*\\.integration\\.spec\\.ts$\" --coverage=false",
|
||||
"test": "jest --coverage --watch --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.\"",
|
||||
"test:ci": "jest --coverage --ci --testPathIgnorePatterns=\"\\.*integration\\.|\\.*helper\\.\"",
|
||||
"test:cache-integration:core": "jest --testPathPattern=\"src/cache/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false",
|
||||
"test:cache-integration:cluster": "jest --testPathPattern=\"src/cluster/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false --runInBand",
|
||||
"test:cache-integration:mcp": "jest --testPathPattern=\"src/mcp/.*\\.cache_integration\\.spec\\.ts$\" --coverage=false",
|
||||
"verify": "npm run test:ci",
|
||||
"b:clean": "bun run rimraf dist",
|
||||
"b:build": "bun run b:clean && bun run rollup -c --silent --bundleConfigAsCjs",
|
||||
|
|
@ -58,6 +60,7 @@
|
|||
"@types/jsonwebtoken": "^9.0.0",
|
||||
"@types/multer": "^1.4.13",
|
||||
"@types/node": "^20.3.0",
|
||||
"@types/node-fetch": "^2.6.13",
|
||||
"@types/react": "^18.2.18",
|
||||
"@types/winston": "^2.4.4",
|
||||
"jest": "^29.5.0",
|
||||
|
|
@ -79,10 +82,10 @@
|
|||
"@azure/search-documents": "^12.0.0",
|
||||
"@azure/storage-blob": "^12.27.0",
|
||||
"@keyv/redis": "^4.3.3",
|
||||
"@langchain/core": "^0.3.62",
|
||||
"@librechat/agents": "^2.4.90",
|
||||
"@langchain/core": "^0.3.79",
|
||||
"@librechat/agents": "^3.0.17",
|
||||
"@librechat/data-schemas": "*",
|
||||
"@modelcontextprotocol/sdk": "^1.17.1",
|
||||
"@modelcontextprotocol/sdk": "^1.21.0",
|
||||
"axios": "^1.12.1",
|
||||
"connect-redis": "^8.1.0",
|
||||
"diff": "^7.0.0",
|
||||
|
|
|
|||
47
packages/api/src/agents/chain.ts
Normal file
47
packages/api/src/agents/chain.ts
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
import { PromptTemplate } from '@langchain/core/prompts';
|
||||
import { BaseMessage, getBufferString } from '@langchain/core/messages';
|
||||
import type { GraphEdge } from '@librechat/agents';
|
||||
|
||||
const DEFAULT_PROMPT_TEMPLATE = `Based on the following conversation and analysis from previous agents, please provide your insights:\n\n{convo}\n\nPlease add your specific expertise and perspective to this discussion.`;
|
||||
|
||||
/**
|
||||
* Helper function to create sequential chain edges with buffer string prompts
|
||||
*
|
||||
* @deprecated Agent Chain helper
|
||||
* @param agentIds - Array of agent IDs in order of execution
|
||||
* @param promptTemplate - Optional prompt template string; defaults to a predefined template if not provided
|
||||
* @returns Array of edges configured for sequential chain with buffer prompts
|
||||
*/
|
||||
export async function createSequentialChainEdges(
|
||||
agentIds: string[],
|
||||
promptTemplate = DEFAULT_PROMPT_TEMPLATE,
|
||||
): Promise<GraphEdge[]> {
|
||||
const edges: GraphEdge[] = [];
|
||||
|
||||
for (let i = 0; i < agentIds.length - 1; i++) {
|
||||
const fromAgent = agentIds[i];
|
||||
const toAgent = agentIds[i + 1];
|
||||
|
||||
edges.push({
|
||||
from: fromAgent,
|
||||
to: toAgent,
|
||||
edgeType: 'direct',
|
||||
// Use a prompt function to create the buffer string from all previous results
|
||||
prompt: async (messages: BaseMessage[], startIndex: number) => {
|
||||
/** Only the messages from this run (after startIndex) are passed in */
|
||||
const runMessages = messages.slice(startIndex);
|
||||
const bufferString = getBufferString(runMessages);
|
||||
const template = PromptTemplate.fromTemplate(promptTemplate);
|
||||
const result = await template.invoke({
|
||||
convo: bufferString,
|
||||
});
|
||||
return result.value;
|
||||
},
|
||||
/** Critical: exclude previous results so only the prompt is passed */
|
||||
excludeResults: true,
|
||||
description: `Sequential chain from ${fromAgent} to ${toAgent}`,
|
||||
});
|
||||
}
|
||||
|
||||
return edges;
|
||||
}
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
export * from './chain';
|
||||
export * from './memory';
|
||||
export * from './migration';
|
||||
export * from './legacy';
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import type {
|
|||
} from '@librechat/agents';
|
||||
import type { TAttachment, MemoryArtifact } from 'librechat-data-provider';
|
||||
import type { ObjectId, MemoryMethods } from '@librechat/data-schemas';
|
||||
import type { BaseMessage } from '@langchain/core/messages';
|
||||
import type { BaseMessage, ToolMessage } from '@langchain/core/messages';
|
||||
import type { Response as ServerResponse } from 'express';
|
||||
import { Tokenizer } from '~/utils';
|
||||
|
||||
|
|
@ -466,7 +466,7 @@ async function handleMemoryArtifact({
|
|||
data: ToolEndData;
|
||||
metadata?: ToolEndMetadata;
|
||||
}) {
|
||||
const output = data?.output;
|
||||
const output = data?.output as ToolMessage | undefined;
|
||||
if (!output) {
|
||||
return null;
|
||||
}
|
||||
|
|
@ -509,7 +509,7 @@ export function createMemoryCallback({
|
|||
artifactPromises: Promise<Partial<TAttachment> | null>[];
|
||||
}): ToolEndCallback {
|
||||
return async (data: ToolEndData, metadata?: Record<string, unknown>) => {
|
||||
const output = data?.output;
|
||||
const output = data?.output as ToolMessage | undefined;
|
||||
const memoryArtifact = output?.artifact?.[Tools.memory] as MemoryArtifact;
|
||||
if (memoryArtifact == null) {
|
||||
return;
|
||||
|
|
|
|||
|
|
@ -1,15 +1,17 @@
|
|||
import { Run, Providers } from '@librechat/agents';
|
||||
import { providerEndpointMap, KnownEndpoints } from 'librechat-data-provider';
|
||||
import type {
|
||||
MultiAgentGraphConfig,
|
||||
OpenAIClientOptions,
|
||||
StandardGraphConfig,
|
||||
EventHandler,
|
||||
AgentInputs,
|
||||
GenericTool,
|
||||
GraphEvents,
|
||||
RunConfig,
|
||||
IState,
|
||||
} from '@librechat/agents';
|
||||
import type { Agent } from 'librechat-data-provider';
|
||||
import type * as t from '~/types';
|
||||
import { resolveHeaders } from '~/utils/env';
|
||||
|
||||
const customProviders = new Set([
|
||||
Providers.XAI,
|
||||
|
|
@ -40,13 +42,19 @@ export function getReasoningKey(
|
|||
return reasoningKey;
|
||||
}
|
||||
|
||||
type RunAgent = Omit<Agent, 'tools'> & {
|
||||
tools?: GenericTool[];
|
||||
maxContextTokens?: number;
|
||||
useLegacyContent?: boolean;
|
||||
toolContextMap?: Record<string, string>;
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates a new Run instance with custom handlers and configuration.
|
||||
*
|
||||
* @param options - The options for creating the Run instance.
|
||||
* @param options.agent - The agent for this run.
|
||||
* @param options.agents - The agents for this run.
|
||||
* @param options.signal - The signal for this run.
|
||||
* @param options.req - The server request.
|
||||
* @param options.runId - Optional run ID; otherwise, a new run ID will be generated.
|
||||
* @param options.customHandlers - Custom event handlers.
|
||||
* @param options.streaming - Whether to use streaming.
|
||||
|
|
@ -55,61 +63,109 @@ export function getReasoningKey(
|
|||
*/
|
||||
export async function createRun({
|
||||
runId,
|
||||
agent,
|
||||
signal,
|
||||
agents,
|
||||
requestBody,
|
||||
tokenCounter,
|
||||
customHandlers,
|
||||
indexTokenCountMap,
|
||||
streaming = true,
|
||||
streamUsage = true,
|
||||
}: {
|
||||
agent: Omit<Agent, 'tools'> & { tools?: GenericTool[] };
|
||||
agents: RunAgent[];
|
||||
signal: AbortSignal;
|
||||
runId?: string;
|
||||
streaming?: boolean;
|
||||
streamUsage?: boolean;
|
||||
customHandlers?: Record<GraphEvents, EventHandler>;
|
||||
}): Promise<Run<IState>> {
|
||||
const provider =
|
||||
(providerEndpointMap[
|
||||
agent.provider as keyof typeof providerEndpointMap
|
||||
] as unknown as Providers) ?? agent.provider;
|
||||
requestBody?: t.RequestBody;
|
||||
} & Pick<RunConfig, 'tokenCounter' | 'customHandlers' | 'indexTokenCountMap'>): Promise<
|
||||
Run<IState>
|
||||
> {
|
||||
const agentInputs: AgentInputs[] = [];
|
||||
const buildAgentContext = (agent: RunAgent) => {
|
||||
const provider =
|
||||
(providerEndpointMap[
|
||||
agent.provider as keyof typeof providerEndpointMap
|
||||
] as unknown as Providers) ?? agent.provider;
|
||||
|
||||
const llmConfig: t.RunLLMConfig = Object.assign(
|
||||
{
|
||||
const llmConfig: t.RunLLMConfig = Object.assign(
|
||||
{
|
||||
provider,
|
||||
streaming,
|
||||
streamUsage,
|
||||
},
|
||||
agent.model_parameters,
|
||||
);
|
||||
|
||||
const systemMessage = Object.values(agent.toolContextMap ?? {})
|
||||
.join('\n')
|
||||
.trim();
|
||||
|
||||
const systemContent = [
|
||||
systemMessage,
|
||||
agent.instructions ?? '',
|
||||
agent.additional_instructions ?? '',
|
||||
]
|
||||
.join('\n')
|
||||
.trim();
|
||||
|
||||
/**
|
||||
* Resolve request-based headers for Custom Endpoints. Note: if this is added to
|
||||
* non-custom endpoints, needs consideration of varying provider header configs.
|
||||
* This is done at this step because the request body may contain dynamic values
|
||||
* that need to be resolved after agent initialization.
|
||||
*/
|
||||
if (llmConfig?.configuration?.defaultHeaders != null) {
|
||||
llmConfig.configuration.defaultHeaders = resolveHeaders({
|
||||
headers: llmConfig.configuration.defaultHeaders as Record<string, string>,
|
||||
body: requestBody,
|
||||
});
|
||||
}
|
||||
|
||||
/** Resolves issues with new OpenAI usage field */
|
||||
if (
|
||||
customProviders.has(agent.provider) ||
|
||||
(agent.provider === Providers.OPENAI && agent.endpoint !== agent.provider)
|
||||
) {
|
||||
llmConfig.streamUsage = false;
|
||||
llmConfig.usage = true;
|
||||
}
|
||||
|
||||
const reasoningKey = getReasoningKey(provider, llmConfig, agent.endpoint);
|
||||
const agentInput: AgentInputs = {
|
||||
provider,
|
||||
streaming,
|
||||
streamUsage,
|
||||
},
|
||||
agent.model_parameters,
|
||||
);
|
||||
|
||||
/** Resolves issues with new OpenAI usage field */
|
||||
if (
|
||||
customProviders.has(agent.provider) ||
|
||||
(agent.provider === Providers.OPENAI && agent.endpoint !== agent.provider)
|
||||
) {
|
||||
llmConfig.streamUsage = false;
|
||||
llmConfig.usage = true;
|
||||
}
|
||||
|
||||
const reasoningKey = getReasoningKey(provider, llmConfig, agent.endpoint);
|
||||
const graphConfig: StandardGraphConfig = {
|
||||
signal,
|
||||
llmConfig,
|
||||
reasoningKey,
|
||||
tools: agent.tools,
|
||||
instructions: agent.instructions,
|
||||
additional_instructions: agent.additional_instructions,
|
||||
// toolEnd: agent.end_after_tools,
|
||||
reasoningKey,
|
||||
agentId: agent.id,
|
||||
tools: agent.tools,
|
||||
clientOptions: llmConfig,
|
||||
instructions: systemContent,
|
||||
maxContextTokens: agent.maxContextTokens,
|
||||
useLegacyContent: agent.useLegacyContent ?? false,
|
||||
};
|
||||
agentInputs.push(agentInput);
|
||||
};
|
||||
|
||||
// TEMPORARY FOR TESTING
|
||||
if (agent.provider === Providers.ANTHROPIC || agent.provider === Providers.BEDROCK) {
|
||||
graphConfig.streamBuffer = 2000;
|
||||
for (const agent of agents) {
|
||||
buildAgentContext(agent);
|
||||
}
|
||||
|
||||
const graphConfig: RunConfig['graphConfig'] = {
|
||||
signal,
|
||||
agents: agentInputs,
|
||||
edges: agents[0].edges,
|
||||
};
|
||||
|
||||
if (agentInputs.length > 1 || ((graphConfig as MultiAgentGraphConfig).edges?.length ?? 0) > 0) {
|
||||
(graphConfig as unknown as MultiAgentGraphConfig).type = 'multi-agent';
|
||||
} else {
|
||||
(graphConfig as StandardGraphConfig).type = 'standard';
|
||||
}
|
||||
|
||||
return Run.create({
|
||||
runId,
|
||||
graphConfig,
|
||||
tokenCounter,
|
||||
customHandlers,
|
||||
indexTokenCountMap,
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -40,6 +40,17 @@ export const agentSupportContactSchema = z
|
|||
})
|
||||
.optional();
|
||||
|
||||
/** Graph edge schema for agent handoffs */
|
||||
export const graphEdgeSchema = z.object({
|
||||
from: z.union([z.string(), z.array(z.string())]),
|
||||
to: z.union([z.string(), z.array(z.string())]),
|
||||
description: z.string().optional(),
|
||||
edgeType: z.enum(['handoff', 'direct']).optional(),
|
||||
prompt: z.union([z.string(), z.function()]).optional(),
|
||||
excludeResults: z.boolean().optional(),
|
||||
promptKey: z.string().optional(),
|
||||
});
|
||||
|
||||
/** Base agent schema with all common fields */
|
||||
export const agentBaseSchema = z.object({
|
||||
name: z.string().nullable().optional(),
|
||||
|
|
@ -48,7 +59,9 @@ export const agentBaseSchema = z.object({
|
|||
avatar: agentAvatarSchema.nullable().optional(),
|
||||
model_parameters: z.record(z.unknown()).optional(),
|
||||
tools: z.array(z.string()).optional(),
|
||||
/** @deprecated Use edges instead */
|
||||
agent_ids: z.array(z.string()).optional(),
|
||||
edges: z.array(graphEdgeSchema).optional(),
|
||||
end_after_tools: z.boolean().optional(),
|
||||
hide_sequential_outputs: z.boolean().optional(),
|
||||
artifacts: z.string().optional(),
|
||||
|
|
|
|||
|
|
@ -40,7 +40,6 @@ jest.mock('@librechat/data-schemas', () => ({
|
|||
|
||||
jest.mock('~/utils', () => ({
|
||||
isEnabled: jest.fn((value) => value === 'true'),
|
||||
normalizeEndpointName: jest.fn((name) => name),
|
||||
}));
|
||||
|
||||
describe('getTransactionsConfig', () => {
|
||||
|
|
|
|||
|
|
@ -1,8 +1,12 @@
|
|||
import { logger } from '@librechat/data-schemas';
|
||||
import { EModelEndpoint, removeNullishValues } from 'librechat-data-provider';
|
||||
import {
|
||||
EModelEndpoint,
|
||||
removeNullishValues,
|
||||
normalizeEndpointName,
|
||||
} from 'librechat-data-provider';
|
||||
import type { TCustomConfig, TEndpoint, TTransactionsConfig } from 'librechat-data-provider';
|
||||
import type { AppConfig } from '@librechat/data-schemas';
|
||||
import { isEnabled, normalizeEndpointName } from '~/utils';
|
||||
import { isEnabled } from '~/utils';
|
||||
|
||||
/**
|
||||
* Retrieves the balance configuration object
|
||||
|
|
|
|||
|
|
@ -1,11 +1,14 @@
|
|||
import type { Keyv } from 'keyv';
|
||||
|
||||
// Mock GLOBAL_PREFIX_SEPARATOR
|
||||
jest.mock('../../redisClients', () => {
|
||||
const originalModule = jest.requireActual('../../redisClients');
|
||||
// Mock GLOBAL_PREFIX_SEPARATOR from cacheConfig
|
||||
jest.mock('../../cacheConfig', () => {
|
||||
const originalModule = jest.requireActual('../../cacheConfig');
|
||||
return {
|
||||
...originalModule,
|
||||
GLOBAL_PREFIX_SEPARATOR: '>>',
|
||||
cacheConfig: {
|
||||
...originalModule.cacheConfig,
|
||||
GLOBAL_PREFIX_SEPARATOR: '>>',
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
1
packages/api/src/cache/cacheConfig.ts
vendored
1
packages/api/src/cache/cacheConfig.ts
vendored
|
|
@ -65,6 +65,7 @@ const cacheConfig = {
|
|||
REDIS_PASSWORD: process.env.REDIS_PASSWORD,
|
||||
REDIS_CA: getRedisCA(),
|
||||
REDIS_KEY_PREFIX: process.env[REDIS_KEY_PREFIX_VAR ?? ''] || REDIS_KEY_PREFIX || '',
|
||||
GLOBAL_PREFIX_SEPARATOR: '::',
|
||||
REDIS_MAX_LISTENERS: math(process.env.REDIS_MAX_LISTENERS, 40),
|
||||
REDIS_PING_INTERVAL: math(process.env.REDIS_PING_INTERVAL, 0),
|
||||
/** Max delay between reconnection attempts in ms */
|
||||
|
|
|
|||
4
packages/api/src/cache/cacheFactory.ts
vendored
4
packages/api/src/cache/cacheFactory.ts
vendored
|
|
@ -14,7 +14,7 @@ import { logger } from '@librechat/data-schemas';
|
|||
import session, { MemoryStore } from 'express-session';
|
||||
import { RedisStore as ConnectRedis } from 'connect-redis';
|
||||
import type { SendCommandFn } from 'rate-limit-redis';
|
||||
import { keyvRedisClient, ioredisClient, GLOBAL_PREFIX_SEPARATOR } from './redisClients';
|
||||
import { keyvRedisClient, ioredisClient } from './redisClients';
|
||||
import { cacheConfig } from './cacheConfig';
|
||||
import { violationFile } from './keyvFiles';
|
||||
|
||||
|
|
@ -31,7 +31,7 @@ export const standardCache = (namespace: string, ttl?: number, fallbackStore?: o
|
|||
const keyvRedis = new KeyvRedis(keyvRedisClient);
|
||||
const cache = new Keyv(keyvRedis, { namespace, ttl });
|
||||
keyvRedis.namespace = cacheConfig.REDIS_KEY_PREFIX;
|
||||
keyvRedis.keyPrefixSeparator = GLOBAL_PREFIX_SEPARATOR;
|
||||
keyvRedis.keyPrefixSeparator = cacheConfig.GLOBAL_PREFIX_SEPARATOR;
|
||||
|
||||
cache.on('error', (err) => {
|
||||
logger.error(`Cache error in namespace ${namespace}:`, err);
|
||||
|
|
|
|||
6
packages/api/src/cache/redisClients.ts
vendored
6
packages/api/src/cache/redisClients.ts
vendored
|
|
@ -5,8 +5,6 @@ import { createClient, createCluster } from '@keyv/redis';
|
|||
import type { RedisClientType, RedisClusterType } from '@redis/client';
|
||||
import { cacheConfig } from './cacheConfig';
|
||||
|
||||
const GLOBAL_PREFIX_SEPARATOR = '::';
|
||||
|
||||
const urls = cacheConfig.REDIS_URI?.split(',').map((uri) => new URL(uri)) || [];
|
||||
const username = urls?.[0]?.username || cacheConfig.REDIS_USERNAME;
|
||||
const password = urls?.[0]?.password || cacheConfig.REDIS_PASSWORD;
|
||||
|
|
@ -18,7 +16,7 @@ if (cacheConfig.USE_REDIS) {
|
|||
username: username,
|
||||
password: password,
|
||||
tls: ca ? { ca } : undefined,
|
||||
keyPrefix: `${cacheConfig.REDIS_KEY_PREFIX}${GLOBAL_PREFIX_SEPARATOR}`,
|
||||
keyPrefix: `${cacheConfig.REDIS_KEY_PREFIX}${cacheConfig.GLOBAL_PREFIX_SEPARATOR}`,
|
||||
maxListeners: cacheConfig.REDIS_MAX_LISTENERS,
|
||||
retryStrategy: (times: number) => {
|
||||
if (
|
||||
|
|
@ -192,4 +190,4 @@ if (cacheConfig.USE_REDIS) {
|
|||
});
|
||||
}
|
||||
|
||||
export { ioredisClient, keyvRedisClient, GLOBAL_PREFIX_SEPARATOR };
|
||||
export { ioredisClient, keyvRedisClient };
|
||||
|
|
|
|||
180
packages/api/src/cluster/LeaderElection.ts
Normal file
180
packages/api/src/cluster/LeaderElection.ts
Normal file
|
|
@ -0,0 +1,180 @@
|
|||
import { keyvRedisClient } from '~/cache/redisClients';
|
||||
import { cacheConfig as cache } from '~/cache/cacheConfig';
|
||||
import { clusterConfig as cluster } from './config';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
|
||||
/**
|
||||
* Distributed leader election implementation using Redis for coordination across multiple server instances.
|
||||
*
|
||||
* Leadership election:
|
||||
* - During bootup, every server attempts to become the leader by calling isLeader()
|
||||
* - Uses atomic Redis SET NX (set if not exists) to ensure only ONE server can claim leadership
|
||||
* - The first server to successfully set the key becomes the leader; others become followers
|
||||
* - Works with any number of servers (1 to infinite) - single server always becomes leader
|
||||
*
|
||||
* Leadership maintenance:
|
||||
* - Leader holds a key in Redis with a 25-second lease duration
|
||||
* - Leader renews this lease every 10 seconds to maintain leadership
|
||||
* - If leader crashes, the lease eventually expires, and the key disappears
|
||||
* - On shutdown, leader deletes its key to allow immediate re-election
|
||||
* - Followers check for leadership and attempt to claim it when the key is empty
|
||||
*/
|
||||
export class LeaderElection {
|
||||
// We can't use Keyv namespace here because we need direct Redis access for atomic operations
|
||||
static readonly LEADER_KEY = `${cache.REDIS_KEY_PREFIX}${cache.GLOBAL_PREFIX_SEPARATOR}LeadingServerUUID`;
|
||||
private static _instance = new LeaderElection();
|
||||
|
||||
readonly UUID: string = crypto.randomUUID();
|
||||
private refreshTimer: NodeJS.Timeout | undefined = undefined;
|
||||
|
||||
// DO NOT create new instances of this class directly.
|
||||
// Use the exported isLeader() function which uses a singleton instance.
|
||||
constructor() {
|
||||
if (LeaderElection._instance) return LeaderElection._instance;
|
||||
|
||||
process.on('SIGTERM', () => this.resign());
|
||||
process.on('SIGINT', () => this.resign());
|
||||
LeaderElection._instance = this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if this instance is the current leader.
|
||||
* If no leader exists, waits upto 2 seconds (randomized to avoid thundering herd) then attempts self-election.
|
||||
* Always returns true in non-Redis mode (single-instance deployment).
|
||||
*/
|
||||
public async isLeader(): Promise<boolean> {
|
||||
if (!cache.USE_REDIS) return true;
|
||||
|
||||
try {
|
||||
const currentLeader = await LeaderElection.getLeaderUUID();
|
||||
// If we own the leadership lock, return true.
|
||||
// However, in case the leadership refresh retries have been exhausted, something has gone wrong.
|
||||
// This server is not considered the leader anymore, similar to a crash, to avoid split-brain scenario.
|
||||
if (currentLeader === this.UUID) return this.refreshTimer != null;
|
||||
if (currentLeader != null) return false; // someone holds leadership lock
|
||||
|
||||
const delay = Math.random() * 2000;
|
||||
await new Promise((resolve) => setTimeout(resolve, delay));
|
||||
return await this.electSelf();
|
||||
} catch (error) {
|
||||
logger.error('Failed to check leadership status:', error);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Steps down from leadership by stopping the refresh timer and releasing the leader key.
|
||||
* Atomically deletes the leader key (only if we still own it) so another server can become leader immediately.
|
||||
*/
|
||||
public async resign(): Promise<void> {
|
||||
if (!cache.USE_REDIS) return;
|
||||
|
||||
try {
|
||||
this.clearRefreshTimer();
|
||||
|
||||
// Lua script for atomic check-and-delete (only delete if we still own it)
|
||||
const script = `
|
||||
if redis.call("get", KEYS[1]) == ARGV[1] then
|
||||
redis.call("del", KEYS[1])
|
||||
end
|
||||
`;
|
||||
|
||||
await keyvRedisClient!.eval(script, {
|
||||
keys: [LeaderElection.LEADER_KEY],
|
||||
arguments: [this.UUID],
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('Failed to release leadership lock:', error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the UUID of the current leader from Redis.
|
||||
* Returns null if no leader exists or in non-Redis mode.
|
||||
* Useful for testing and observability.
|
||||
*/
|
||||
public static async getLeaderUUID(): Promise<string | null> {
|
||||
if (!cache.USE_REDIS) return null;
|
||||
return await keyvRedisClient!.get(LeaderElection.LEADER_KEY);
|
||||
}
|
||||
|
||||
/**
|
||||
* Clears the refresh timer to stop leadership maintenance.
|
||||
* Called when resigning or failing to refresh leadership.
|
||||
* Calling this directly to simulate a crash in testing.
|
||||
*/
|
||||
public clearRefreshTimer(): void {
|
||||
clearInterval(this.refreshTimer);
|
||||
this.refreshTimer = undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Attempts to claim leadership using atomic Redis SET NX (set if not exists).
|
||||
* If successful, starts a refresh timer to maintain leadership by extending the lease duration.
|
||||
* The NX flag ensures only one server can become leader even if multiple attempt simultaneously.
|
||||
*/
|
||||
private async electSelf(): Promise<boolean> {
|
||||
try {
|
||||
const result = await keyvRedisClient!.set(LeaderElection.LEADER_KEY, this.UUID, {
|
||||
NX: true,
|
||||
EX: cluster.LEADER_LEASE_DURATION,
|
||||
});
|
||||
|
||||
if (result !== 'OK') return false;
|
||||
|
||||
this.clearRefreshTimer();
|
||||
this.refreshTimer = setInterval(async () => {
|
||||
await this.renewLeadership();
|
||||
}, cluster.LEADER_RENEW_INTERVAL * 1000);
|
||||
this.refreshTimer.unref();
|
||||
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error('Leader election failed:', error);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Renews leadership by extending the lease duration on the leader key.
|
||||
* Uses Lua script to atomically verify we still own the key before renewing (prevents race conditions).
|
||||
* If we've lost leadership (key was taken by another server), stops the refresh timer.
|
||||
* This is called every 10 seconds by the refresh timer.
|
||||
*/
|
||||
private async renewLeadership(attempts: number = 1): Promise<void> {
|
||||
try {
|
||||
// Lua script for atomic check-and-renew
|
||||
const script = `
|
||||
if redis.call("get", KEYS[1]) == ARGV[1] then
|
||||
return redis.call("expire", KEYS[1], ARGV[2])
|
||||
else
|
||||
return 0
|
||||
end
|
||||
`;
|
||||
|
||||
const result = await keyvRedisClient!.eval(script, {
|
||||
keys: [LeaderElection.LEADER_KEY],
|
||||
arguments: [this.UUID, cluster.LEADER_LEASE_DURATION.toString()],
|
||||
});
|
||||
|
||||
if (result === 0) {
|
||||
logger.warn('Lost leadership, clearing refresh timer');
|
||||
this.clearRefreshTimer();
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`Failed to renew leadership (attempts No.${attempts}):`, error);
|
||||
if (attempts <= cluster.LEADER_RENEW_ATTEMPTS) {
|
||||
await new Promise((resolve) =>
|
||||
setTimeout(resolve, cluster.LEADER_RENEW_RETRY_DELAY * 1000),
|
||||
);
|
||||
await this.renewLeadership(attempts + 1);
|
||||
} else {
|
||||
logger.error('Exceeded maximum attempts to renew leadership.');
|
||||
this.clearRefreshTimer();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const defaultElection = new LeaderElection();
|
||||
export const isLeader = (): Promise<boolean> => defaultElection.isLeader();
|
||||
|
|
@ -0,0 +1,220 @@
|
|||
import { expect } from '@playwright/test';
|
||||
|
||||
describe('LeaderElection with Redis', () => {
|
||||
let LeaderElection: typeof import('../LeaderElection').LeaderElection;
|
||||
let instances: InstanceType<typeof import('../LeaderElection').LeaderElection>[] = [];
|
||||
let keyvRedisClient: Awaited<typeof import('~/cache/redisClients')>['keyvRedisClient'];
|
||||
let ioredisClient: Awaited<typeof import('~/cache/redisClients')>['ioredisClient'];
|
||||
|
||||
beforeAll(async () => {
|
||||
// Set up environment variables for Redis
|
||||
process.env.USE_REDIS = 'true';
|
||||
process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379';
|
||||
process.env.REDIS_KEY_PREFIX = 'LeaderElection-IntegrationTest';
|
||||
|
||||
// Import modules after setting env vars
|
||||
const leaderElectionModule = await import('../LeaderElection');
|
||||
const redisClients = await import('~/cache/redisClients');
|
||||
|
||||
LeaderElection = leaderElectionModule.LeaderElection;
|
||||
keyvRedisClient = redisClients.keyvRedisClient;
|
||||
ioredisClient = redisClients.ioredisClient;
|
||||
|
||||
// Ensure Redis is connected
|
||||
if (!keyvRedisClient) {
|
||||
throw new Error('Redis client is not initialized');
|
||||
}
|
||||
|
||||
// Wait for Redis to be ready
|
||||
if (!keyvRedisClient.isOpen) {
|
||||
await keyvRedisClient.connect();
|
||||
}
|
||||
|
||||
// Increase max listeners to handle many instances in tests
|
||||
process.setMaxListeners(200);
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await Promise.all(instances.map((instance) => instance.resign()));
|
||||
instances = [];
|
||||
|
||||
// Clean up: clear the leader key directly from Redis
|
||||
if (keyvRedisClient) {
|
||||
await keyvRedisClient.del(LeaderElection.LEADER_KEY);
|
||||
}
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
// Close both Redis clients to prevent hanging
|
||||
if (keyvRedisClient?.isOpen) await keyvRedisClient.disconnect();
|
||||
if (ioredisClient?.status === 'ready') await ioredisClient.quit();
|
||||
});
|
||||
|
||||
describe('Test Case 1: Simulate shutdown of the leader', () => {
|
||||
it('should elect a new leader after the current leader resigns', async () => {
|
||||
// Create 100 instances
|
||||
instances = Array.from({ length: 100 }, () => new LeaderElection());
|
||||
|
||||
// Call isLeader on all instances and get leadership status
|
||||
const resultsWithInstances = await Promise.all(
|
||||
instances.map(async (instance) => ({
|
||||
instance,
|
||||
isLeader: await instance.isLeader(),
|
||||
})),
|
||||
);
|
||||
|
||||
// Find leader and followers
|
||||
const leaders = resultsWithInstances.filter((r) => r.isLeader);
|
||||
const followers = resultsWithInstances.filter((r) => !r.isLeader);
|
||||
const leader = leaders[0].instance;
|
||||
const nextLeader = followers[0].instance;
|
||||
|
||||
// Verify only one is leader
|
||||
expect(leaders.length).toBe(1);
|
||||
|
||||
// Verify getLeaderUUID matches the leader's UUID
|
||||
expect(await LeaderElection.getLeaderUUID()).toBe(leader.UUID);
|
||||
|
||||
// Leader resigns
|
||||
await leader.resign();
|
||||
|
||||
// Verify getLeaderUUID returns null after resignation
|
||||
expect(await LeaderElection.getLeaderUUID()).toBeNull();
|
||||
|
||||
// Next instance to call isLeader should become the new leader
|
||||
expect(await nextLeader.isLeader()).toBe(true);
|
||||
}, 30000); // 30 second timeout for 100 instances
|
||||
});
|
||||
|
||||
describe('Test Case 2: Simulate crash of the leader', () => {
|
||||
it('should allow re-election after leader crashes (lease expires)', async () => {
|
||||
// Mock config with short lease duration
|
||||
const clusterConfigModule = await import('../config');
|
||||
const originalConfig = { ...clusterConfigModule.clusterConfig };
|
||||
|
||||
// Override config values for this test
|
||||
Object.assign(clusterConfigModule.clusterConfig, {
|
||||
LEADER_LEASE_DURATION: 2,
|
||||
LEADER_RENEW_INTERVAL: 4,
|
||||
});
|
||||
|
||||
try {
|
||||
// Create 1 instance with mocked config
|
||||
const instance = new LeaderElection();
|
||||
instances.push(instance);
|
||||
|
||||
// Become leader
|
||||
expect(await instance.isLeader()).toBe(true);
|
||||
|
||||
// Verify leader UUID is set
|
||||
expect(await LeaderElection.getLeaderUUID()).toBe(instance.UUID);
|
||||
|
||||
// Simulate crash by clearing refresh timer
|
||||
instance.clearRefreshTimer();
|
||||
|
||||
// The instance no longer considers itself leader even though it still holds the key
|
||||
expect(await LeaderElection.getLeaderUUID()).toBe(instance.UUID);
|
||||
expect(await instance.isLeader()).toBe(false);
|
||||
|
||||
// Wait for lease to expire (3 seconds > 2 second lease)
|
||||
await new Promise((resolve) => setTimeout(resolve, 3000));
|
||||
|
||||
// Verify leader UUID is null after lease expiration
|
||||
expect(await LeaderElection.getLeaderUUID()).toBeNull();
|
||||
} finally {
|
||||
// Restore original config values
|
||||
Object.assign(clusterConfigModule.clusterConfig, originalConfig);
|
||||
}
|
||||
}, 15000); // 15 second timeout
|
||||
});
|
||||
|
||||
describe('Test Case 3: Stress testing', () => {
|
||||
it('should ensure only one instance becomes leader even when multiple instances call electSelf() at once', async () => {
|
||||
// Create 10 instances
|
||||
instances = Array.from({ length: 10 }, () => new LeaderElection());
|
||||
|
||||
// Call electSelf on all instances in parallel
|
||||
const results = await Promise.all(instances.map((instance) => instance['electSelf']()));
|
||||
|
||||
// Verify only one returned true
|
||||
const successCount = results.filter((success) => success).length;
|
||||
expect(successCount).toBe(1);
|
||||
|
||||
// Find the winning instance
|
||||
const winnerInstance = instances.find((_, index) => results[index]);
|
||||
|
||||
// Verify getLeaderUUID matches the winner's UUID
|
||||
expect(await LeaderElection.getLeaderUUID()).toBe(winnerInstance?.UUID);
|
||||
}, 15000); // 15 second timeout
|
||||
});
|
||||
});
|
||||
|
||||
describe('LeaderElection without Redis', () => {
|
||||
let LeaderElection: typeof import('../LeaderElection').LeaderElection;
|
||||
let instances: InstanceType<typeof import('../LeaderElection').LeaderElection>[] = [];
|
||||
|
||||
beforeAll(async () => {
|
||||
// Set up environment variables for non-Redis mode
|
||||
process.env.USE_REDIS = 'false';
|
||||
|
||||
// Reset all modules to force re-evaluation with new env vars
|
||||
jest.resetModules();
|
||||
|
||||
// Import modules after setting env vars and resetting modules
|
||||
const leaderElectionModule = await import('../LeaderElection');
|
||||
LeaderElection = leaderElectionModule.LeaderElection;
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await Promise.all(instances.map((instance) => instance.resign()));
|
||||
instances = [];
|
||||
});
|
||||
|
||||
afterAll(() => {
|
||||
// Restore environment variables
|
||||
process.env.USE_REDIS = 'true';
|
||||
|
||||
// Reset all modules to ensure next test runs get fresh imports
|
||||
jest.resetModules();
|
||||
});
|
||||
|
||||
it('should allow all instances to be leaders when USE_REDIS is false', async () => {
|
||||
// Create 10 instances
|
||||
instances = Array.from({ length: 10 }, () => new LeaderElection());
|
||||
|
||||
// Call isLeader on all instances
|
||||
const results = await Promise.all(instances.map((instance) => instance.isLeader()));
|
||||
|
||||
// Verify all instances report themselves as leaders
|
||||
expect(results.every((isLeader) => isLeader)).toBe(true);
|
||||
expect(results.filter((isLeader) => isLeader).length).toBe(10);
|
||||
});
|
||||
|
||||
it('should return null for getLeaderUUID when USE_REDIS is false', async () => {
|
||||
// Create a few instances
|
||||
instances = Array.from({ length: 3 }, () => new LeaderElection());
|
||||
|
||||
// Call isLeader on all instances to make them "leaders"
|
||||
await Promise.all(instances.map((instance) => instance.isLeader()));
|
||||
|
||||
// Verify getLeaderUUID returns null in non-Redis mode
|
||||
expect(await LeaderElection.getLeaderUUID()).toBeNull();
|
||||
});
|
||||
|
||||
it('should allow resign() to be called without throwing errors', async () => {
|
||||
// Create multiple instances
|
||||
instances = Array.from({ length: 5 }, () => new LeaderElection());
|
||||
|
||||
// Make them all leaders
|
||||
await Promise.all(instances.map((instance) => instance.isLeader()));
|
||||
|
||||
// Call resign on all instances - should not throw
|
||||
await expect(
|
||||
Promise.all(instances.map((instance) => instance.resign())),
|
||||
).resolves.not.toThrow();
|
||||
|
||||
// Verify they're still leaders after resigning (since there's no shared state)
|
||||
const results = await Promise.all(instances.map((instance) => instance.isLeader()));
|
||||
expect(results.every((isLeader) => isLeader)).toBe(true);
|
||||
});
|
||||
});
|
||||
14
packages/api/src/cluster/config.ts
Normal file
14
packages/api/src/cluster/config.ts
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
import { math } from '~/utils';
|
||||
|
||||
const clusterConfig = {
|
||||
/** Duration in seconds that the leader lease is valid before it expires */
|
||||
LEADER_LEASE_DURATION: math(process.env.LEADER_LEASE_DURATION, 25),
|
||||
/** Interval in seconds at which the leader renews its lease */
|
||||
LEADER_RENEW_INTERVAL: math(process.env.LEADER_RENEW_INTERVAL, 10),
|
||||
/** Maximum number of retry attempts when renewing the lease fails */
|
||||
LEADER_RENEW_ATTEMPTS: math(process.env.LEADER_RENEW_ATTEMPTS, 3),
|
||||
/** Delay in seconds between retry attempts when renewing the lease */
|
||||
LEADER_RENEW_RETRY_DELAY: math(process.env.LEADER_RENEW_RETRY_DELAY, 0.5),
|
||||
};
|
||||
|
||||
export { clusterConfig };
|
||||
1
packages/api/src/cluster/index.ts
Normal file
1
packages/api/src/cluster/index.ts
Normal file
|
|
@ -0,0 +1 @@
|
|||
export { isLeader } from './LeaderElection';
|
||||
|
|
@ -4,6 +4,38 @@ import { anthropicSettings, removeNullishValues } from 'librechat-data-provider'
|
|||
import type { AnthropicLLMConfigResult, AnthropicConfigOptions } from '~/types/anthropic';
|
||||
import { checkPromptCacheSupport, getClaudeHeaders, configureReasoning } from './helpers';
|
||||
|
||||
/** Known Anthropic parameters that map directly to the client config */
|
||||
export const knownAnthropicParams = new Set([
|
||||
'model',
|
||||
'temperature',
|
||||
'topP',
|
||||
'topK',
|
||||
'maxTokens',
|
||||
'maxOutputTokens',
|
||||
'stopSequences',
|
||||
'stop',
|
||||
'stream',
|
||||
'apiKey',
|
||||
'maxRetries',
|
||||
'timeout',
|
||||
'anthropicVersion',
|
||||
'anthropicApiUrl',
|
||||
'defaultHeaders',
|
||||
]);
|
||||
|
||||
/**
|
||||
* Applies default parameters to the target object only if the field is undefined
|
||||
* @param target - The target object to apply defaults to
|
||||
* @param defaults - Record of default parameter values
|
||||
*/
|
||||
function applyDefaultParams(target: Record<string, unknown>, defaults: Record<string, unknown>) {
|
||||
for (const [key, value] of Object.entries(defaults)) {
|
||||
if (target[key] === undefined) {
|
||||
target[key] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates configuration options for creating an Anthropic language model (LLM) instance.
|
||||
* @param apiKey - The API key for authentication with Anthropic.
|
||||
|
|
@ -39,6 +71,8 @@ function getLLMConfig(
|
|||
|
||||
const mergedOptions = Object.assign(defaultOptions, options.modelOptions);
|
||||
|
||||
let enableWebSearch = mergedOptions.web_search;
|
||||
|
||||
let requestOptions: AnthropicClientOptions & { stream?: boolean } = {
|
||||
apiKey,
|
||||
model: mergedOptions.model,
|
||||
|
|
@ -84,9 +118,64 @@ function getLLMConfig(
|
|||
requestOptions.anthropicApiUrl = options.reverseProxyUrl;
|
||||
}
|
||||
|
||||
/** Handle defaultParams first - only process Anthropic-native params if undefined */
|
||||
if (options.defaultParams && typeof options.defaultParams === 'object') {
|
||||
for (const [key, value] of Object.entries(options.defaultParams)) {
|
||||
/** Handle web_search separately - don't add to config */
|
||||
if (key === 'web_search') {
|
||||
if (enableWebSearch === undefined && typeof value === 'boolean') {
|
||||
enableWebSearch = value;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (knownAnthropicParams.has(key)) {
|
||||
/** Route known Anthropic params to requestOptions only if undefined */
|
||||
applyDefaultParams(requestOptions as Record<string, unknown>, { [key]: value });
|
||||
}
|
||||
/** Leave other params for transform to handle - they might be OpenAI params */
|
||||
}
|
||||
}
|
||||
|
||||
/** Handle addParams - can override defaultParams */
|
||||
if (options.addParams && typeof options.addParams === 'object') {
|
||||
for (const [key, value] of Object.entries(options.addParams)) {
|
||||
/** Handle web_search separately - don't add to config */
|
||||
if (key === 'web_search') {
|
||||
if (typeof value === 'boolean') {
|
||||
enableWebSearch = value;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (knownAnthropicParams.has(key)) {
|
||||
/** Route known Anthropic params to requestOptions */
|
||||
(requestOptions as Record<string, unknown>)[key] = value;
|
||||
}
|
||||
/** Leave other params for transform to handle - they might be OpenAI params */
|
||||
}
|
||||
}
|
||||
|
||||
/** Handle dropParams - only drop from Anthropic config */
|
||||
if (options.dropParams && Array.isArray(options.dropParams)) {
|
||||
options.dropParams.forEach((param) => {
|
||||
if (param === 'web_search') {
|
||||
enableWebSearch = false;
|
||||
return;
|
||||
}
|
||||
|
||||
if (param in requestOptions) {
|
||||
delete requestOptions[param as keyof AnthropicClientOptions];
|
||||
}
|
||||
if (requestOptions.invocationKwargs && param in requestOptions.invocationKwargs) {
|
||||
delete (requestOptions.invocationKwargs as Record<string, unknown>)[param];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
const tools = [];
|
||||
|
||||
if (mergedOptions.web_search) {
|
||||
if (enableWebSearch) {
|
||||
tools.push({
|
||||
type: 'web_search_20250305',
|
||||
name: 'web_search',
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import { EModelEndpoint, extractEnvVariable } from 'librechat-data-provider';
|
||||
import { EModelEndpoint, extractEnvVariable, normalizeEndpointName } from 'librechat-data-provider';
|
||||
import type { TCustomEndpoints, TEndpoint, TConfig } from 'librechat-data-provider';
|
||||
import type { TCustomEndpointsConfig } from '~/types/endpoints';
|
||||
import { isUserProvided, normalizeEndpointName } from '~/utils';
|
||||
import { isUserProvided } from '~/utils';
|
||||
|
||||
/**
|
||||
* Load config endpoints from the cached configuration object
|
||||
|
|
|
|||
|
|
@ -5,6 +5,46 @@ import type { GoogleAIToolType } from '@langchain/google-common';
|
|||
import type * as t from '~/types';
|
||||
import { isEnabled } from '~/utils';
|
||||
|
||||
/** Known Google/Vertex AI parameters that map directly to the client config */
|
||||
export const knownGoogleParams = new Set([
|
||||
'model',
|
||||
'modelName',
|
||||
'temperature',
|
||||
'maxOutputTokens',
|
||||
'maxReasoningTokens',
|
||||
'topP',
|
||||
'topK',
|
||||
'seed',
|
||||
'presencePenalty',
|
||||
'frequencyPenalty',
|
||||
'stopSequences',
|
||||
'stop',
|
||||
'logprobs',
|
||||
'topLogprobs',
|
||||
'safetySettings',
|
||||
'responseModalities',
|
||||
'convertSystemMessageToHumanContent',
|
||||
'speechConfig',
|
||||
'streamUsage',
|
||||
'apiKey',
|
||||
'baseUrl',
|
||||
'location',
|
||||
'authOptions',
|
||||
]);
|
||||
|
||||
/**
|
||||
* Applies default parameters to the target object only if the field is undefined
|
||||
* @param target - The target object to apply defaults to
|
||||
* @param defaults - Record of default parameter values
|
||||
*/
|
||||
function applyDefaultParams(target: Record<string, unknown>, defaults: Record<string, unknown>) {
|
||||
for (const [key, value] of Object.entries(defaults)) {
|
||||
if (target[key] === undefined) {
|
||||
target[key] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function getThresholdMapping(model: string) {
|
||||
const gemini1Pattern = /gemini-(1\.0|1\.5|pro$|1\.0-pro|1\.5-pro|1\.5-flash-001)/;
|
||||
const restrictedPattern = /(gemini-(1\.5-flash-8b|2\.0|exp)|learnlm)/;
|
||||
|
|
@ -112,6 +152,8 @@ export function getGoogleConfig(
|
|||
...modelOptions
|
||||
} = options.modelOptions || {};
|
||||
|
||||
let enableWebSearch = web_search;
|
||||
|
||||
const llmConfig: GoogleClientOptions | VertexAIClientOptions = removeNullishValues({
|
||||
...(modelOptions || {}),
|
||||
model: modelOptions?.model ?? '',
|
||||
|
|
@ -193,9 +235,61 @@ export function getGoogleConfig(
|
|||
};
|
||||
}
|
||||
|
||||
/** Handle defaultParams first - only process Google-native params if undefined */
|
||||
if (options.defaultParams && typeof options.defaultParams === 'object') {
|
||||
for (const [key, value] of Object.entries(options.defaultParams)) {
|
||||
/** Handle web_search separately - don't add to config */
|
||||
if (key === 'web_search') {
|
||||
if (enableWebSearch === undefined && typeof value === 'boolean') {
|
||||
enableWebSearch = value;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (knownGoogleParams.has(key)) {
|
||||
/** Route known Google params to llmConfig only if undefined */
|
||||
applyDefaultParams(llmConfig as Record<string, unknown>, { [key]: value });
|
||||
}
|
||||
/** Leave other params for transform to handle - they might be OpenAI params */
|
||||
}
|
||||
}
|
||||
|
||||
/** Handle addParams - can override defaultParams */
|
||||
if (options.addParams && typeof options.addParams === 'object') {
|
||||
for (const [key, value] of Object.entries(options.addParams)) {
|
||||
/** Handle web_search separately - don't add to config */
|
||||
if (key === 'web_search') {
|
||||
if (typeof value === 'boolean') {
|
||||
enableWebSearch = value;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (knownGoogleParams.has(key)) {
|
||||
/** Route known Google params to llmConfig */
|
||||
(llmConfig as Record<string, unknown>)[key] = value;
|
||||
}
|
||||
/** Leave other params for transform to handle - they might be OpenAI params */
|
||||
}
|
||||
}
|
||||
|
||||
/** Handle dropParams - only drop from Google config */
|
||||
if (options.dropParams && Array.isArray(options.dropParams)) {
|
||||
options.dropParams.forEach((param) => {
|
||||
if (param === 'web_search') {
|
||||
enableWebSearch = false;
|
||||
return;
|
||||
}
|
||||
|
||||
if (param in llmConfig) {
|
||||
delete (llmConfig as Record<string, unknown>)[param];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
const tools: GoogleAIToolType[] = [];
|
||||
|
||||
if (web_search) {
|
||||
if (enableWebSearch) {
|
||||
tools.push({ googleSearch: {} });
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -548,4 +548,375 @@ describe('getOpenAIConfig - Anthropic Compatibility', () => {
|
|||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Web Search Support via addParams', () => {
|
||||
it('should enable web_search tool when web_search: true in addParams', () => {
|
||||
const apiKey = 'sk-web-search';
|
||||
const endpoint = 'Anthropic (Custom)';
|
||||
const options = {
|
||||
modelOptions: {
|
||||
model: 'claude-3-5-sonnet-latest',
|
||||
user: 'search-user',
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'anthropic',
|
||||
},
|
||||
addParams: {
|
||||
web_search: true,
|
||||
},
|
||||
};
|
||||
|
||||
const result = getOpenAIConfig(apiKey, options, endpoint);
|
||||
|
||||
expect(result.tools).toEqual([
|
||||
{
|
||||
type: 'web_search_20250305',
|
||||
name: 'web_search',
|
||||
},
|
||||
]);
|
||||
expect(result.llmConfig).toMatchObject({
|
||||
model: 'claude-3-5-sonnet-latest',
|
||||
stream: true,
|
||||
});
|
||||
});
|
||||
|
||||
it('should disable web_search tool when web_search: false in addParams', () => {
|
||||
const apiKey = 'sk-no-search';
|
||||
const endpoint = 'Anthropic (Custom)';
|
||||
const options = {
|
||||
modelOptions: {
|
||||
model: 'claude-3-opus-20240229',
|
||||
web_search: true, // This should be overridden by addParams
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'anthropic',
|
||||
},
|
||||
addParams: {
|
||||
web_search: false,
|
||||
},
|
||||
};
|
||||
|
||||
const result = getOpenAIConfig(apiKey, options, endpoint);
|
||||
|
||||
expect(result.tools).toEqual([]);
|
||||
});
|
||||
|
||||
it('should disable web_search when in dropParams', () => {
|
||||
const apiKey = 'sk-drop-search';
|
||||
const endpoint = 'Anthropic (Custom)';
|
||||
const options = {
|
||||
modelOptions: {
|
||||
model: 'claude-3-5-sonnet-latest',
|
||||
web_search: true,
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'anthropic',
|
||||
},
|
||||
dropParams: ['web_search'],
|
||||
};
|
||||
|
||||
const result = getOpenAIConfig(apiKey, options, endpoint);
|
||||
|
||||
expect(result.tools).toEqual([]);
|
||||
});
|
||||
|
||||
it('should handle web_search with mixed Anthropic and OpenAI params in addParams', () => {
|
||||
const apiKey = 'sk-mixed';
|
||||
const endpoint = 'Anthropic (Custom)';
|
||||
const options = {
|
||||
modelOptions: {
|
||||
model: 'claude-3-opus-20240229',
|
||||
user: 'mixed-user',
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'anthropic',
|
||||
},
|
||||
addParams: {
|
||||
web_search: true,
|
||||
temperature: 0.7, // Anthropic native
|
||||
maxRetries: 3, // OpenAI param (known), should go to top level
|
||||
customParam: 'custom', // Unknown param, should go to modelKwargs
|
||||
},
|
||||
};
|
||||
|
||||
const result = getOpenAIConfig(apiKey, options, endpoint);
|
||||
|
||||
expect(result.tools).toEqual([
|
||||
{
|
||||
type: 'web_search_20250305',
|
||||
name: 'web_search',
|
||||
},
|
||||
]);
|
||||
expect(result.llmConfig.temperature).toBe(0.7);
|
||||
expect(result.llmConfig.maxRetries).toBe(3); // Known OpenAI param at top level
|
||||
expect(result.llmConfig.modelKwargs).toMatchObject({
|
||||
customParam: 'custom', // Unknown param in modelKwargs
|
||||
metadata: { user_id: 'mixed-user' }, // From invocationKwargs
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle Anthropic native params in addParams without web_search', () => {
|
||||
const apiKey = 'sk-native';
|
||||
const endpoint = 'Anthropic (Custom)';
|
||||
const options = {
|
||||
modelOptions: {
|
||||
model: 'claude-3-opus-20240229',
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'anthropic',
|
||||
},
|
||||
addParams: {
|
||||
temperature: 0.9,
|
||||
topP: 0.95,
|
||||
maxTokens: 4096,
|
||||
},
|
||||
};
|
||||
|
||||
const result = getOpenAIConfig(apiKey, options, endpoint);
|
||||
|
||||
expect(result.llmConfig).toMatchObject({
|
||||
model: 'claude-3-opus-20240229',
|
||||
temperature: 0.9,
|
||||
topP: 0.95,
|
||||
maxTokens: 4096,
|
||||
});
|
||||
expect(result.tools).toEqual([]);
|
||||
});
|
||||
|
||||
describe('defaultParams Support via customParams', () => {
|
||||
it('should apply defaultParams when fields are undefined', () => {
|
||||
const apiKey = 'sk-defaults';
|
||||
const result = getOpenAIConfig(apiKey, {
|
||||
modelOptions: {
|
||||
model: 'claude-3-5-sonnet-20241022',
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'anthropic',
|
||||
paramDefinitions: [
|
||||
{ key: 'temperature', default: 0.7 },
|
||||
{ key: 'topP', default: 0.9 },
|
||||
{ key: 'maxRetries', default: 5 },
|
||||
],
|
||||
},
|
||||
reverseProxyUrl: 'https://api.anthropic.com',
|
||||
});
|
||||
|
||||
expect(result.llmConfig.temperature).toBe(0.7);
|
||||
expect(result.llmConfig.topP).toBe(0.9);
|
||||
expect(result.llmConfig.maxRetries).toBe(5);
|
||||
});
|
||||
|
||||
it('should not override existing modelOptions with defaultParams', () => {
|
||||
const apiKey = 'sk-override';
|
||||
const result = getOpenAIConfig(apiKey, {
|
||||
modelOptions: {
|
||||
model: 'claude-3-5-sonnet-20241022',
|
||||
temperature: 0.9,
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'anthropic',
|
||||
paramDefinitions: [
|
||||
{ key: 'temperature', default: 0.5 },
|
||||
{ key: 'topP', default: 0.8 },
|
||||
],
|
||||
},
|
||||
reverseProxyUrl: 'https://api.anthropic.com',
|
||||
});
|
||||
|
||||
expect(result.llmConfig.temperature).toBe(0.9);
|
||||
expect(result.llmConfig.topP).toBe(0.8);
|
||||
});
|
||||
|
||||
it('should allow addParams to override defaultParams', () => {
|
||||
const apiKey = 'sk-add-override';
|
||||
const result = getOpenAIConfig(apiKey, {
|
||||
modelOptions: {
|
||||
model: 'claude-3-opus-20240229',
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'anthropic',
|
||||
paramDefinitions: [
|
||||
{ key: 'temperature', default: 0.5 },
|
||||
{ key: 'topP', default: 0.7 },
|
||||
],
|
||||
},
|
||||
addParams: {
|
||||
temperature: 0.8,
|
||||
topP: 0.95,
|
||||
},
|
||||
reverseProxyUrl: 'https://api.anthropic.com',
|
||||
});
|
||||
|
||||
expect(result.llmConfig.temperature).toBe(0.8);
|
||||
expect(result.llmConfig.topP).toBe(0.95);
|
||||
});
|
||||
|
||||
it('should handle defaultParams with web_search', () => {
|
||||
const apiKey = 'sk-web-default';
|
||||
const result = getOpenAIConfig(apiKey, {
|
||||
modelOptions: {
|
||||
model: 'claude-3-5-sonnet-latest',
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'anthropic',
|
||||
paramDefinitions: [{ key: 'web_search', default: true }],
|
||||
},
|
||||
reverseProxyUrl: 'https://api.anthropic.com',
|
||||
});
|
||||
|
||||
expect(result.tools).toEqual([
|
||||
{
|
||||
type: 'web_search_20250305',
|
||||
name: 'web_search',
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('should allow addParams to override defaultParams web_search', () => {
|
||||
const apiKey = 'sk-web-override';
|
||||
const result = getOpenAIConfig(apiKey, {
|
||||
modelOptions: {
|
||||
model: 'claude-3-opus-20240229',
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'anthropic',
|
||||
paramDefinitions: [{ key: 'web_search', default: true }],
|
||||
},
|
||||
addParams: {
|
||||
web_search: false,
|
||||
},
|
||||
reverseProxyUrl: 'https://api.anthropic.com',
|
||||
});
|
||||
|
||||
expect(result.tools).toEqual([]);
|
||||
});
|
||||
|
||||
it('should handle dropParams overriding defaultParams', () => {
|
||||
const apiKey = 'sk-drop';
|
||||
const result = getOpenAIConfig(apiKey, {
|
||||
modelOptions: {
|
||||
model: 'claude-3-opus-20240229',
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'anthropic',
|
||||
paramDefinitions: [
|
||||
{ key: 'temperature', default: 0.7 },
|
||||
{ key: 'topP', default: 0.9 },
|
||||
{ key: 'web_search', default: true },
|
||||
],
|
||||
},
|
||||
dropParams: ['topP', 'web_search'],
|
||||
reverseProxyUrl: 'https://api.anthropic.com',
|
||||
});
|
||||
|
||||
expect(result.llmConfig.temperature).toBe(0.7);
|
||||
expect(result.llmConfig.topP).toBeUndefined();
|
||||
expect(result.tools).toEqual([]);
|
||||
});
|
||||
|
||||
it('should preserve order: defaultParams < addParams < modelOptions', () => {
|
||||
const apiKey = 'sk-precedence';
|
||||
const result = getOpenAIConfig(apiKey, {
|
||||
modelOptions: {
|
||||
model: 'claude-3-5-sonnet-20241022',
|
||||
temperature: 0.9,
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'anthropic',
|
||||
paramDefinitions: [
|
||||
{ key: 'temperature', default: 0.3 },
|
||||
{ key: 'topP', default: 0.5 },
|
||||
{ key: 'timeout', default: 60000 },
|
||||
],
|
||||
},
|
||||
addParams: {
|
||||
topP: 0.8,
|
||||
},
|
||||
reverseProxyUrl: 'https://api.anthropic.com',
|
||||
});
|
||||
|
||||
expect(result.llmConfig.temperature).toBe(0.9);
|
||||
expect(result.llmConfig.topP).toBe(0.8);
|
||||
expect(result.llmConfig.timeout).toBe(60000);
|
||||
});
|
||||
|
||||
it('should handle Claude 3.7 with defaultParams and thinking disabled', () => {
|
||||
const apiKey = 'sk-37-defaults';
|
||||
const result = getOpenAIConfig(apiKey, {
|
||||
modelOptions: {
|
||||
model: 'claude-3.7-sonnet-20241022',
|
||||
thinking: false,
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'anthropic',
|
||||
paramDefinitions: [
|
||||
{ key: 'temperature', default: 0.7 },
|
||||
{ key: 'topP', default: 0.9 },
|
||||
{ key: 'topK', default: 50 },
|
||||
],
|
||||
},
|
||||
reverseProxyUrl: 'https://api.anthropic.com',
|
||||
});
|
||||
|
||||
expect(result.llmConfig.temperature).toBe(0.7);
|
||||
expect(result.llmConfig.topP).toBe(0.9);
|
||||
expect(result.llmConfig.modelKwargs?.topK).toBe(50);
|
||||
});
|
||||
|
||||
it('should handle empty paramDefinitions', () => {
|
||||
const apiKey = 'sk-empty';
|
||||
const result = getOpenAIConfig(apiKey, {
|
||||
modelOptions: {
|
||||
model: 'claude-3-opus-20240229',
|
||||
temperature: 0.8,
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'anthropic',
|
||||
paramDefinitions: [],
|
||||
},
|
||||
reverseProxyUrl: 'https://api.anthropic.com',
|
||||
});
|
||||
|
||||
expect(result.llmConfig.temperature).toBe(0.8);
|
||||
});
|
||||
|
||||
it('should handle missing paramDefinitions', () => {
|
||||
const apiKey = 'sk-missing';
|
||||
const result = getOpenAIConfig(apiKey, {
|
||||
modelOptions: {
|
||||
model: 'claude-3-opus-20240229',
|
||||
temperature: 0.8,
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'anthropic',
|
||||
},
|
||||
reverseProxyUrl: 'https://api.anthropic.com',
|
||||
});
|
||||
|
||||
expect(result.llmConfig.temperature).toBe(0.8);
|
||||
});
|
||||
|
||||
it('should handle mixed Anthropic params in defaultParams', () => {
|
||||
const apiKey = 'sk-mixed';
|
||||
const result = getOpenAIConfig(apiKey, {
|
||||
modelOptions: {
|
||||
model: 'claude-3-5-sonnet-20241022',
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'anthropic',
|
||||
paramDefinitions: [
|
||||
{ key: 'temperature', default: 0.7 },
|
||||
{ key: 'topP', default: 0.9 },
|
||||
{ key: 'maxRetries', default: 3 },
|
||||
],
|
||||
},
|
||||
reverseProxyUrl: 'https://api.anthropic.com',
|
||||
});
|
||||
|
||||
expect(result.llmConfig.temperature).toBe(0.7);
|
||||
expect(result.llmConfig.topP).toBe(0.9);
|
||||
expect(result.llmConfig.maxRetries).toBe(3);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ describe('getOpenAIConfig - Backward Compatibility', () => {
|
|||
configOptions: {},
|
||||
tools: [
|
||||
{
|
||||
type: 'web_search_preview',
|
||||
type: 'web_search',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
|
|
|||
387
packages/api/src/endpoints/openai/config.google.spec.ts
Normal file
387
packages/api/src/endpoints/openai/config.google.spec.ts
Normal file
|
|
@ -0,0 +1,387 @@
|
|||
import { getOpenAIConfig } from './config';
|
||||
|
||||
describe('getOpenAIConfig - Google Compatibility', () => {
|
||||
describe('Google via Custom Endpoint', () => {
|
||||
describe('Web Search Support via addParams', () => {
|
||||
it('should enable googleSearch tool when web_search: true in addParams', () => {
|
||||
const apiKey = JSON.stringify({ GOOGLE_API_KEY: 'test-google-key' });
|
||||
const endpoint = 'Gemini (Custom)';
|
||||
const options = {
|
||||
modelOptions: {
|
||||
model: 'gemini-2.0-flash-exp',
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'google',
|
||||
},
|
||||
addParams: {
|
||||
web_search: true,
|
||||
},
|
||||
reverseProxyUrl: 'https://generativelanguage.googleapis.com/v1beta/openai',
|
||||
};
|
||||
|
||||
const result = getOpenAIConfig(apiKey, options, endpoint);
|
||||
|
||||
expect(result.tools).toEqual([{ googleSearch: {} }]);
|
||||
expect(result.llmConfig).toMatchObject({
|
||||
model: 'gemini-2.0-flash-exp',
|
||||
});
|
||||
});
|
||||
|
||||
it('should disable googleSearch tool when web_search: false in addParams', () => {
|
||||
const apiKey = JSON.stringify({ GOOGLE_API_KEY: 'test-google-key' });
|
||||
const endpoint = 'Gemini (Custom)';
|
||||
const options = {
|
||||
modelOptions: {
|
||||
model: 'gemini-2.0-flash-exp',
|
||||
web_search: true, // Should be overridden by addParams
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'google',
|
||||
},
|
||||
addParams: {
|
||||
web_search: false,
|
||||
},
|
||||
reverseProxyUrl: 'https://generativelanguage.googleapis.com/v1beta/openai',
|
||||
};
|
||||
|
||||
const result = getOpenAIConfig(apiKey, options, endpoint);
|
||||
|
||||
expect(result.tools).toEqual([]);
|
||||
});
|
||||
|
||||
it('should disable googleSearch when in dropParams', () => {
|
||||
const apiKey = JSON.stringify({ GOOGLE_API_KEY: 'test-google-key' });
|
||||
const endpoint = 'Gemini (Custom)';
|
||||
const options = {
|
||||
modelOptions: {
|
||||
model: 'gemini-2.0-flash-exp',
|
||||
web_search: true,
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'google',
|
||||
},
|
||||
dropParams: ['web_search'],
|
||||
reverseProxyUrl: 'https://generativelanguage.googleapis.com/v1beta/openai',
|
||||
};
|
||||
|
||||
const result = getOpenAIConfig(apiKey, options, endpoint);
|
||||
|
||||
expect(result.tools).toEqual([]);
|
||||
});
|
||||
|
||||
it('should handle web_search with mixed Google and OpenAI params in addParams', () => {
|
||||
const apiKey = JSON.stringify({ GOOGLE_API_KEY: 'test-google-key' });
|
||||
const endpoint = 'Gemini (Custom)';
|
||||
const options = {
|
||||
modelOptions: {
|
||||
model: 'gemini-2.0-flash-exp',
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'google',
|
||||
},
|
||||
addParams: {
|
||||
web_search: true,
|
||||
temperature: 0.8, // Shared param (both Google and OpenAI)
|
||||
topK: 40, // Google-only param, goes to modelKwargs
|
||||
frequencyPenalty: 0.5, // Known OpenAI param, goes to top level
|
||||
customUnknown: 'test', // Unknown param, goes to modelKwargs
|
||||
},
|
||||
reverseProxyUrl: 'https://generativelanguage.googleapis.com/v1beta/openai',
|
||||
};
|
||||
|
||||
const result = getOpenAIConfig(apiKey, options, endpoint);
|
||||
|
||||
expect(result.tools).toEqual([{ googleSearch: {} }]);
|
||||
expect(result.llmConfig.temperature).toBe(0.8); // Shared param at top level
|
||||
expect(result.llmConfig.frequencyPenalty).toBe(0.5); // Known OpenAI param at top level
|
||||
expect(result.llmConfig.modelKwargs).toMatchObject({
|
||||
topK: 40, // Google-specific in modelKwargs
|
||||
customUnknown: 'test', // Unknown param in modelKwargs
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle Google native params in addParams without web_search', () => {
|
||||
const apiKey = JSON.stringify({ GOOGLE_API_KEY: 'test-google-key' });
|
||||
const endpoint = 'Gemini (Custom)';
|
||||
const options = {
|
||||
modelOptions: {
|
||||
model: 'gemini-2.0-flash-exp',
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'google',
|
||||
},
|
||||
addParams: {
|
||||
temperature: 0.9, // Shared param (both Google and OpenAI)
|
||||
topP: 0.95, // Shared param (both Google and OpenAI)
|
||||
topK: 50, // Google-only, goes to modelKwargs
|
||||
maxOutputTokens: 8192, // Google-only, goes to modelKwargs
|
||||
},
|
||||
reverseProxyUrl: 'https://generativelanguage.googleapis.com/v1beta/openai',
|
||||
};
|
||||
|
||||
const result = getOpenAIConfig(apiKey, options, endpoint);
|
||||
|
||||
expect(result.llmConfig).toMatchObject({
|
||||
model: 'gemini-2.0-flash-exp',
|
||||
temperature: 0.9, // Shared params at top level
|
||||
topP: 0.95,
|
||||
});
|
||||
expect(result.llmConfig.modelKwargs).toMatchObject({
|
||||
topK: 50, // Google-specific in modelKwargs
|
||||
maxOutputTokens: 8192,
|
||||
});
|
||||
expect(result.tools).toEqual([]);
|
||||
});
|
||||
|
||||
it('should drop Google native params with dropParams', () => {
|
||||
const apiKey = JSON.stringify({ GOOGLE_API_KEY: 'test-google-key' });
|
||||
const endpoint = 'Gemini (Custom)';
|
||||
const options = {
|
||||
modelOptions: {
|
||||
model: 'gemini-2.0-flash-exp',
|
||||
temperature: 0.7,
|
||||
topK: 40,
|
||||
topP: 0.9,
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'google',
|
||||
},
|
||||
dropParams: ['topK', 'topP'],
|
||||
reverseProxyUrl: 'https://generativelanguage.googleapis.com/v1beta/openai',
|
||||
};
|
||||
|
||||
const result = getOpenAIConfig(apiKey, options, endpoint);
|
||||
|
||||
expect(result.llmConfig.temperature).toBe(0.7);
|
||||
expect((result.llmConfig as Record<string, unknown>).topK).toBeUndefined();
|
||||
expect(result.llmConfig.topP).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should handle both addParams and dropParams for Google', () => {
|
||||
const apiKey = JSON.stringify({ GOOGLE_API_KEY: 'test-google-key' });
|
||||
const endpoint = 'Gemini (Custom)';
|
||||
const options = {
|
||||
modelOptions: {
|
||||
model: 'gemini-2.0-flash-exp',
|
||||
topK: 30, // Will be dropped
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'google',
|
||||
},
|
||||
addParams: {
|
||||
web_search: true,
|
||||
temperature: 0.8, // Shared param
|
||||
maxOutputTokens: 4096, // Google-only param
|
||||
},
|
||||
dropParams: ['topK'],
|
||||
reverseProxyUrl: 'https://generativelanguage.googleapis.com/v1beta/openai',
|
||||
};
|
||||
|
||||
const result = getOpenAIConfig(apiKey, options, endpoint);
|
||||
|
||||
expect(result.tools).toEqual([{ googleSearch: {} }]);
|
||||
expect(result.llmConfig).toMatchObject({
|
||||
model: 'gemini-2.0-flash-exp',
|
||||
temperature: 0.8,
|
||||
});
|
||||
expect(result.llmConfig.modelKwargs).toMatchObject({
|
||||
maxOutputTokens: 4096, // Google-specific in modelKwargs
|
||||
});
|
||||
expect((result.llmConfig as Record<string, unknown>).topK).toBeUndefined();
|
||||
// Verify topK is not in modelKwargs either
|
||||
expect(result.llmConfig.modelKwargs?.topK).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('defaultParams Support via customParams', () => {
|
||||
it('should apply defaultParams when fields are undefined', () => {
|
||||
const apiKey = JSON.stringify({ GOOGLE_API_KEY: 'test-google-key' });
|
||||
const result = getOpenAIConfig(apiKey, {
|
||||
modelOptions: {
|
||||
model: 'gemini-2.0-flash-exp',
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'google',
|
||||
paramDefinitions: [
|
||||
{ key: 'temperature', default: 0.6 },
|
||||
{ key: 'topK', default: 40 },
|
||||
],
|
||||
},
|
||||
reverseProxyUrl: 'https://generativelanguage.googleapis.com/v1beta/openai',
|
||||
});
|
||||
|
||||
expect(result.llmConfig.temperature).toBe(0.6);
|
||||
expect(result.llmConfig.modelKwargs?.topK).toBe(40);
|
||||
});
|
||||
|
||||
it('should not override existing modelOptions with defaultParams', () => {
|
||||
const apiKey = JSON.stringify({ GOOGLE_API_KEY: 'test-google-key' });
|
||||
const result = getOpenAIConfig(apiKey, {
|
||||
modelOptions: {
|
||||
model: 'gemini-2.0-flash-exp',
|
||||
temperature: 0.9,
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'google',
|
||||
paramDefinitions: [
|
||||
{ key: 'temperature', default: 0.5 },
|
||||
{ key: 'topK', default: 40 },
|
||||
],
|
||||
},
|
||||
reverseProxyUrl: 'https://generativelanguage.googleapis.com/v1beta/openai',
|
||||
});
|
||||
|
||||
expect(result.llmConfig.temperature).toBe(0.9);
|
||||
expect(result.llmConfig.modelKwargs?.topK).toBe(40);
|
||||
});
|
||||
|
||||
it('should allow addParams to override defaultParams', () => {
|
||||
const apiKey = JSON.stringify({ GOOGLE_API_KEY: 'test-google-key' });
|
||||
|
||||
const result = getOpenAIConfig(apiKey, {
|
||||
modelOptions: {
|
||||
model: 'gemini-2.0-flash-exp',
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'google',
|
||||
paramDefinitions: [
|
||||
{ key: 'temperature', default: 0.5 },
|
||||
{ key: 'topK', default: 30 },
|
||||
],
|
||||
},
|
||||
addParams: {
|
||||
temperature: 0.8,
|
||||
topK: 50,
|
||||
},
|
||||
reverseProxyUrl: 'https://generativelanguage.googleapis.com/v1beta/openai',
|
||||
});
|
||||
|
||||
expect(result.llmConfig.temperature).toBe(0.8);
|
||||
expect(result.llmConfig.modelKwargs?.topK).toBe(50);
|
||||
});
|
||||
|
||||
it('should handle defaultParams with web_search', () => {
|
||||
const apiKey = JSON.stringify({ GOOGLE_API_KEY: 'test-google-key' });
|
||||
|
||||
const result = getOpenAIConfig(apiKey, {
|
||||
modelOptions: {
|
||||
model: 'gemini-2.0-flash-exp',
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'google',
|
||||
paramDefinitions: [{ key: 'web_search', default: true }],
|
||||
},
|
||||
reverseProxyUrl: 'https://generativelanguage.googleapis.com/v1beta/openai',
|
||||
});
|
||||
|
||||
expect(result.tools).toEqual([{ googleSearch: {} }]);
|
||||
});
|
||||
|
||||
it('should allow addParams to override defaultParams web_search', () => {
|
||||
const apiKey = JSON.stringify({ GOOGLE_API_KEY: 'test-google-key' });
|
||||
|
||||
const result = getOpenAIConfig(apiKey, {
|
||||
modelOptions: {
|
||||
model: 'gemini-2.0-flash-exp',
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'google',
|
||||
paramDefinitions: [{ key: 'web_search', default: true }],
|
||||
},
|
||||
addParams: {
|
||||
web_search: false,
|
||||
},
|
||||
reverseProxyUrl: 'https://generativelanguage.googleapis.com/v1beta/openai',
|
||||
});
|
||||
|
||||
expect(result.tools).toEqual([]);
|
||||
});
|
||||
|
||||
it('should handle dropParams overriding defaultParams', () => {
|
||||
const apiKey = JSON.stringify({ GOOGLE_API_KEY: 'test-google-key' });
|
||||
|
||||
const result = getOpenAIConfig(apiKey, {
|
||||
modelOptions: {
|
||||
model: 'gemini-2.0-flash-exp',
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'google',
|
||||
paramDefinitions: [
|
||||
{ key: 'temperature', default: 0.7 },
|
||||
{ key: 'topK', default: 40 },
|
||||
{ key: 'web_search', default: true },
|
||||
],
|
||||
},
|
||||
dropParams: ['topK', 'web_search'],
|
||||
reverseProxyUrl: 'https://generativelanguage.googleapis.com/v1beta/openai',
|
||||
});
|
||||
|
||||
expect(result.llmConfig.temperature).toBe(0.7);
|
||||
expect(result.llmConfig.modelKwargs?.topK).toBeUndefined();
|
||||
expect(result.tools).toEqual([]);
|
||||
});
|
||||
|
||||
it('should preserve order: defaultParams < addParams < modelOptions', () => {
|
||||
const apiKey = JSON.stringify({ GOOGLE_API_KEY: 'test-google-key' });
|
||||
|
||||
const result = getOpenAIConfig(apiKey, {
|
||||
modelOptions: {
|
||||
model: 'gemini-2.0-flash-exp',
|
||||
temperature: 0.9,
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'google',
|
||||
paramDefinitions: [
|
||||
{ key: 'temperature', default: 0.3 },
|
||||
{ key: 'topP', default: 0.5 },
|
||||
{ key: 'topK', default: 20 },
|
||||
],
|
||||
},
|
||||
addParams: {
|
||||
topP: 0.8,
|
||||
},
|
||||
reverseProxyUrl: 'https://generativelanguage.googleapis.com/v1beta/openai',
|
||||
});
|
||||
|
||||
expect(result.llmConfig.temperature).toBe(0.9);
|
||||
expect(result.llmConfig.topP).toBe(0.8);
|
||||
expect(result.llmConfig.modelKwargs?.topK).toBe(20);
|
||||
});
|
||||
|
||||
it('should handle empty paramDefinitions', () => {
|
||||
const apiKey = JSON.stringify({ GOOGLE_API_KEY: 'test-google-key' });
|
||||
|
||||
const result = getOpenAIConfig(apiKey, {
|
||||
modelOptions: {
|
||||
model: 'gemini-2.0-flash-exp',
|
||||
temperature: 0.8,
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'google',
|
||||
paramDefinitions: [],
|
||||
},
|
||||
reverseProxyUrl: 'https://generativelanguage.googleapis.com/v1beta/openai',
|
||||
});
|
||||
|
||||
expect(result.llmConfig.temperature).toBe(0.8);
|
||||
});
|
||||
|
||||
it('should handle missing paramDefinitions', () => {
|
||||
const apiKey = JSON.stringify({ GOOGLE_API_KEY: 'test-google-key' });
|
||||
|
||||
const result = getOpenAIConfig(apiKey, {
|
||||
modelOptions: {
|
||||
model: 'gemini-2.0-flash-exp',
|
||||
temperature: 0.8,
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'google',
|
||||
},
|
||||
reverseProxyUrl: 'https://generativelanguage.googleapis.com/v1beta/openai',
|
||||
});
|
||||
|
||||
expect(result.llmConfig.temperature).toBe(0.8);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -230,7 +230,7 @@ describe('getOpenAIConfig', () => {
|
|||
const result = getOpenAIConfig(mockApiKey, { modelOptions });
|
||||
|
||||
expect(result.llmConfig.useResponsesApi).toBe(true);
|
||||
expect(result.tools).toEqual([{ type: 'web_search_preview' }]);
|
||||
expect(result.tools).toEqual([{ type: 'web_search' }]);
|
||||
});
|
||||
|
||||
it('should handle web_search from addParams overriding modelOptions', () => {
|
||||
|
|
@ -247,7 +247,7 @@ describe('getOpenAIConfig', () => {
|
|||
const result = getOpenAIConfig(mockApiKey, { modelOptions, addParams });
|
||||
|
||||
expect(result.llmConfig.useResponsesApi).toBe(true);
|
||||
expect(result.tools).toEqual([{ type: 'web_search_preview' }]);
|
||||
expect(result.tools).toEqual([{ type: 'web_search' }]);
|
||||
// web_search should not be in modelKwargs or llmConfig
|
||||
expect((result.llmConfig as Record<string, unknown>).web_search).toBeUndefined();
|
||||
expect(result.llmConfig.modelKwargs).toEqual({ customParam: 'value' });
|
||||
|
|
@ -299,7 +299,7 @@ describe('getOpenAIConfig', () => {
|
|||
|
||||
// Should keep the original web_search from modelOptions since addParams value is not boolean
|
||||
expect(result.llmConfig.useResponsesApi).toBe(true);
|
||||
expect(result.tools).toEqual([{ type: 'web_search_preview' }]);
|
||||
expect(result.tools).toEqual([{ type: 'web_search' }]);
|
||||
expect(result.llmConfig.temperature).toBe(0.7);
|
||||
// web_search should not be added to modelKwargs
|
||||
expect(result.llmConfig.modelKwargs).toBeUndefined();
|
||||
|
|
@ -335,7 +335,7 @@ describe('getOpenAIConfig', () => {
|
|||
|
||||
// web_search should trigger the tool but not appear in config
|
||||
expect(result.llmConfig.useResponsesApi).toBe(true);
|
||||
expect(result.tools).toEqual([{ type: 'web_search_preview' }]);
|
||||
expect(result.tools).toEqual([{ type: 'web_search' }]);
|
||||
expect((result.llmConfig as Record<string, unknown>).web_search).toBeUndefined();
|
||||
expect(result.llmConfig.temperature).toBe(0.5);
|
||||
expect(result.llmConfig.modelKwargs).toEqual({ customParam1: 'value1' });
|
||||
|
|
@ -1164,7 +1164,7 @@ describe('getOpenAIConfig', () => {
|
|||
text: { verbosity: Verbosity.medium },
|
||||
customParam: 'custom-value',
|
||||
});
|
||||
expect(result.tools).toEqual([{ type: 'web_search_preview' }]);
|
||||
expect(result.tools).toEqual([{ type: 'web_search' }]);
|
||||
expect(result.configOptions).toMatchObject({
|
||||
baseURL: 'https://api.custom.com',
|
||||
defaultHeaders: { 'X-Custom': 'value' },
|
||||
|
|
@ -1651,6 +1651,211 @@ describe('getOpenAIConfig', () => {
|
|||
expect(result.llmConfig.modelKwargs).toEqual(largeModelKwargs);
|
||||
});
|
||||
});
|
||||
|
||||
describe('defaultParams Support via customParams', () => {
|
||||
it('should apply defaultParams when fields are undefined', () => {
|
||||
const result = getOpenAIConfig(mockApiKey, {
|
||||
modelOptions: {
|
||||
model: 'gpt-4o',
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'azureOpenAI',
|
||||
paramDefinitions: [
|
||||
{ key: 'useResponsesApi', default: true },
|
||||
{ key: 'temperature', default: 0.5 },
|
||||
],
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig.useResponsesApi).toBe(true);
|
||||
expect(result.llmConfig.temperature).toBe(0.5);
|
||||
});
|
||||
|
||||
it('should not override existing modelOptions with defaultParams', () => {
|
||||
const result = getOpenAIConfig(mockApiKey, {
|
||||
modelOptions: {
|
||||
model: 'gpt-5',
|
||||
temperature: 0.9,
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'azureOpenAI',
|
||||
paramDefinitions: [
|
||||
{ key: 'temperature', default: 0.5 },
|
||||
{ key: 'maxTokens', default: 1000 },
|
||||
],
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig.temperature).toBe(0.9);
|
||||
expect(result.llmConfig.modelKwargs?.max_completion_tokens).toBe(1000);
|
||||
});
|
||||
|
||||
it('should allow addParams to override defaultParams', () => {
|
||||
const result = getOpenAIConfig(mockApiKey, {
|
||||
modelOptions: {
|
||||
model: 'gpt-4o',
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'azureOpenAI',
|
||||
paramDefinitions: [
|
||||
{ key: 'useResponsesApi', default: true },
|
||||
{ key: 'temperature', default: 0.5 },
|
||||
],
|
||||
},
|
||||
addParams: {
|
||||
useResponsesApi: false,
|
||||
temperature: 0.8,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig.useResponsesApi).toBe(false);
|
||||
expect(result.llmConfig.temperature).toBe(0.8);
|
||||
});
|
||||
|
||||
it('should handle defaultParams with unknown parameters', () => {
|
||||
const result = getOpenAIConfig(mockApiKey, {
|
||||
modelOptions: {
|
||||
model: 'gpt-4o',
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'azureOpenAI',
|
||||
paramDefinitions: [
|
||||
{ key: 'customParam1', default: 'defaultValue' },
|
||||
{ key: 'customParam2', default: 123 },
|
||||
],
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig.modelKwargs).toMatchObject({
|
||||
customParam1: 'defaultValue',
|
||||
customParam2: 123,
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle defaultParams with web_search', () => {
|
||||
const result = getOpenAIConfig(mockApiKey, {
|
||||
modelOptions: {
|
||||
model: 'gpt-4o',
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'openAI',
|
||||
paramDefinitions: [{ key: 'web_search', default: true }],
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig.useResponsesApi).toBe(true);
|
||||
expect(result.tools).toEqual([{ type: 'web_search' }]);
|
||||
});
|
||||
|
||||
it('should allow addParams to override defaultParams web_search', () => {
|
||||
const result = getOpenAIConfig(mockApiKey, {
|
||||
modelOptions: {
|
||||
model: 'gpt-4o',
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'openAI',
|
||||
paramDefinitions: [{ key: 'web_search', default: true }],
|
||||
},
|
||||
addParams: {
|
||||
web_search: false,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.tools).toEqual([]);
|
||||
});
|
||||
|
||||
it('should apply defaultParams for Anthropic via customParams', () => {
|
||||
const result = getOpenAIConfig('test-key', {
|
||||
modelOptions: {
|
||||
model: 'claude-3-5-sonnet-20241022',
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'anthropic',
|
||||
paramDefinitions: [
|
||||
{ key: 'temperature', default: 0.7 },
|
||||
{ key: 'topK', default: 50 },
|
||||
],
|
||||
},
|
||||
reverseProxyUrl: 'https://api.anthropic.com',
|
||||
});
|
||||
|
||||
expect(result.llmConfig.temperature).toBe(0.7);
|
||||
expect(result.llmConfig.modelKwargs?.topK).toBe(50);
|
||||
});
|
||||
|
||||
it('should apply defaultParams for Google via customParams', () => {
|
||||
const credentials = JSON.stringify({ GOOGLE_API_KEY: 'test-google-key' });
|
||||
const result = getOpenAIConfig(credentials, {
|
||||
modelOptions: {
|
||||
model: 'gemini-2.0-flash-exp',
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'google',
|
||||
paramDefinitions: [
|
||||
{ key: 'temperature', default: 0.6 },
|
||||
{ key: 'topK', default: 40 },
|
||||
],
|
||||
},
|
||||
reverseProxyUrl: 'https://generativelanguage.googleapis.com/v1beta/openai',
|
||||
});
|
||||
|
||||
expect(result.llmConfig.temperature).toBe(0.6);
|
||||
expect(result.llmConfig.modelKwargs?.topK).toBe(40);
|
||||
});
|
||||
|
||||
it('should handle empty paramDefinitions', () => {
|
||||
const result = getOpenAIConfig(mockApiKey, {
|
||||
modelOptions: {
|
||||
model: 'gpt-4o',
|
||||
temperature: 0.9,
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'azureOpenAI',
|
||||
paramDefinitions: [],
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig.temperature).toBe(0.9);
|
||||
});
|
||||
|
||||
it('should handle missing paramDefinitions', () => {
|
||||
const result = getOpenAIConfig(mockApiKey, {
|
||||
modelOptions: {
|
||||
model: 'gpt-4o',
|
||||
temperature: 0.9,
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'azureOpenAI',
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig.temperature).toBe(0.9);
|
||||
});
|
||||
|
||||
it('should preserve order: defaultParams < addParams < modelOptions', () => {
|
||||
const result = getOpenAIConfig(mockApiKey, {
|
||||
modelOptions: {
|
||||
model: 'gpt-5',
|
||||
temperature: 0.9,
|
||||
},
|
||||
customParams: {
|
||||
defaultParamsEndpoint: 'openAI',
|
||||
paramDefinitions: [
|
||||
{ key: 'temperature', default: 0.3 },
|
||||
{ key: 'topP', default: 0.5 },
|
||||
{ key: 'maxTokens', default: 500 },
|
||||
],
|
||||
},
|
||||
addParams: {
|
||||
topP: 0.8,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig.temperature).toBe(0.9);
|
||||
expect(result.llmConfig.topP).toBe(0.8);
|
||||
expect(result.llmConfig.modelKwargs?.max_completion_tokens).toBe(500);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Entra ID Authentication', () => {
|
||||
|
|
|
|||
|
|
@ -3,10 +3,11 @@ import { Providers } from '@librechat/agents';
|
|||
import { KnownEndpoints, EModelEndpoint } from 'librechat-data-provider';
|
||||
import type * as t from '~/types';
|
||||
import { getLLMConfig as getAnthropicLLMConfig } from '~/endpoints/anthropic/llm';
|
||||
import { getOpenAILLMConfig, extractDefaultParams } from './llm';
|
||||
import { getGoogleConfig } from '~/endpoints/google/llm';
|
||||
import { transformToOpenAIConfig } from './transform';
|
||||
import { constructAzureURL } from '~/utils/azure';
|
||||
import { createFetch } from '~/utils/generators';
|
||||
import { getOpenAILLMConfig } from './llm';
|
||||
|
||||
type Fetch = (input: string | URL | Request, init?: RequestInit) => Promise<Response>;
|
||||
|
||||
|
|
@ -33,17 +34,24 @@ export function getOpenAIConfig(
|
|||
reverseProxyUrl: baseURL,
|
||||
} = options;
|
||||
|
||||
/** Extract default params from customParams.paramDefinitions */
|
||||
const defaultParams = extractDefaultParams(options.customParams?.paramDefinitions);
|
||||
|
||||
let llmConfig: t.OAIClientOptions;
|
||||
let tools: t.LLMConfigResult['tools'];
|
||||
const isAnthropic = options.customParams?.defaultParamsEndpoint === EModelEndpoint.anthropic;
|
||||
const isGoogle = options.customParams?.defaultParamsEndpoint === EModelEndpoint.google;
|
||||
|
||||
const useOpenRouter =
|
||||
!isAnthropic &&
|
||||
!isGoogle &&
|
||||
((baseURL && baseURL.includes(KnownEndpoints.openrouter)) ||
|
||||
(endpoint != null && endpoint.toLowerCase().includes(KnownEndpoints.openrouter)));
|
||||
const isVercel =
|
||||
(baseURL && baseURL.includes('ai-gateway.vercel.sh')) ||
|
||||
(endpoint != null && endpoint.toLowerCase().includes(KnownEndpoints.vercel));
|
||||
!isAnthropic &&
|
||||
!isGoogle &&
|
||||
((baseURL && baseURL.includes('ai-gateway.vercel.sh')) ||
|
||||
(endpoint != null && endpoint.toLowerCase().includes(KnownEndpoints.vercel)));
|
||||
|
||||
let azure = options.azure;
|
||||
let headers = options.headers;
|
||||
|
|
@ -51,7 +59,12 @@ export function getOpenAIConfig(
|
|||
const anthropicResult = getAnthropicLLMConfig(apiKey, {
|
||||
modelOptions,
|
||||
proxy: options.proxy,
|
||||
reverseProxyUrl: baseURL,
|
||||
addParams,
|
||||
dropParams,
|
||||
defaultParams,
|
||||
});
|
||||
/** Transform handles addParams/dropParams - it knows about OpenAI params */
|
||||
const transformed = transformToOpenAIConfig({
|
||||
addParams,
|
||||
dropParams,
|
||||
|
|
@ -63,6 +76,24 @@ export function getOpenAIConfig(
|
|||
if (transformed.configOptions?.defaultHeaders) {
|
||||
headers = Object.assign(headers ?? {}, transformed.configOptions?.defaultHeaders);
|
||||
}
|
||||
} else if (isGoogle) {
|
||||
const googleResult = getGoogleConfig(apiKey, {
|
||||
modelOptions,
|
||||
reverseProxyUrl: baseURL ?? undefined,
|
||||
authHeader: true,
|
||||
addParams,
|
||||
dropParams,
|
||||
defaultParams,
|
||||
});
|
||||
/** Transform handles addParams/dropParams - it knows about OpenAI params */
|
||||
const transformed = transformToOpenAIConfig({
|
||||
addParams,
|
||||
dropParams,
|
||||
llmConfig: googleResult.llmConfig,
|
||||
fromEndpoint: EModelEndpoint.google,
|
||||
});
|
||||
llmConfig = transformed.llmConfig;
|
||||
tools = googleResult.tools;
|
||||
} else {
|
||||
const openaiResult = getOpenAILLMConfig({
|
||||
azure,
|
||||
|
|
@ -72,6 +103,7 @@ export function getOpenAIConfig(
|
|||
streaming,
|
||||
addParams,
|
||||
dropParams,
|
||||
defaultParams,
|
||||
modelOptions,
|
||||
useOpenRouter,
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
import { ErrorTypes, EModelEndpoint, mapModelToAzureConfig } from 'librechat-data-provider';
|
||||
import type {
|
||||
InitializeOpenAIOptionsParams,
|
||||
OpenAIOptionsResult,
|
||||
OpenAIConfigOptions,
|
||||
LLMConfigResult,
|
||||
UserKeyValues,
|
||||
} from '~/types';
|
||||
import { createHandleLLMNewToken } from '~/utils/generators';
|
||||
import { getAzureCredentials, getEntraIdAccessToken, shouldUseEntraId } from '~/utils/azure';
|
||||
import { isUserProvided } from '~/utils/common';
|
||||
import { resolveHeaders } from '~/utils/env';
|
||||
|
|
@ -27,7 +26,7 @@ export const initializeOpenAI = async ({
|
|||
overrideEndpoint,
|
||||
getUserKeyValues,
|
||||
checkUserKeyExpiry,
|
||||
}: InitializeOpenAIOptionsParams): Promise<OpenAIOptionsResult> => {
|
||||
}: InitializeOpenAIOptionsParams): Promise<LLMConfigResult> => {
|
||||
const { PROXY, OPENAI_API_KEY, AZURE_API_KEY, OPENAI_REVERSE_PROXY, AZURE_OPENAI_BASEURL } =
|
||||
process.env;
|
||||
|
||||
|
|
@ -178,17 +177,8 @@ export const initializeOpenAI = async ({
|
|||
}
|
||||
|
||||
if (streamRate) {
|
||||
options.llmConfig.callbacks = [
|
||||
{
|
||||
handleLLMNewToken: createHandleLLMNewToken(streamRate),
|
||||
},
|
||||
];
|
||||
options.llmConfig._lc_stream_delay = streamRate;
|
||||
}
|
||||
|
||||
const result: OpenAIOptionsResult = {
|
||||
...options,
|
||||
streamRate,
|
||||
};
|
||||
|
||||
return result;
|
||||
return options;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import { EModelEndpoint, removeNullishValues } from 'librechat-data-provider';
|
||||
import type { BindToolsInput } from '@langchain/core/language_models/chat_models';
|
||||
import type { SettingDefinition } from 'librechat-data-provider';
|
||||
import type { AzureOpenAIInput } from '@langchain/openai';
|
||||
import type { OpenAI } from 'openai';
|
||||
import type * as t from '~/types';
|
||||
|
|
@ -75,6 +76,44 @@ function hasReasoningParams({
|
|||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts default parameters from customParams.paramDefinitions
|
||||
* @param paramDefinitions - Array of parameter definitions with key and default values
|
||||
* @returns Record of default parameters
|
||||
*/
|
||||
export function extractDefaultParams(
|
||||
paramDefinitions?: Partial<SettingDefinition>[],
|
||||
): Record<string, unknown> | undefined {
|
||||
if (!paramDefinitions || !Array.isArray(paramDefinitions)) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const defaults: Record<string, unknown> = {};
|
||||
for (let i = 0; i < paramDefinitions.length; i++) {
|
||||
const param = paramDefinitions[i];
|
||||
if (param.key !== undefined && param.default !== undefined) {
|
||||
defaults[param.key] = param.default;
|
||||
}
|
||||
}
|
||||
return defaults;
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies default parameters to the target object only if the field is undefined
|
||||
* @param target - The target object to apply defaults to
|
||||
* @param defaults - Record of default parameter values
|
||||
*/
|
||||
export function applyDefaultParams(
|
||||
target: Record<string, unknown>,
|
||||
defaults: Record<string, unknown>,
|
||||
) {
|
||||
for (const [key, value] of Object.entries(defaults)) {
|
||||
if (target[key] === undefined) {
|
||||
target[key] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function getOpenAILLMConfig({
|
||||
azure,
|
||||
apiKey,
|
||||
|
|
@ -83,6 +122,7 @@ export function getOpenAILLMConfig({
|
|||
streaming,
|
||||
addParams,
|
||||
dropParams,
|
||||
defaultParams,
|
||||
useOpenRouter,
|
||||
modelOptions: _modelOptions,
|
||||
}: {
|
||||
|
|
@ -93,6 +133,7 @@ export function getOpenAILLMConfig({
|
|||
modelOptions: Partial<t.OpenAIParameters>;
|
||||
addParams?: Record<string, unknown>;
|
||||
dropParams?: string[];
|
||||
defaultParams?: Record<string, unknown>;
|
||||
useOpenRouter?: boolean;
|
||||
azure?: false | t.AzureOptions;
|
||||
}): Pick<t.LLMConfigResult, 'llmConfig' | 'tools'> & {
|
||||
|
|
@ -133,6 +174,30 @@ export function getOpenAILLMConfig({
|
|||
|
||||
let enableWebSearch = web_search;
|
||||
|
||||
/** Apply defaultParams first - only if fields are undefined */
|
||||
if (defaultParams && typeof defaultParams === 'object') {
|
||||
for (const [key, value] of Object.entries(defaultParams)) {
|
||||
/** Handle web_search separately - don't add to config */
|
||||
if (key === 'web_search') {
|
||||
if (enableWebSearch === undefined && typeof value === 'boolean') {
|
||||
enableWebSearch = value;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (knownOpenAIParams.has(key)) {
|
||||
applyDefaultParams(llmConfig as Record<string, unknown>, { [key]: value });
|
||||
} else {
|
||||
/** Apply to modelKwargs if not a known param */
|
||||
if (modelKwargs[key] === undefined) {
|
||||
modelKwargs[key] = value;
|
||||
hasModelKwargs = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Apply addParams - can override defaultParams */
|
||||
if (addParams && typeof addParams === 'object') {
|
||||
for (const [key, value] of Object.entries(addParams)) {
|
||||
/** Handle web_search directly here instead of adding to modelKwargs or llmConfig */
|
||||
|
|
@ -190,7 +255,7 @@ export function getOpenAILLMConfig({
|
|||
} else if (enableWebSearch) {
|
||||
/** Standard OpenAI web search uses tools API */
|
||||
llmConfig.useResponsesApi = true;
|
||||
tools.push({ type: 'web_search_preview' });
|
||||
tools.push({ type: 'web_search' });
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import type * as t from '~/types';
|
|||
import { knownOpenAIParams } from './llm';
|
||||
|
||||
const anthropicExcludeParams = new Set(['anthropicApiUrl']);
|
||||
const googleExcludeParams = new Set(['safetySettings', 'location', 'baseUrl', 'customHeaders']);
|
||||
|
||||
/**
|
||||
* Transforms a Non-OpenAI LLM config to an OpenAI-conformant config.
|
||||
|
|
@ -31,7 +32,14 @@ export function transformToOpenAIConfig({
|
|||
let hasModelKwargs = false;
|
||||
|
||||
const isAnthropic = fromEndpoint === EModelEndpoint.anthropic;
|
||||
const excludeParams = isAnthropic ? anthropicExcludeParams : new Set();
|
||||
const isGoogle = fromEndpoint === EModelEndpoint.google;
|
||||
|
||||
let excludeParams = new Set<string>();
|
||||
if (isAnthropic) {
|
||||
excludeParams = anthropicExcludeParams;
|
||||
} else if (isGoogle) {
|
||||
excludeParams = googleExcludeParams;
|
||||
}
|
||||
|
||||
for (const [key, value] of Object.entries(llmConfig)) {
|
||||
if (value === undefined || value === null) {
|
||||
|
|
@ -49,6 +57,19 @@ export function transformToOpenAIConfig({
|
|||
modelKwargs = Object.assign({}, modelKwargs, value as Record<string, unknown>);
|
||||
hasModelKwargs = true;
|
||||
continue;
|
||||
} else if (isGoogle && key === 'authOptions') {
|
||||
// Handle Google authOptions
|
||||
modelKwargs = Object.assign({}, modelKwargs, value as Record<string, unknown>);
|
||||
hasModelKwargs = true;
|
||||
continue;
|
||||
} else if (
|
||||
isGoogle &&
|
||||
(key === 'thinkingConfig' || key === 'thinkingBudget' || key === 'includeThoughts')
|
||||
) {
|
||||
// Handle Google thinking configuration
|
||||
modelKwargs = Object.assign({}, modelKwargs, { [key]: value });
|
||||
hasModelKwargs = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (knownOpenAIParams.has(key)) {
|
||||
|
|
@ -61,6 +82,11 @@ export function transformToOpenAIConfig({
|
|||
|
||||
if (addParams && typeof addParams === 'object') {
|
||||
for (const [key, value] of Object.entries(addParams)) {
|
||||
/** Skip web_search - it's handled separately as a tool */
|
||||
if (key === 'web_search') {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (knownOpenAIParams.has(key)) {
|
||||
(openAIConfig as Record<string, unknown>)[key] = value;
|
||||
} else {
|
||||
|
|
@ -76,16 +102,23 @@ export function transformToOpenAIConfig({
|
|||
|
||||
if (dropParams && Array.isArray(dropParams)) {
|
||||
dropParams.forEach((param) => {
|
||||
/** Skip web_search - handled separately */
|
||||
if (param === 'web_search') {
|
||||
return;
|
||||
}
|
||||
|
||||
if (param in openAIConfig) {
|
||||
delete openAIConfig[param as keyof t.OAIClientOptions];
|
||||
}
|
||||
if (openAIConfig.modelKwargs && param in openAIConfig.modelKwargs) {
|
||||
delete openAIConfig.modelKwargs[param];
|
||||
if (Object.keys(openAIConfig.modelKwargs).length === 0) {
|
||||
delete openAIConfig.modelKwargs;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
/** Clean up empty modelKwargs after dropParams processing */
|
||||
if (openAIConfig.modelKwargs && Object.keys(openAIConfig.modelKwargs).length === 0) {
|
||||
delete openAIConfig.modelKwargs;
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
import { logger } from '@librechat/data-schemas';
|
||||
import { FileSources, mergeFileConfig } from 'librechat-data-provider';
|
||||
import type { fileConfigSchema } from 'librechat-data-provider';
|
||||
import type { IMongoFile } from '@librechat/data-schemas';
|
||||
import type { z } from 'zod';
|
||||
import type { ServerRequest } from '~/types';
|
||||
import { processTextWithTokenLimit } from '~/utils/text';
|
||||
|
||||
/**
|
||||
|
|
@ -20,10 +19,7 @@ export async function extractFileContext({
|
|||
tokenCountFn,
|
||||
}: {
|
||||
attachments: IMongoFile[];
|
||||
req?: {
|
||||
body?: { fileTokenLimit?: number };
|
||||
config?: { fileConfig?: z.infer<typeof fileConfigSchema> };
|
||||
};
|
||||
req?: ServerRequest;
|
||||
tokenCountFn: (text: string) => number;
|
||||
}): Promise<string | undefined> {
|
||||
if (!attachments || attachments.length === 0) {
|
||||
|
|
|
|||
|
|
@ -1,25 +1,27 @@
|
|||
import { Providers } from '@librechat/agents';
|
||||
import { isDocumentSupportedProvider } from 'librechat-data-provider';
|
||||
import type { IMongoFile } from '@librechat/data-schemas';
|
||||
import type { Request } from 'express';
|
||||
import type { StrategyFunctions, AudioResult } from '~/types/files';
|
||||
import type { ServerRequest, StrategyFunctions, AudioResult } from '~/types';
|
||||
import { getFileStream, getConfiguredFileSizeLimit } from './utils';
|
||||
import { validateAudio } from '~/files/validation';
|
||||
import { getFileStream } from './utils';
|
||||
|
||||
/**
|
||||
* Encodes and formats audio files for different providers
|
||||
* @param req - The request object
|
||||
* @param files - Array of audio files
|
||||
* @param provider - The provider to format for (currently only google is supported)
|
||||
* @param params - Object containing provider and optional endpoint
|
||||
* @param params.provider - The provider to format for (currently only google is supported)
|
||||
* @param params.endpoint - Optional endpoint name for file config lookup
|
||||
* @param getStrategyFunctions - Function to get strategy functions
|
||||
* @returns Promise that resolves to audio and file metadata
|
||||
*/
|
||||
export async function encodeAndFormatAudios(
|
||||
req: Request,
|
||||
req: ServerRequest,
|
||||
files: IMongoFile[],
|
||||
provider: Providers,
|
||||
params: { provider: Providers; endpoint?: string },
|
||||
getStrategyFunctions: (source: string) => StrategyFunctions,
|
||||
): Promise<AudioResult> {
|
||||
const { provider, endpoint } = params;
|
||||
if (!files?.length) {
|
||||
return { audios: [], files: [] };
|
||||
}
|
||||
|
|
@ -53,7 +55,19 @@ export async function encodeAndFormatAudios(
|
|||
}
|
||||
|
||||
const audioBuffer = Buffer.from(content, 'base64');
|
||||
const validation = await validateAudio(audioBuffer, audioBuffer.length, provider);
|
||||
|
||||
/** Extract configured file size limit from fileConfig for this endpoint */
|
||||
const configuredFileSizeLimit = getConfiguredFileSizeLimit(req, {
|
||||
provider,
|
||||
endpoint,
|
||||
});
|
||||
|
||||
const validation = await validateAudio(
|
||||
audioBuffer,
|
||||
audioBuffer.length,
|
||||
provider,
|
||||
configuredFileSizeLimit,
|
||||
);
|
||||
|
||||
if (!validation.isValid) {
|
||||
throw new Error(`Audio validation failed: ${validation.error}`);
|
||||
|
|
|
|||
531
packages/api/src/files/encode/document.spec.ts
Normal file
531
packages/api/src/files/encode/document.spec.ts
Normal file
|
|
@ -0,0 +1,531 @@
|
|||
import { Providers } from '@librechat/agents';
|
||||
import { mbToBytes } from 'librechat-data-provider';
|
||||
import type { AppConfig, IMongoFile } from '@librechat/data-schemas';
|
||||
import type { ServerRequest } from '~/types';
|
||||
import { encodeAndFormatDocuments } from './document';
|
||||
|
||||
/** Mock the validation module */
|
||||
jest.mock('~/files/validation', () => ({
|
||||
validatePdf: jest.fn(),
|
||||
}));
|
||||
|
||||
/** Mock the utils module */
|
||||
jest.mock('./utils', () => ({
|
||||
getFileStream: jest.fn(),
|
||||
getConfiguredFileSizeLimit: jest.fn(),
|
||||
}));
|
||||
|
||||
import { validatePdf } from '~/files/validation';
|
||||
import { getFileStream, getConfiguredFileSizeLimit } from './utils';
|
||||
import { Types } from 'mongoose';
|
||||
|
||||
const mockedValidatePdf = validatePdf as jest.MockedFunction<typeof validatePdf>;
|
||||
const mockedGetFileStream = getFileStream as jest.MockedFunction<typeof getFileStream>;
|
||||
const mockedGetConfiguredFileSizeLimit = getConfiguredFileSizeLimit as jest.MockedFunction<
|
||||
typeof getConfiguredFileSizeLimit
|
||||
>;
|
||||
|
||||
describe('encodeAndFormatDocuments - fileConfig integration', () => {
|
||||
const mockStrategyFunctions = jest.fn();
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
/** Default mock implementation for getConfiguredFileSizeLimit */
|
||||
mockedGetConfiguredFileSizeLimit.mockImplementation((req, params) => {
|
||||
if (!req.config?.fileConfig) {
|
||||
return undefined;
|
||||
}
|
||||
const { provider, endpoint } = params;
|
||||
const lookupKey = endpoint ?? provider;
|
||||
const fileConfig = req.config.fileConfig;
|
||||
const endpoints = fileConfig.endpoints;
|
||||
if (endpoints?.[lookupKey]) {
|
||||
const limit = endpoints[lookupKey].fileSizeLimit;
|
||||
return limit !== undefined ? mbToBytes(limit) : undefined;
|
||||
}
|
||||
if (endpoints?.default) {
|
||||
const limit = endpoints.default.fileSizeLimit;
|
||||
return limit !== undefined ? mbToBytes(limit) : undefined;
|
||||
}
|
||||
return undefined;
|
||||
});
|
||||
});
|
||||
|
||||
/** Helper to create a mock request with file config */
|
||||
const createMockRequest = (fileSizeLimit?: number): Partial<AppConfig> => ({
|
||||
config:
|
||||
fileSizeLimit !== undefined
|
||||
? {
|
||||
fileConfig: {
|
||||
endpoints: {
|
||||
[Providers.OPENAI]: {
|
||||
fileSizeLimit,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
: undefined,
|
||||
});
|
||||
|
||||
/** Helper to create a mock PDF file */
|
||||
const createMockFile = (sizeInMB: number): IMongoFile =>
|
||||
({
|
||||
_id: new Types.ObjectId(),
|
||||
user: new Types.ObjectId(),
|
||||
file_id: new Types.ObjectId().toString(),
|
||||
filename: 'test.pdf',
|
||||
type: 'application/pdf',
|
||||
bytes: Math.floor(sizeInMB * 1024 * 1024),
|
||||
object: 'file',
|
||||
usage: 0,
|
||||
source: 'test',
|
||||
filepath: '/test/path.pdf',
|
||||
createdAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
}) as unknown as IMongoFile;
|
||||
|
||||
describe('Configuration extraction and validation', () => {
|
||||
it('should pass configured file size limit to validatePdf for OpenAI', async () => {
|
||||
const configuredLimit = mbToBytes(15);
|
||||
const req = createMockRequest(15) as ServerRequest;
|
||||
const file = createMockFile(10);
|
||||
|
||||
const mockContent = Buffer.from('test-pdf-content').toString('base64');
|
||||
mockedGetFileStream.mockResolvedValue({
|
||||
file,
|
||||
content: mockContent,
|
||||
metadata: file,
|
||||
});
|
||||
|
||||
mockedValidatePdf.mockResolvedValue({ isValid: true });
|
||||
|
||||
await encodeAndFormatDocuments(
|
||||
req,
|
||||
[file],
|
||||
{ provider: Providers.OPENAI },
|
||||
mockStrategyFunctions,
|
||||
);
|
||||
|
||||
expect(mockedValidatePdf).toHaveBeenCalledWith(
|
||||
expect.any(Buffer),
|
||||
expect.any(Number),
|
||||
Providers.OPENAI,
|
||||
configuredLimit,
|
||||
);
|
||||
});
|
||||
|
||||
it('should pass undefined when no fileConfig is provided', async () => {
|
||||
const req = {} as ServerRequest;
|
||||
const file = createMockFile(10);
|
||||
|
||||
const mockContent = Buffer.from('test-pdf-content').toString('base64');
|
||||
mockedGetFileStream.mockResolvedValue({
|
||||
file,
|
||||
content: mockContent,
|
||||
metadata: file,
|
||||
});
|
||||
|
||||
mockedValidatePdf.mockResolvedValue({ isValid: true });
|
||||
|
||||
await encodeAndFormatDocuments(
|
||||
req,
|
||||
[file],
|
||||
{ provider: Providers.OPENAI },
|
||||
mockStrategyFunctions,
|
||||
);
|
||||
|
||||
expect(mockedValidatePdf).toHaveBeenCalledWith(
|
||||
expect.any(Buffer),
|
||||
expect.any(Number),
|
||||
Providers.OPENAI,
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
it('should pass undefined when fileConfig.endpoints is not defined', async () => {
|
||||
const req = {
|
||||
config: {
|
||||
fileConfig: {},
|
||||
},
|
||||
} as ServerRequest;
|
||||
const file = createMockFile(10);
|
||||
|
||||
const mockContent = Buffer.from('test-pdf-content').toString('base64');
|
||||
mockedGetFileStream.mockResolvedValue({
|
||||
file,
|
||||
content: mockContent,
|
||||
metadata: file,
|
||||
});
|
||||
|
||||
mockedValidatePdf.mockResolvedValue({ isValid: true });
|
||||
|
||||
await encodeAndFormatDocuments(
|
||||
req,
|
||||
[file],
|
||||
{ provider: Providers.OPENAI },
|
||||
mockStrategyFunctions,
|
||||
);
|
||||
|
||||
/** When fileConfig has no endpoints, getConfiguredFileSizeLimit returns undefined */
|
||||
expect(mockedValidatePdf).toHaveBeenCalledWith(
|
||||
expect.any(Buffer),
|
||||
expect.any(Number),
|
||||
Providers.OPENAI,
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
it('should use endpoint-specific config for Anthropic', async () => {
|
||||
const configuredLimit = mbToBytes(20);
|
||||
const req = {
|
||||
config: {
|
||||
fileConfig: {
|
||||
endpoints: {
|
||||
[Providers.ANTHROPIC]: {
|
||||
fileSizeLimit: 20,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
} as unknown as ServerRequest;
|
||||
const file = createMockFile(15);
|
||||
|
||||
const mockContent = Buffer.from('test-pdf-content').toString('base64');
|
||||
mockedGetFileStream.mockResolvedValue({
|
||||
file,
|
||||
content: mockContent,
|
||||
metadata: file,
|
||||
});
|
||||
|
||||
mockedValidatePdf.mockResolvedValue({ isValid: true });
|
||||
|
||||
await encodeAndFormatDocuments(
|
||||
req,
|
||||
[file],
|
||||
{ provider: Providers.ANTHROPIC },
|
||||
mockStrategyFunctions,
|
||||
);
|
||||
|
||||
expect(mockedValidatePdf).toHaveBeenCalledWith(
|
||||
expect.any(Buffer),
|
||||
expect.any(Number),
|
||||
Providers.ANTHROPIC,
|
||||
configuredLimit,
|
||||
);
|
||||
});
|
||||
|
||||
it('should use endpoint-specific config for Google', async () => {
|
||||
const configuredLimit = mbToBytes(25);
|
||||
const req = {
|
||||
config: {
|
||||
fileConfig: {
|
||||
endpoints: {
|
||||
[Providers.GOOGLE]: {
|
||||
fileSizeLimit: 25,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
} as unknown as ServerRequest;
|
||||
const file = createMockFile(18);
|
||||
|
||||
const mockContent = Buffer.from('test-pdf-content').toString('base64');
|
||||
mockedGetFileStream.mockResolvedValue({
|
||||
file,
|
||||
content: mockContent,
|
||||
metadata: file,
|
||||
});
|
||||
|
||||
mockedValidatePdf.mockResolvedValue({ isValid: true });
|
||||
|
||||
await encodeAndFormatDocuments(
|
||||
req,
|
||||
[file],
|
||||
{ provider: Providers.GOOGLE },
|
||||
mockStrategyFunctions,
|
||||
);
|
||||
|
||||
expect(mockedValidatePdf).toHaveBeenCalledWith(
|
||||
expect.any(Buffer),
|
||||
expect.any(Number),
|
||||
Providers.GOOGLE,
|
||||
configuredLimit,
|
||||
);
|
||||
});
|
||||
|
||||
it('should pass undefined when provider-specific config not found and no default', async () => {
|
||||
const req = {
|
||||
config: {
|
||||
fileConfig: {
|
||||
endpoints: {
|
||||
/** Only configure a different provider, not OpenAI */
|
||||
[Providers.ANTHROPIC]: {
|
||||
fileSizeLimit: 25,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
} as unknown as ServerRequest;
|
||||
const file = createMockFile(20);
|
||||
|
||||
const mockContent = Buffer.from('test-pdf-content').toString('base64');
|
||||
mockedGetFileStream.mockResolvedValue({
|
||||
file,
|
||||
content: mockContent,
|
||||
metadata: file,
|
||||
});
|
||||
|
||||
mockedValidatePdf.mockResolvedValue({ isValid: true });
|
||||
|
||||
await encodeAndFormatDocuments(
|
||||
req,
|
||||
[file],
|
||||
{ provider: Providers.OPENAI },
|
||||
mockStrategyFunctions,
|
||||
);
|
||||
|
||||
/** When provider-specific config not found and no default, returns undefined */
|
||||
expect(mockedValidatePdf).toHaveBeenCalledWith(
|
||||
expect.any(Buffer),
|
||||
expect.any(Number),
|
||||
Providers.OPENAI,
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Validation failure handling', () => {
|
||||
it('should throw error when validation fails', async () => {
|
||||
const req = createMockRequest(10) as ServerRequest;
|
||||
const file = createMockFile(12);
|
||||
|
||||
const mockContent = Buffer.from('test-pdf-content').toString('base64');
|
||||
mockedGetFileStream.mockResolvedValue({
|
||||
file,
|
||||
content: mockContent,
|
||||
metadata: file,
|
||||
});
|
||||
|
||||
mockedValidatePdf.mockResolvedValue({
|
||||
isValid: false,
|
||||
error: 'PDF file size (12MB) exceeds the 10MB limit',
|
||||
});
|
||||
|
||||
await expect(
|
||||
encodeAndFormatDocuments(
|
||||
req,
|
||||
[file],
|
||||
{ provider: Providers.OPENAI },
|
||||
mockStrategyFunctions,
|
||||
),
|
||||
).rejects.toThrow('PDF validation failed: PDF file size (12MB) exceeds the 10MB limit');
|
||||
});
|
||||
|
||||
it('should not call validatePdf for non-PDF files', async () => {
|
||||
const req = createMockRequest(10) as ServerRequest;
|
||||
const file: IMongoFile = {
|
||||
...createMockFile(5),
|
||||
type: 'image/jpeg',
|
||||
filename: 'test.jpg',
|
||||
};
|
||||
|
||||
const mockContent = Buffer.from('test-image-content').toString('base64');
|
||||
mockedGetFileStream.mockResolvedValue({
|
||||
file,
|
||||
content: mockContent,
|
||||
metadata: file,
|
||||
});
|
||||
|
||||
await encodeAndFormatDocuments(
|
||||
req,
|
||||
[file],
|
||||
{ provider: Providers.OPENAI },
|
||||
mockStrategyFunctions,
|
||||
);
|
||||
|
||||
expect(mockedValidatePdf).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Bug reproduction scenarios', () => {
|
||||
it('should respect user-configured lower limit (stricter than provider)', async () => {
|
||||
/**
|
||||
* Scenario: User sets openAI.fileSizeLimit = 5MB (stricter than 10MB provider limit)
|
||||
* Uploads 7MB PDF
|
||||
* Expected: Validation called with 5MB limit
|
||||
*/
|
||||
const req = createMockRequest(5) as ServerRequest;
|
||||
const file = createMockFile(7);
|
||||
|
||||
const mockContent = Buffer.from('test-pdf-content').toString('base64');
|
||||
mockedGetFileStream.mockResolvedValue({
|
||||
file,
|
||||
content: mockContent,
|
||||
metadata: file,
|
||||
});
|
||||
|
||||
mockedValidatePdf.mockResolvedValue({
|
||||
isValid: false,
|
||||
error: 'PDF file size (7MB) exceeds the 5MB limit',
|
||||
});
|
||||
|
||||
await expect(
|
||||
encodeAndFormatDocuments(
|
||||
req,
|
||||
[file],
|
||||
{ provider: Providers.OPENAI },
|
||||
mockStrategyFunctions,
|
||||
),
|
||||
).rejects.toThrow('PDF validation failed');
|
||||
|
||||
expect(mockedValidatePdf).toHaveBeenCalledWith(
|
||||
expect.any(Buffer),
|
||||
expect.any(Number),
|
||||
Providers.OPENAI,
|
||||
mbToBytes(5),
|
||||
);
|
||||
});
|
||||
|
||||
it('should respect user-configured higher limit (allows API changes)', async () => {
|
||||
/**
|
||||
* Scenario: User sets openAI.fileSizeLimit = 50MB (higher than 10MB provider default)
|
||||
* Uploads 15MB PDF
|
||||
* Expected: Validation called with 50MB limit, allowing files between 10-50MB
|
||||
* This allows users to take advantage of API limit increases
|
||||
*/
|
||||
const req = createMockRequest(50) as ServerRequest;
|
||||
const file = createMockFile(15);
|
||||
|
||||
const mockContent = Buffer.from('test-pdf-content').toString('base64');
|
||||
mockedGetFileStream.mockResolvedValue({
|
||||
file,
|
||||
content: mockContent,
|
||||
metadata: file,
|
||||
});
|
||||
|
||||
mockedValidatePdf.mockResolvedValue({ isValid: true });
|
||||
|
||||
await encodeAndFormatDocuments(
|
||||
req,
|
||||
[file],
|
||||
{ provider: Providers.OPENAI },
|
||||
mockStrategyFunctions,
|
||||
);
|
||||
|
||||
expect(mockedValidatePdf).toHaveBeenCalledWith(
|
||||
expect.any(Buffer),
|
||||
expect.any(Number),
|
||||
Providers.OPENAI,
|
||||
mbToBytes(50),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle multiple files with different sizes', async () => {
|
||||
const req = createMockRequest(10) as ServerRequest;
|
||||
const file1 = createMockFile(5);
|
||||
const file2 = createMockFile(8);
|
||||
|
||||
const mockContent1 = Buffer.from('pdf-content-1').toString('base64');
|
||||
const mockContent2 = Buffer.from('pdf-content-2').toString('base64');
|
||||
|
||||
mockedGetFileStream
|
||||
.mockResolvedValueOnce({
|
||||
file: file1,
|
||||
content: mockContent1,
|
||||
metadata: file1,
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
file: file2,
|
||||
content: mockContent2,
|
||||
metadata: file2,
|
||||
});
|
||||
|
||||
mockedValidatePdf.mockResolvedValue({ isValid: true });
|
||||
|
||||
await encodeAndFormatDocuments(
|
||||
req,
|
||||
[file1, file2],
|
||||
{ provider: Providers.OPENAI },
|
||||
mockStrategyFunctions,
|
||||
);
|
||||
|
||||
expect(mockedValidatePdf).toHaveBeenCalledTimes(2);
|
||||
expect(mockedValidatePdf).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
expect.any(Buffer),
|
||||
expect.any(Number),
|
||||
Providers.OPENAI,
|
||||
mbToBytes(10),
|
||||
);
|
||||
expect(mockedValidatePdf).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
expect.any(Buffer),
|
||||
expect.any(Number),
|
||||
Providers.OPENAI,
|
||||
mbToBytes(10),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Document formatting after validation', () => {
|
||||
it('should format Anthropic document with valid PDF', async () => {
|
||||
const req = createMockRequest(30) as ServerRequest;
|
||||
const file = createMockFile(20);
|
||||
|
||||
const mockContent = Buffer.from('test-pdf-content').toString('base64');
|
||||
mockedGetFileStream.mockResolvedValue({
|
||||
file,
|
||||
content: mockContent,
|
||||
metadata: file,
|
||||
});
|
||||
|
||||
mockedValidatePdf.mockResolvedValue({ isValid: true });
|
||||
|
||||
const result = await encodeAndFormatDocuments(
|
||||
req,
|
||||
[file],
|
||||
{ provider: Providers.ANTHROPIC },
|
||||
mockStrategyFunctions,
|
||||
);
|
||||
|
||||
expect(result.documents).toHaveLength(1);
|
||||
expect(result.documents[0]).toMatchObject({
|
||||
type: 'document',
|
||||
source: {
|
||||
type: 'base64',
|
||||
media_type: 'application/pdf',
|
||||
data: mockContent,
|
||||
},
|
||||
citations: { enabled: true },
|
||||
});
|
||||
});
|
||||
|
||||
it('should format OpenAI document with responses API', async () => {
|
||||
const req = createMockRequest(15) as ServerRequest;
|
||||
const file = createMockFile(10);
|
||||
|
||||
const mockContent = Buffer.from('test-pdf-content').toString('base64');
|
||||
mockedGetFileStream.mockResolvedValue({
|
||||
file,
|
||||
content: mockContent,
|
||||
metadata: file,
|
||||
});
|
||||
|
||||
mockedValidatePdf.mockResolvedValue({ isValid: true });
|
||||
|
||||
const result = await encodeAndFormatDocuments(
|
||||
req,
|
||||
[file],
|
||||
{ provider: Providers.OPENAI, useResponsesApi: true },
|
||||
mockStrategyFunctions,
|
||||
);
|
||||
|
||||
expect(result.documents).toHaveLength(1);
|
||||
expect(result.documents[0]).toMatchObject({
|
||||
type: 'input_file',
|
||||
filename: 'test.pdf',
|
||||
file_data: `data:application/pdf;base64,${mockContent}`,
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -1,25 +1,33 @@
|
|||
import { Providers } from '@librechat/agents';
|
||||
import { isOpenAILikeProvider, isDocumentSupportedProvider } from 'librechat-data-provider';
|
||||
import type { IMongoFile } from '@librechat/data-schemas';
|
||||
import type { Request } from 'express';
|
||||
import type { StrategyFunctions, DocumentResult, AnthropicDocumentBlock } from '~/types/files';
|
||||
import type {
|
||||
AnthropicDocumentBlock,
|
||||
StrategyFunctions,
|
||||
DocumentResult,
|
||||
ServerRequest,
|
||||
} from '~/types';
|
||||
import { getFileStream, getConfiguredFileSizeLimit } from './utils';
|
||||
import { validatePdf } from '~/files/validation';
|
||||
import { getFileStream } from './utils';
|
||||
|
||||
/**
|
||||
* Processes and encodes document files for various providers
|
||||
* @param req - Express request object
|
||||
* @param files - Array of file objects to process
|
||||
* @param provider - The provider name
|
||||
* @param params - Object containing provider, endpoint, and other options
|
||||
* @param params.provider - The provider name
|
||||
* @param params.endpoint - Optional endpoint name for file config lookup
|
||||
* @param params.useResponsesApi - Whether to use responses API format
|
||||
* @param getStrategyFunctions - Function to get strategy functions
|
||||
* @returns Promise that resolves to documents and file metadata
|
||||
*/
|
||||
export async function encodeAndFormatDocuments(
|
||||
req: Request,
|
||||
req: ServerRequest,
|
||||
files: IMongoFile[],
|
||||
{ provider, useResponsesApi }: { provider: Providers; useResponsesApi?: boolean },
|
||||
params: { provider: Providers; endpoint?: string; useResponsesApi?: boolean },
|
||||
getStrategyFunctions: (source: string) => StrategyFunctions,
|
||||
): Promise<DocumentResult> {
|
||||
const { provider, endpoint, useResponsesApi } = params;
|
||||
if (!files?.length) {
|
||||
return { documents: [], files: [] };
|
||||
}
|
||||
|
|
@ -62,7 +70,19 @@ export async function encodeAndFormatDocuments(
|
|||
|
||||
if (file.type === 'application/pdf' && isDocumentSupportedProvider(provider)) {
|
||||
const pdfBuffer = Buffer.from(content, 'base64');
|
||||
const validation = await validatePdf(pdfBuffer, pdfBuffer.length, provider);
|
||||
|
||||
/** Extract configured file size limit from fileConfig for this endpoint */
|
||||
const configuredFileSizeLimit = getConfiguredFileSizeLimit(req, {
|
||||
provider,
|
||||
endpoint,
|
||||
});
|
||||
|
||||
const validation = await validatePdf(
|
||||
pdfBuffer,
|
||||
pdfBuffer.length,
|
||||
provider,
|
||||
configuredFileSizeLimit,
|
||||
);
|
||||
|
||||
if (!validation.isValid) {
|
||||
throw new Error(`PDF validation failed: ${validation.error}`);
|
||||
|
|
|
|||
|
|
@ -1,8 +1,35 @@
|
|||
import getStream from 'get-stream';
|
||||
import { FileSources } from 'librechat-data-provider';
|
||||
import { Providers } from '@librechat/agents';
|
||||
import { FileSources, mergeFileConfig, getEndpointFileConfig } from 'librechat-data-provider';
|
||||
import type { IMongoFile } from '@librechat/data-schemas';
|
||||
import type { Request } from 'express';
|
||||
import type { StrategyFunctions, ProcessedFile } from '~/types/files';
|
||||
import type { ServerRequest, StrategyFunctions, ProcessedFile } from '~/types';
|
||||
|
||||
/**
|
||||
* Extracts the configured file size limit for a specific provider from fileConfig
|
||||
* @param req - The server request object containing config
|
||||
* @param params - Object containing provider and optional endpoint
|
||||
* @param params.provider - The provider to get the limit for
|
||||
* @param params.endpoint - Optional endpoint name for lookup
|
||||
* @returns The configured file size limit in bytes, or undefined if not configured
|
||||
*/
|
||||
export const getConfiguredFileSizeLimit = (
|
||||
req: ServerRequest,
|
||||
params: {
|
||||
provider: Providers;
|
||||
endpoint?: string;
|
||||
},
|
||||
): number | undefined => {
|
||||
if (!req.config?.fileConfig) {
|
||||
return undefined;
|
||||
}
|
||||
const { provider, endpoint } = params;
|
||||
const fileConfig = mergeFileConfig(req.config.fileConfig);
|
||||
const endpointConfig = getEndpointFileConfig({
|
||||
fileConfig,
|
||||
endpoint: endpoint ?? provider,
|
||||
});
|
||||
return endpointConfig?.fileSizeLimit;
|
||||
};
|
||||
|
||||
/**
|
||||
* Processes a file by downloading and encoding it to base64
|
||||
|
|
@ -13,7 +40,7 @@ import type { StrategyFunctions, ProcessedFile } from '~/types/files';
|
|||
* @returns Processed file with content and metadata, or null if filepath missing
|
||||
*/
|
||||
export async function getFileStream(
|
||||
req: Request,
|
||||
req: ServerRequest,
|
||||
file: IMongoFile,
|
||||
encodingMethods: Record<string, StrategyFunctions>,
|
||||
getStrategyFunctions: (source: string) => StrategyFunctions,
|
||||
|
|
|
|||
|
|
@ -1,25 +1,27 @@
|
|||
import { Providers } from '@librechat/agents';
|
||||
import { isDocumentSupportedProvider } from 'librechat-data-provider';
|
||||
import type { IMongoFile } from '@librechat/data-schemas';
|
||||
import type { Request } from 'express';
|
||||
import type { StrategyFunctions, VideoResult } from '~/types/files';
|
||||
import type { ServerRequest, StrategyFunctions, VideoResult } from '~/types';
|
||||
import { getFileStream, getConfiguredFileSizeLimit } from './utils';
|
||||
import { validateVideo } from '~/files/validation';
|
||||
import { getFileStream } from './utils';
|
||||
|
||||
/**
|
||||
* Encodes and formats video files for different providers
|
||||
* @param req - The request object
|
||||
* @param files - Array of video files
|
||||
* @param provider - The provider to format for
|
||||
* @param params - Object containing provider and optional endpoint
|
||||
* @param params.provider - The provider to format for
|
||||
* @param params.endpoint - Optional endpoint name for file config lookup
|
||||
* @param getStrategyFunctions - Function to get strategy functions
|
||||
* @returns Promise that resolves to videos and file metadata
|
||||
*/
|
||||
export async function encodeAndFormatVideos(
|
||||
req: Request,
|
||||
req: ServerRequest,
|
||||
files: IMongoFile[],
|
||||
provider: Providers,
|
||||
params: { provider: Providers; endpoint?: string },
|
||||
getStrategyFunctions: (source: string) => StrategyFunctions,
|
||||
): Promise<VideoResult> {
|
||||
const { provider, endpoint } = params;
|
||||
if (!files?.length) {
|
||||
return { videos: [], files: [] };
|
||||
}
|
||||
|
|
@ -53,7 +55,19 @@ export async function encodeAndFormatVideos(
|
|||
}
|
||||
|
||||
const videoBuffer = Buffer.from(content, 'base64');
|
||||
const validation = await validateVideo(videoBuffer, videoBuffer.length, provider);
|
||||
|
||||
/** Extract configured file size limit from fileConfig for this endpoint */
|
||||
const configuredFileSizeLimit = getConfiguredFileSizeLimit(req, {
|
||||
provider,
|
||||
endpoint,
|
||||
});
|
||||
|
||||
const validation = await validateVideo(
|
||||
videoBuffer,
|
||||
videoBuffer.length,
|
||||
provider,
|
||||
configuredFileSizeLimit,
|
||||
);
|
||||
|
||||
if (!validation.isValid) {
|
||||
throw new Error(`Video validation failed: ${validation.error}`);
|
||||
|
|
|
|||
1313
packages/api/src/files/filter.spec.ts
Normal file
1313
packages/api/src/files/filter.spec.ts
Normal file
File diff suppressed because it is too large
Load diff
96
packages/api/src/files/filter.ts
Normal file
96
packages/api/src/files/filter.ts
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
import { getEndpointFileConfig, mergeFileConfig, fileConfig } from 'librechat-data-provider';
|
||||
import type { IMongoFile } from '@librechat/data-schemas';
|
||||
import type { ServerRequest } from '~/types';
|
||||
|
||||
/**
|
||||
* Checks if a MIME type is supported by the endpoint configuration
|
||||
* @param mimeType - The MIME type to check
|
||||
* @param supportedMimeTypes - Array of RegExp patterns to match against
|
||||
* @returns True if the MIME type matches any pattern
|
||||
*/
|
||||
function isMimeTypeSupported(mimeType: string, supportedMimeTypes?: RegExp[]): boolean {
|
||||
if (!supportedMimeTypes || supportedMimeTypes.length === 0) {
|
||||
return true;
|
||||
}
|
||||
return fileConfig.checkType(mimeType, supportedMimeTypes);
|
||||
}
|
||||
|
||||
/**
|
||||
* Filters out files based on endpoint configuration including:
|
||||
* - Disabled status
|
||||
* - File size limits
|
||||
* - MIME type restrictions
|
||||
* - Total size limits
|
||||
* @param req - The server request object containing config
|
||||
* @param params - Object containing files, endpoint, and endpointType
|
||||
* @param params.files - Array of processed file documents from MongoDB
|
||||
* @param params.endpoint - The endpoint name to check configuration for
|
||||
* @param params.endpointType - The endpoint type to check configuration for
|
||||
* @returns Filtered array of files
|
||||
*/
|
||||
export function filterFilesByEndpointConfig(
|
||||
req: ServerRequest,
|
||||
params: {
|
||||
files: IMongoFile[] | undefined;
|
||||
endpoint?: string | null;
|
||||
endpointType?: string | null;
|
||||
},
|
||||
): IMongoFile[] {
|
||||
const { files, endpoint, endpointType } = params;
|
||||
|
||||
if (!files || files.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const mergedFileConfig = mergeFileConfig(req.config?.fileConfig);
|
||||
const endpointFileConfig = getEndpointFileConfig({
|
||||
fileConfig: mergedFileConfig,
|
||||
endpoint,
|
||||
endpointType,
|
||||
});
|
||||
|
||||
/**
|
||||
* If endpoint has files explicitly disabled, filter out all files
|
||||
* Only filter if disabled is explicitly set to true
|
||||
*/
|
||||
if (endpointFileConfig?.disabled === true) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const { fileSizeLimit, supportedMimeTypes, totalSizeLimit } = endpointFileConfig;
|
||||
|
||||
/** Filter files based on individual file size and MIME type */
|
||||
let filteredFiles = files;
|
||||
|
||||
/** Filter by individual file size limit */
|
||||
if (fileSizeLimit !== undefined && fileSizeLimit > 0) {
|
||||
filteredFiles = filteredFiles.filter((file) => {
|
||||
return file.bytes <= fileSizeLimit;
|
||||
});
|
||||
}
|
||||
|
||||
/** Filter by MIME type */
|
||||
if (supportedMimeTypes && supportedMimeTypes.length > 0) {
|
||||
filteredFiles = filteredFiles.filter((file) => {
|
||||
return isMimeTypeSupported(file.type, supportedMimeTypes);
|
||||
});
|
||||
}
|
||||
|
||||
/** Filter by total size limit - keep files until total exceeds limit */
|
||||
if (totalSizeLimit !== undefined && totalSizeLimit > 0) {
|
||||
let totalSize = 0;
|
||||
const withinTotalLimit: IMongoFile[] = [];
|
||||
|
||||
for (let i = 0; i < filteredFiles.length; i++) {
|
||||
const file = filteredFiles[i];
|
||||
if (totalSize + file.bytes <= totalSizeLimit) {
|
||||
withinTotalLimit.push(file);
|
||||
totalSize += file.bytes;
|
||||
}
|
||||
}
|
||||
|
||||
filteredFiles = withinTotalLimit;
|
||||
}
|
||||
|
||||
return filteredFiles;
|
||||
}
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
export * from './audio';
|
||||
export * from './context';
|
||||
export * from './encode';
|
||||
export * from './filter';
|
||||
export * from './mistral/crud';
|
||||
export * from './ocr';
|
||||
export * from './parse';
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import FormData from 'form-data';
|
|||
import { createReadStream } from 'fs';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import { FileSources } from 'librechat-data-provider';
|
||||
import type { Request as ServerRequest } from 'express';
|
||||
import type { ServerRequest } from '~/types';
|
||||
import { logAxiosError, readFileAsString } from '~/utils';
|
||||
import { generateShortLivedToken } from '~/crypto/jwt';
|
||||
|
||||
|
|
@ -20,9 +20,7 @@ export async function parseText({
|
|||
file,
|
||||
file_id,
|
||||
}: {
|
||||
req: Pick<ServerRequest, 'user'> & {
|
||||
user?: { id: string };
|
||||
};
|
||||
req: ServerRequest;
|
||||
file: Express.Multer.File;
|
||||
file_id: string;
|
||||
}): Promise<{ text: string; bytes: number; source: string }> {
|
||||
|
|
|
|||
558
packages/api/src/files/validation.spec.ts
Normal file
558
packages/api/src/files/validation.spec.ts
Normal file
|
|
@ -0,0 +1,558 @@
|
|||
import { Providers } from '@librechat/agents';
|
||||
import { mbToBytes } from 'librechat-data-provider';
|
||||
import { validatePdf, validateVideo, validateAudio } from './validation';
|
||||
|
||||
describe('PDF Validation with fileConfig.endpoints.*.fileSizeLimit', () => {
|
||||
/** Helper to create a PDF buffer with valid header */
|
||||
const createMockPdfBuffer = (sizeInMB: number): Buffer => {
|
||||
const bytes = Math.floor(sizeInMB * 1024 * 1024);
|
||||
const buffer = Buffer.alloc(bytes);
|
||||
buffer.write('%PDF-1.4\n', 0);
|
||||
return buffer;
|
||||
};
|
||||
|
||||
describe('validatePdf - OpenAI provider', () => {
|
||||
const provider = Providers.OPENAI;
|
||||
|
||||
it('should accept PDF within provider limit when no config provided', async () => {
|
||||
const pdfBuffer = createMockPdfBuffer(8);
|
||||
const result = await validatePdf(pdfBuffer, pdfBuffer.length, provider);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.error).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should reject PDF exceeding provider limit when no config provided', async () => {
|
||||
const pdfBuffer = createMockPdfBuffer(12);
|
||||
const result = await validatePdf(pdfBuffer, pdfBuffer.length, provider);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toContain('12MB');
|
||||
expect(result.error).toContain('10MB');
|
||||
});
|
||||
|
||||
it('should use configured limit when it is lower than provider limit', async () => {
|
||||
const configuredLimit = 5 * 1024 * 1024; // 5MB
|
||||
const pdfBuffer = createMockPdfBuffer(7); // Between configured and provider limit
|
||||
const result = await validatePdf(pdfBuffer, pdfBuffer.length, provider, configuredLimit);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toContain('7MB');
|
||||
expect(result.error).toContain('5MB');
|
||||
});
|
||||
|
||||
it('should allow configured limit higher than provider default', async () => {
|
||||
const configuredLimit = 50 * 1024 * 1024; // 50MB (higher than 10MB provider default)
|
||||
const pdfBuffer = createMockPdfBuffer(12); // Between provider default and configured limit
|
||||
const result = await validatePdf(pdfBuffer, pdfBuffer.length, provider, configuredLimit);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.error).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should accept PDF within both configured and provider limits', async () => {
|
||||
const configuredLimit = 50 * 1024 * 1024; // 50MB
|
||||
const pdfBuffer = createMockPdfBuffer(8);
|
||||
const result = await validatePdf(pdfBuffer, pdfBuffer.length, provider, configuredLimit);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.error).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should accept PDF within lower configured limit', async () => {
|
||||
const configuredLimit = 5 * 1024 * 1024; // 5MB
|
||||
const pdfBuffer = createMockPdfBuffer(4);
|
||||
const result = await validatePdf(pdfBuffer, pdfBuffer.length, provider, configuredLimit);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.error).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should handle exact limit size correctly', async () => {
|
||||
const configuredLimit = 10 * 1024 * 1024; // Exactly 10MB
|
||||
const pdfBuffer = Buffer.alloc(10 * 1024 * 1024);
|
||||
pdfBuffer.write('%PDF-1.4\n', 0);
|
||||
const result = await validatePdf(pdfBuffer, pdfBuffer.length, provider, configuredLimit);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('validatePdf - Anthropic provider', () => {
|
||||
const provider = Providers.ANTHROPIC;
|
||||
|
||||
it('should accept PDF within provider limit when no config provided', async () => {
|
||||
const pdfBuffer = createMockPdfBuffer(20);
|
||||
const result = await validatePdf(pdfBuffer, pdfBuffer.length, provider);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.error).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should reject PDF exceeding provider limit when no config provided', async () => {
|
||||
const pdfBuffer = createMockPdfBuffer(35);
|
||||
const result = await validatePdf(pdfBuffer, pdfBuffer.length, provider);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toContain('35MB');
|
||||
expect(result.error).toContain('32MB');
|
||||
});
|
||||
|
||||
it('should use configured limit when it is lower than provider limit', async () => {
|
||||
const configuredLimit = mbToBytes(15); // 15MB
|
||||
const pdfBuffer = createMockPdfBuffer(20); // Between configured and provider limit
|
||||
const result = await validatePdf(pdfBuffer, pdfBuffer.length, provider, configuredLimit);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toContain('20MB');
|
||||
expect(result.error).toContain('15MB');
|
||||
});
|
||||
|
||||
it('should allow configured limit higher than provider default', async () => {
|
||||
const configuredLimit = mbToBytes(50); // 50MB (higher than 32MB provider default)
|
||||
const pdfBuffer = createMockPdfBuffer(35); // Between provider default and configured limit
|
||||
const result = await validatePdf(pdfBuffer, pdfBuffer.length, provider, configuredLimit);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.error).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should reject encrypted PDFs regardless of size', async () => {
|
||||
const pdfBuffer = Buffer.alloc(1024);
|
||||
pdfBuffer.write('%PDF-1.4\n', 0);
|
||||
pdfBuffer.write('/Encrypt ', 100);
|
||||
const result = await validatePdf(pdfBuffer, pdfBuffer.length, provider);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toContain('encrypted');
|
||||
});
|
||||
|
||||
it('should reject PDFs with invalid header', async () => {
|
||||
const pdfBuffer = Buffer.alloc(1024);
|
||||
pdfBuffer.write('INVALID', 0);
|
||||
const result = await validatePdf(pdfBuffer, pdfBuffer.length, provider);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toContain('PDF header');
|
||||
});
|
||||
|
||||
it('should reject PDFs that are too small', async () => {
|
||||
const pdfBuffer = Buffer.alloc(3);
|
||||
const result = await validatePdf(pdfBuffer, pdfBuffer.length, provider);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toContain('too small');
|
||||
});
|
||||
});
|
||||
|
||||
describe('validatePdf - Google provider', () => {
|
||||
const provider = Providers.GOOGLE;
|
||||
|
||||
it('should accept PDF within provider limit when no config provided', async () => {
|
||||
const pdfBuffer = createMockPdfBuffer(15);
|
||||
const result = await validatePdf(pdfBuffer, pdfBuffer.length, provider);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.error).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should reject PDF exceeding provider limit when no config provided', async () => {
|
||||
const pdfBuffer = createMockPdfBuffer(25);
|
||||
const result = await validatePdf(pdfBuffer, pdfBuffer.length, provider);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toContain('25MB');
|
||||
expect(result.error).toContain('20MB');
|
||||
});
|
||||
|
||||
it('should use configured limit when it is lower than provider limit', async () => {
|
||||
const configuredLimit = 10 * 1024 * 1024; // 10MB
|
||||
const pdfBuffer = createMockPdfBuffer(15); // Between configured and provider limit
|
||||
const result = await validatePdf(pdfBuffer, pdfBuffer.length, provider, configuredLimit);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toContain('15MB');
|
||||
expect(result.error).toContain('10MB');
|
||||
});
|
||||
|
||||
it('should allow configured limit higher than provider default', async () => {
|
||||
const configuredLimit = 50 * 1024 * 1024; // 50MB (higher than 20MB provider default)
|
||||
const pdfBuffer = createMockPdfBuffer(25); // Between provider default and configured limit
|
||||
const result = await validatePdf(pdfBuffer, pdfBuffer.length, provider, configuredLimit);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.error).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('validatePdf - VertexAI provider', () => {
|
||||
const provider = Providers.VERTEXAI;
|
||||
|
||||
it('should accept PDF within provider limit', async () => {
|
||||
const pdfBuffer = createMockPdfBuffer(15);
|
||||
const result = await validatePdf(pdfBuffer, pdfBuffer.length, provider);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
});
|
||||
|
||||
it('should respect configured limit', async () => {
|
||||
const configuredLimit = 10 * 1024 * 1024;
|
||||
const pdfBuffer = createMockPdfBuffer(15);
|
||||
const result = await validatePdf(pdfBuffer, pdfBuffer.length, provider, configuredLimit);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toContain('10MB');
|
||||
});
|
||||
});
|
||||
|
||||
describe('validatePdf - Azure OpenAI provider', () => {
|
||||
const provider = Providers.AZURE;
|
||||
|
||||
it('should accept PDF within OpenAI-like provider limit', async () => {
|
||||
const pdfBuffer = createMockPdfBuffer(8);
|
||||
const result = await validatePdf(pdfBuffer, pdfBuffer.length, provider);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
});
|
||||
|
||||
it('should respect configured limit for Azure', async () => {
|
||||
const configuredLimit = 5 * 1024 * 1024;
|
||||
const pdfBuffer = createMockPdfBuffer(7);
|
||||
const result = await validatePdf(pdfBuffer, pdfBuffer.length, provider, configuredLimit);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('validatePdf - Unsupported providers', () => {
|
||||
it('should return valid for providers without specific validation', async () => {
|
||||
const pdfBuffer = createMockPdfBuffer(100); // Very large file
|
||||
const provider = 'unsupported' as Providers;
|
||||
const result = await validatePdf(pdfBuffer, pdfBuffer.length, provider);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Edge cases', () => {
|
||||
it('should handle zero-configured limit', async () => {
|
||||
const configuredLimit = 0;
|
||||
const pdfBuffer = createMockPdfBuffer(1);
|
||||
const result = await validatePdf(
|
||||
pdfBuffer,
|
||||
pdfBuffer.length,
|
||||
Providers.OPENAI,
|
||||
configuredLimit,
|
||||
);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toContain('0MB');
|
||||
});
|
||||
|
||||
it('should handle very small PDF files', async () => {
|
||||
const pdfBuffer = Buffer.alloc(100);
|
||||
pdfBuffer.write('%PDF-1.4\n', 0);
|
||||
const result = await validatePdf(
|
||||
pdfBuffer,
|
||||
pdfBuffer.length,
|
||||
Providers.OPENAI,
|
||||
10 * 1024 * 1024,
|
||||
);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle configured limit equal to provider limit', async () => {
|
||||
const configuredLimit = 10 * 1024 * 1024; // Same as OpenAI provider limit
|
||||
const pdfBuffer = createMockPdfBuffer(12);
|
||||
const result = await validatePdf(
|
||||
pdfBuffer,
|
||||
pdfBuffer.length,
|
||||
Providers.OPENAI,
|
||||
configuredLimit,
|
||||
);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toContain('10MB');
|
||||
});
|
||||
|
||||
it('should use provider limit when configured limit is undefined', async () => {
|
||||
const pdfBuffer = createMockPdfBuffer(12);
|
||||
const result = await validatePdf(pdfBuffer, pdfBuffer.length, Providers.OPENAI, undefined);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toContain('10MB');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Bug reproduction - Original issue', () => {
|
||||
it('should reproduce the original bug scenario from issue description', async () => {
|
||||
/**
|
||||
* Original bug: User configures openAI.fileSizeLimit = 50MB in librechat.yaml
|
||||
* Uploads a 15MB PDF to OpenAI endpoint
|
||||
* Expected: Should be accepted (within 50MB config)
|
||||
* Actual (before fix): Rejected with "exceeds 10MB limit"
|
||||
*/
|
||||
const configuredLimit = mbToBytes(50); // User configured 50MB
|
||||
const pdfBuffer = createMockPdfBuffer(15); // User uploads 15MB file
|
||||
|
||||
const result = await validatePdf(
|
||||
pdfBuffer,
|
||||
pdfBuffer.length,
|
||||
Providers.OPENAI,
|
||||
configuredLimit,
|
||||
);
|
||||
|
||||
/**
|
||||
* After fix: Should be accepted because configured limit (50MB) overrides
|
||||
* provider default (10MB), allowing for API changes
|
||||
*/
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.error).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should allow user to set stricter limits than provider', async () => {
|
||||
/**
|
||||
* Use case: User wants to enforce stricter limits than provider allows
|
||||
* User configures openAI.fileSizeLimit = 5MB
|
||||
* Uploads a 7MB PDF to OpenAI endpoint
|
||||
* Expected: Should be rejected (exceeds 5MB configured limit)
|
||||
*/
|
||||
const configuredLimit = mbToBytes(5); // User configured 5MB
|
||||
const pdfBuffer = createMockPdfBuffer(7); // User uploads 7MB file
|
||||
|
||||
const result = await validatePdf(
|
||||
pdfBuffer,
|
||||
pdfBuffer.length,
|
||||
Providers.OPENAI,
|
||||
configuredLimit,
|
||||
);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toContain('7MB');
|
||||
expect(result.error).toContain('5MB');
|
||||
});
|
||||
|
||||
it('should allow upload within stricter user-configured limit', async () => {
|
||||
/**
|
||||
* User configures openAI.fileSizeLimit = 5MB
|
||||
* Uploads a 4MB PDF
|
||||
* Expected: Should be accepted
|
||||
*/
|
||||
const configuredLimit = mbToBytes(5);
|
||||
const pdfBuffer = createMockPdfBuffer(4);
|
||||
|
||||
const result = await validatePdf(
|
||||
pdfBuffer,
|
||||
pdfBuffer.length,
|
||||
Providers.OPENAI,
|
||||
configuredLimit,
|
||||
);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.error).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Video and Audio Validation with fileConfig', () => {
|
||||
/** Helper to create a mock video/audio buffer */
|
||||
const createMockMediaBuffer = (sizeInMB: number): Buffer => {
|
||||
const bytes = Math.floor(sizeInMB * 1024 * 1024);
|
||||
return Buffer.alloc(bytes);
|
||||
};
|
||||
|
||||
describe('validateVideo - Google provider', () => {
|
||||
const provider = Providers.GOOGLE;
|
||||
|
||||
it('should accept video within provider limit when no config provided', async () => {
|
||||
const videoBuffer = createMockMediaBuffer(15);
|
||||
const result = await validateVideo(videoBuffer, videoBuffer.length, provider);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.error).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should reject video exceeding provider limit when no config provided', async () => {
|
||||
const videoBuffer = createMockMediaBuffer(25);
|
||||
const result = await validateVideo(videoBuffer, videoBuffer.length, provider);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toContain('25MB');
|
||||
expect(result.error).toContain('20MB');
|
||||
});
|
||||
|
||||
it('should use configured limit when it is lower than provider limit', async () => {
|
||||
const configuredLimit = mbToBytes(10); // 10MB
|
||||
const videoBuffer = createMockMediaBuffer(15); // Between configured and provider limit
|
||||
const result = await validateVideo(
|
||||
videoBuffer,
|
||||
videoBuffer.length,
|
||||
provider,
|
||||
configuredLimit,
|
||||
);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toContain('15MB');
|
||||
expect(result.error).toContain('10MB');
|
||||
});
|
||||
|
||||
it('should allow configured limit higher than provider default', async () => {
|
||||
const configuredLimit = mbToBytes(50); // 50MB (higher than 20MB provider default)
|
||||
const videoBuffer = createMockMediaBuffer(25); // Between provider default and configured limit
|
||||
const result = await validateVideo(
|
||||
videoBuffer,
|
||||
videoBuffer.length,
|
||||
provider,
|
||||
configuredLimit,
|
||||
);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.error).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should accept video within lower configured limit', async () => {
|
||||
const configuredLimit = mbToBytes(8);
|
||||
const videoBuffer = createMockMediaBuffer(7);
|
||||
const result = await validateVideo(
|
||||
videoBuffer,
|
||||
videoBuffer.length,
|
||||
provider,
|
||||
configuredLimit,
|
||||
);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.error).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should reject videos that are too small', async () => {
|
||||
const videoBuffer = Buffer.alloc(5);
|
||||
const result = await validateVideo(videoBuffer, videoBuffer.length, provider);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toContain('too small');
|
||||
});
|
||||
});
|
||||
|
||||
describe('validateAudio - Google provider', () => {
|
||||
const provider = Providers.GOOGLE;
|
||||
|
||||
it('should accept audio within provider limit when no config provided', async () => {
|
||||
const audioBuffer = createMockMediaBuffer(15);
|
||||
const result = await validateAudio(audioBuffer, audioBuffer.length, provider);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.error).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should reject audio exceeding provider limit when no config provided', async () => {
|
||||
const audioBuffer = createMockMediaBuffer(25);
|
||||
const result = await validateAudio(audioBuffer, audioBuffer.length, provider);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toContain('25MB');
|
||||
expect(result.error).toContain('20MB');
|
||||
});
|
||||
|
||||
it('should use configured limit when it is lower than provider limit', async () => {
|
||||
const configuredLimit = mbToBytes(10); // 10MB
|
||||
const audioBuffer = createMockMediaBuffer(15); // Between configured and provider limit
|
||||
const result = await validateAudio(
|
||||
audioBuffer,
|
||||
audioBuffer.length,
|
||||
provider,
|
||||
configuredLimit,
|
||||
);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toContain('15MB');
|
||||
expect(result.error).toContain('10MB');
|
||||
});
|
||||
|
||||
it('should allow configured limit higher than provider default', async () => {
|
||||
const configuredLimit = mbToBytes(50); // 50MB (higher than 20MB provider default)
|
||||
const audioBuffer = createMockMediaBuffer(25); // Between provider default and configured limit
|
||||
const result = await validateAudio(
|
||||
audioBuffer,
|
||||
audioBuffer.length,
|
||||
provider,
|
||||
configuredLimit,
|
||||
);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.error).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should accept audio within lower configured limit', async () => {
|
||||
const configuredLimit = mbToBytes(8);
|
||||
const audioBuffer = createMockMediaBuffer(7);
|
||||
const result = await validateAudio(
|
||||
audioBuffer,
|
||||
audioBuffer.length,
|
||||
provider,
|
||||
configuredLimit,
|
||||
);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
expect(result.error).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should reject audio files that are too small', async () => {
|
||||
const audioBuffer = Buffer.alloc(5);
|
||||
const result = await validateAudio(audioBuffer, audioBuffer.length, provider);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toContain('too small');
|
||||
});
|
||||
});
|
||||
|
||||
describe('validateVideo and validateAudio - VertexAI provider', () => {
|
||||
const provider = Providers.VERTEXAI;
|
||||
|
||||
it('should respect configured video limit for VertexAI', async () => {
|
||||
const configuredLimit = mbToBytes(10);
|
||||
const videoBuffer = createMockMediaBuffer(15);
|
||||
const result = await validateVideo(
|
||||
videoBuffer,
|
||||
videoBuffer.length,
|
||||
provider,
|
||||
configuredLimit,
|
||||
);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toContain('10MB');
|
||||
});
|
||||
|
||||
it('should respect configured audio limit for VertexAI', async () => {
|
||||
const configuredLimit = mbToBytes(10);
|
||||
const audioBuffer = createMockMediaBuffer(15);
|
||||
const result = await validateAudio(
|
||||
audioBuffer,
|
||||
audioBuffer.length,
|
||||
provider,
|
||||
configuredLimit,
|
||||
);
|
||||
|
||||
expect(result.isValid).toBe(false);
|
||||
expect(result.error).toContain('10MB');
|
||||
});
|
||||
});
|
||||
|
||||
describe('validateVideo and validateAudio - Unsupported providers', () => {
|
||||
it('should return valid for video from unsupported provider', async () => {
|
||||
const videoBuffer = createMockMediaBuffer(100);
|
||||
const provider = Providers.OPENAI;
|
||||
const result = await validateVideo(videoBuffer, videoBuffer.length, provider);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
});
|
||||
|
||||
it('should return valid for audio from unsupported provider', async () => {
|
||||
const audioBuffer = createMockMediaBuffer(100);
|
||||
const provider = Providers.OPENAI;
|
||||
const result = await validateAudio(audioBuffer, audioBuffer.length, provider);
|
||||
|
||||
expect(result.isValid).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -16,21 +16,27 @@ export interface AudioValidationResult {
|
|||
error?: string;
|
||||
}
|
||||
|
||||
export interface ImageValidationResult {
|
||||
isValid: boolean;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
export async function validatePdf(
|
||||
pdfBuffer: Buffer,
|
||||
fileSize: number,
|
||||
provider: Providers,
|
||||
configuredFileSizeLimit?: number,
|
||||
): Promise<PDFValidationResult> {
|
||||
if (provider === Providers.ANTHROPIC) {
|
||||
return validateAnthropicPdf(pdfBuffer, fileSize);
|
||||
return validateAnthropicPdf(pdfBuffer, fileSize, configuredFileSizeLimit);
|
||||
}
|
||||
|
||||
if (isOpenAILikeProvider(provider)) {
|
||||
return validateOpenAIPdf(fileSize);
|
||||
return validateOpenAIPdf(fileSize, configuredFileSizeLimit);
|
||||
}
|
||||
|
||||
if (provider === Providers.GOOGLE || provider === Providers.VERTEXAI) {
|
||||
return validateGooglePdf(fileSize);
|
||||
return validateGooglePdf(fileSize, configuredFileSizeLimit);
|
||||
}
|
||||
|
||||
return { isValid: true };
|
||||
|
|
@ -40,17 +46,23 @@ export async function validatePdf(
|
|||
* Validates if a PDF meets Anthropic's requirements
|
||||
* @param pdfBuffer - The PDF file as a buffer
|
||||
* @param fileSize - The file size in bytes
|
||||
* @param configuredFileSizeLimit - Optional configured file size limit from fileConfig (in bytes)
|
||||
* @returns Promise that resolves to validation result
|
||||
*/
|
||||
async function validateAnthropicPdf(
|
||||
pdfBuffer: Buffer,
|
||||
fileSize: number,
|
||||
configuredFileSizeLimit?: number,
|
||||
): Promise<PDFValidationResult> {
|
||||
try {
|
||||
if (fileSize > mbToBytes(32)) {
|
||||
const providerLimit = mbToBytes(32);
|
||||
const effectiveLimit = configuredFileSizeLimit ?? providerLimit;
|
||||
|
||||
if (fileSize > effectiveLimit) {
|
||||
const limitMB = Math.round(effectiveLimit / (1024 * 1024));
|
||||
return {
|
||||
isValid: false,
|
||||
error: `PDF file size (${Math.round(fileSize / (1024 * 1024))}MB) exceeds Anthropic's 32MB limit`,
|
||||
error: `PDF file size (${Math.round(fileSize / (1024 * 1024))}MB) exceeds the ${limitMB}MB limit`,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -101,22 +113,48 @@ async function validateAnthropicPdf(
|
|||
}
|
||||
}
|
||||
|
||||
async function validateOpenAIPdf(fileSize: number): Promise<PDFValidationResult> {
|
||||
if (fileSize > 10 * 1024 * 1024) {
|
||||
/**
|
||||
* Validates if a PDF meets OpenAI's requirements
|
||||
* @param fileSize - The file size in bytes
|
||||
* @param configuredFileSizeLimit - Optional configured file size limit from fileConfig (in bytes)
|
||||
* @returns Promise that resolves to validation result
|
||||
*/
|
||||
async function validateOpenAIPdf(
|
||||
fileSize: number,
|
||||
configuredFileSizeLimit?: number,
|
||||
): Promise<PDFValidationResult> {
|
||||
const providerLimit = mbToBytes(10);
|
||||
const effectiveLimit = configuredFileSizeLimit ?? providerLimit;
|
||||
|
||||
if (fileSize > effectiveLimit) {
|
||||
const limitMB = Math.round(effectiveLimit / (1024 * 1024));
|
||||
return {
|
||||
isValid: false,
|
||||
error: "PDF file size exceeds OpenAI's 10MB limit",
|
||||
error: `PDF file size (${Math.round(fileSize / (1024 * 1024))}MB) exceeds the ${limitMB}MB limit`,
|
||||
};
|
||||
}
|
||||
|
||||
return { isValid: true };
|
||||
}
|
||||
|
||||
async function validateGooglePdf(fileSize: number): Promise<PDFValidationResult> {
|
||||
if (fileSize > 20 * 1024 * 1024) {
|
||||
/**
|
||||
* Validates if a PDF meets Google's requirements
|
||||
* @param fileSize - The file size in bytes
|
||||
* @param configuredFileSizeLimit - Optional configured file size limit from fileConfig (in bytes)
|
||||
* @returns Promise that resolves to validation result
|
||||
*/
|
||||
async function validateGooglePdf(
|
||||
fileSize: number,
|
||||
configuredFileSizeLimit?: number,
|
||||
): Promise<PDFValidationResult> {
|
||||
const providerLimit = mbToBytes(20);
|
||||
const effectiveLimit = configuredFileSizeLimit ?? providerLimit;
|
||||
|
||||
if (fileSize > effectiveLimit) {
|
||||
const limitMB = Math.round(effectiveLimit / (1024 * 1024));
|
||||
return {
|
||||
isValid: false,
|
||||
error: "PDF file size exceeds Google's 20MB limit",
|
||||
error: `PDF file size (${Math.round(fileSize / (1024 * 1024))}MB) exceeds the ${limitMB}MB limit`,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -128,18 +166,24 @@ async function validateGooglePdf(fileSize: number): Promise<PDFValidationResult>
|
|||
* @param videoBuffer - The video file as a buffer
|
||||
* @param fileSize - The file size in bytes
|
||||
* @param provider - The provider to validate for
|
||||
* @param configuredFileSizeLimit - Optional configured file size limit from fileConfig (in bytes)
|
||||
* @returns Promise that resolves to validation result
|
||||
*/
|
||||
export async function validateVideo(
|
||||
videoBuffer: Buffer,
|
||||
fileSize: number,
|
||||
provider: Providers,
|
||||
configuredFileSizeLimit?: number,
|
||||
): Promise<VideoValidationResult> {
|
||||
if (provider === Providers.GOOGLE || provider === Providers.VERTEXAI) {
|
||||
if (fileSize > 20 * 1024 * 1024) {
|
||||
const providerLimit = mbToBytes(20);
|
||||
const effectiveLimit = configuredFileSizeLimit ?? providerLimit;
|
||||
|
||||
if (fileSize > effectiveLimit) {
|
||||
const limitMB = Math.round(effectiveLimit / (1024 * 1024));
|
||||
return {
|
||||
isValid: false,
|
||||
error: `Video file size (${Math.round(fileSize / (1024 * 1024))}MB) exceeds Google's 20MB limit`,
|
||||
error: `Video file size (${Math.round(fileSize / (1024 * 1024))}MB) exceeds the ${limitMB}MB limit`,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
@ -159,18 +203,24 @@ export async function validateVideo(
|
|||
* @param audioBuffer - The audio file as a buffer
|
||||
* @param fileSize - The file size in bytes
|
||||
* @param provider - The provider to validate for
|
||||
* @param configuredFileSizeLimit - Optional configured file size limit from fileConfig (in bytes)
|
||||
* @returns Promise that resolves to validation result
|
||||
*/
|
||||
export async function validateAudio(
|
||||
audioBuffer: Buffer,
|
||||
fileSize: number,
|
||||
provider: Providers,
|
||||
configuredFileSizeLimit?: number,
|
||||
): Promise<AudioValidationResult> {
|
||||
if (provider === Providers.GOOGLE || provider === Providers.VERTEXAI) {
|
||||
if (fileSize > 20 * 1024 * 1024) {
|
||||
const providerLimit = mbToBytes(20);
|
||||
const effectiveLimit = configuredFileSizeLimit ?? providerLimit;
|
||||
|
||||
if (fileSize > effectiveLimit) {
|
||||
const limitMB = Math.round(effectiveLimit / (1024 * 1024));
|
||||
return {
|
||||
isValid: false,
|
||||
error: `Audio file size (${Math.round(fileSize / (1024 * 1024))}MB) exceeds Google's 20MB limit`,
|
||||
error: `Audio file size (${Math.round(fileSize / (1024 * 1024))}MB) exceeds the ${limitMB}MB limit`,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
@ -184,3 +234,53 @@ export async function validateAudio(
|
|||
|
||||
return { isValid: true };
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates image files for different providers
|
||||
* @param imageBuffer - The image file as a buffer
|
||||
* @param fileSize - The file size in bytes
|
||||
* @param provider - The provider to validate for
|
||||
* @param configuredFileSizeLimit - Optional configured file size limit from fileConfig (in bytes)
|
||||
* @returns Promise that resolves to validation result
|
||||
*/
|
||||
export async function validateImage(
|
||||
imageBuffer: Buffer,
|
||||
fileSize: number,
|
||||
provider: Providers | string,
|
||||
configuredFileSizeLimit?: number,
|
||||
): Promise<ImageValidationResult> {
|
||||
if (provider === Providers.GOOGLE || provider === Providers.VERTEXAI) {
|
||||
const providerLimit = mbToBytes(20);
|
||||
const effectiveLimit = configuredFileSizeLimit ?? providerLimit;
|
||||
|
||||
if (fileSize > effectiveLimit) {
|
||||
const limitMB = Math.round(effectiveLimit / (1024 * 1024));
|
||||
return {
|
||||
isValid: false,
|
||||
error: `Image file size (${Math.round(fileSize / (1024 * 1024))}MB) exceeds the ${limitMB}MB limit`,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (provider === Providers.ANTHROPIC) {
|
||||
const providerLimit = mbToBytes(5);
|
||||
const effectiveLimit = configuredFileSizeLimit ?? providerLimit;
|
||||
|
||||
if (fileSize > effectiveLimit) {
|
||||
const limitMB = Math.round(effectiveLimit / (1024 * 1024));
|
||||
return {
|
||||
isValid: false,
|
||||
error: `Image file size (${Math.round(fileSize / (1024 * 1024))}MB) exceeds the ${limitMB}MB limit`,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (!imageBuffer || imageBuffer.length < 10) {
|
||||
return {
|
||||
isValid: false,
|
||||
error: 'Invalid image file: too small or corrupted',
|
||||
};
|
||||
}
|
||||
|
||||
return { isValid: true };
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import { Keyv } from 'keyv';
|
||||
import { FlowStateManager } from './manager';
|
||||
import type { FlowState } from './types';
|
||||
import { FlowState } from './types';
|
||||
|
||||
/** Mock class without extending Keyv */
|
||||
class MockKeyv {
|
||||
|
|
@ -181,4 +181,214 @@ describe('FlowStateManager', () => {
|
|||
expect(result).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isFlowStale', () => {
|
||||
const flowId = 'test-flow-stale';
|
||||
const type = 'test-type';
|
||||
const flowKey = `${type}:${flowId}`;
|
||||
|
||||
it('returns not stale for non-existent flow', async () => {
|
||||
const result = await flowManager.isFlowStale(flowId, type);
|
||||
|
||||
expect(result).toEqual({
|
||||
isStale: false,
|
||||
age: 0,
|
||||
});
|
||||
});
|
||||
|
||||
it('returns not stale for PENDING flow regardless of age', async () => {
|
||||
const oldTimestamp = Date.now() - 10 * 60 * 1000; // 10 minutes ago
|
||||
await store.set(flowKey, {
|
||||
type,
|
||||
status: 'PENDING',
|
||||
metadata: {},
|
||||
createdAt: oldTimestamp,
|
||||
});
|
||||
|
||||
const result = await flowManager.isFlowStale(flowId, type, 2 * 60 * 1000);
|
||||
|
||||
expect(result).toEqual({
|
||||
isStale: false,
|
||||
age: 0,
|
||||
status: 'PENDING',
|
||||
});
|
||||
});
|
||||
|
||||
it('returns not stale for recently COMPLETED flow', async () => {
|
||||
const recentTimestamp = Date.now() - 30 * 1000; // 30 seconds ago
|
||||
await store.set(flowKey, {
|
||||
type,
|
||||
status: 'COMPLETED',
|
||||
metadata: {},
|
||||
createdAt: Date.now() - 60 * 1000,
|
||||
completedAt: recentTimestamp,
|
||||
});
|
||||
|
||||
const result = await flowManager.isFlowStale(flowId, type, 2 * 60 * 1000);
|
||||
|
||||
expect(result.isStale).toBe(false);
|
||||
expect(result.status).toBe('COMPLETED');
|
||||
expect(result.age).toBeGreaterThan(0);
|
||||
expect(result.age).toBeLessThan(60 * 1000);
|
||||
});
|
||||
|
||||
it('returns stale for old COMPLETED flow', async () => {
|
||||
const oldTimestamp = Date.now() - 5 * 60 * 1000; // 5 minutes ago
|
||||
await store.set(flowKey, {
|
||||
type,
|
||||
status: 'COMPLETED',
|
||||
metadata: {},
|
||||
createdAt: Date.now() - 10 * 60 * 1000,
|
||||
completedAt: oldTimestamp,
|
||||
});
|
||||
|
||||
const result = await flowManager.isFlowStale(flowId, type, 2 * 60 * 1000);
|
||||
|
||||
expect(result.isStale).toBe(true);
|
||||
expect(result.status).toBe('COMPLETED');
|
||||
expect(result.age).toBeGreaterThan(2 * 60 * 1000);
|
||||
});
|
||||
|
||||
it('returns not stale for recently FAILED flow', async () => {
|
||||
const recentTimestamp = Date.now() - 30 * 1000; // 30 seconds ago
|
||||
await store.set(flowKey, {
|
||||
type,
|
||||
status: 'FAILED',
|
||||
metadata: {},
|
||||
createdAt: Date.now() - 60 * 1000,
|
||||
failedAt: recentTimestamp,
|
||||
error: 'Test error',
|
||||
});
|
||||
|
||||
const result = await flowManager.isFlowStale(flowId, type, 2 * 60 * 1000);
|
||||
|
||||
expect(result.isStale).toBe(false);
|
||||
expect(result.status).toBe('FAILED');
|
||||
expect(result.age).toBeGreaterThan(0);
|
||||
expect(result.age).toBeLessThan(60 * 1000);
|
||||
});
|
||||
|
||||
it('returns stale for old FAILED flow', async () => {
|
||||
const oldTimestamp = Date.now() - 5 * 60 * 1000; // 5 minutes ago
|
||||
await store.set(flowKey, {
|
||||
type,
|
||||
status: 'FAILED',
|
||||
metadata: {},
|
||||
createdAt: Date.now() - 10 * 60 * 1000,
|
||||
failedAt: oldTimestamp,
|
||||
error: 'Test error',
|
||||
});
|
||||
|
||||
const result = await flowManager.isFlowStale(flowId, type, 2 * 60 * 1000);
|
||||
|
||||
expect(result.isStale).toBe(true);
|
||||
expect(result.status).toBe('FAILED');
|
||||
expect(result.age).toBeGreaterThan(2 * 60 * 1000);
|
||||
});
|
||||
|
||||
it('uses custom stale threshold', async () => {
|
||||
const timestamp = Date.now() - 90 * 1000; // 90 seconds ago
|
||||
await store.set(flowKey, {
|
||||
type,
|
||||
status: 'COMPLETED',
|
||||
metadata: {},
|
||||
createdAt: Date.now() - 2 * 60 * 1000,
|
||||
completedAt: timestamp,
|
||||
});
|
||||
|
||||
// 90 seconds old, threshold 60 seconds = stale
|
||||
const result1 = await flowManager.isFlowStale(flowId, type, 60 * 1000);
|
||||
expect(result1.isStale).toBe(true);
|
||||
|
||||
// 90 seconds old, threshold 120 seconds = not stale
|
||||
const result2 = await flowManager.isFlowStale(flowId, type, 120 * 1000);
|
||||
expect(result2.isStale).toBe(false);
|
||||
});
|
||||
|
||||
it('uses default threshold of 2 minutes when not specified', async () => {
|
||||
const timestamp = Date.now() - 3 * 60 * 1000; // 3 minutes ago
|
||||
await store.set(flowKey, {
|
||||
type,
|
||||
status: 'COMPLETED',
|
||||
metadata: {},
|
||||
createdAt: Date.now() - 5 * 60 * 1000,
|
||||
completedAt: timestamp,
|
||||
});
|
||||
|
||||
// Should use default 2 minute threshold
|
||||
const result = await flowManager.isFlowStale(flowId, type);
|
||||
|
||||
expect(result.isStale).toBe(true);
|
||||
expect(result.age).toBeGreaterThan(2 * 60 * 1000);
|
||||
});
|
||||
|
||||
it('falls back to createdAt when completedAt/failedAt are not present', async () => {
|
||||
const createdTimestamp = Date.now() - 5 * 60 * 1000; // 5 minutes ago
|
||||
await store.set(flowKey, {
|
||||
type,
|
||||
status: 'COMPLETED',
|
||||
metadata: {},
|
||||
createdAt: createdTimestamp,
|
||||
// No completedAt or failedAt
|
||||
});
|
||||
|
||||
const result = await flowManager.isFlowStale(flowId, type, 2 * 60 * 1000);
|
||||
|
||||
expect(result.isStale).toBe(true);
|
||||
expect(result.status).toBe('COMPLETED');
|
||||
expect(result.age).toBeGreaterThan(2 * 60 * 1000);
|
||||
});
|
||||
|
||||
it('handles flow with no timestamps', async () => {
|
||||
await store.set(flowKey, {
|
||||
type,
|
||||
status: 'COMPLETED',
|
||||
metadata: {},
|
||||
// No timestamps at all
|
||||
} as FlowState<string>);
|
||||
|
||||
const result = await flowManager.isFlowStale(flowId, type, 2 * 60 * 1000);
|
||||
|
||||
expect(result.isStale).toBe(false);
|
||||
expect(result.age).toBe(0);
|
||||
expect(result.status).toBe('COMPLETED');
|
||||
});
|
||||
|
||||
it('prefers completedAt over createdAt for age calculation', async () => {
|
||||
const createdTimestamp = Date.now() - 10 * 60 * 1000; // 10 minutes ago
|
||||
const completedTimestamp = Date.now() - 30 * 1000; // 30 seconds ago
|
||||
await store.set(flowKey, {
|
||||
type,
|
||||
status: 'COMPLETED',
|
||||
metadata: {},
|
||||
createdAt: createdTimestamp,
|
||||
completedAt: completedTimestamp,
|
||||
});
|
||||
|
||||
const result = await flowManager.isFlowStale(flowId, type, 2 * 60 * 1000);
|
||||
|
||||
// Should use completedAt (30s) not createdAt (10m)
|
||||
expect(result.isStale).toBe(false);
|
||||
expect(result.age).toBeLessThan(60 * 1000);
|
||||
});
|
||||
|
||||
it('prefers failedAt over createdAt for age calculation', async () => {
|
||||
const createdTimestamp = Date.now() - 10 * 60 * 1000; // 10 minutes ago
|
||||
const failedTimestamp = Date.now() - 30 * 1000; // 30 seconds ago
|
||||
await store.set(flowKey, {
|
||||
type,
|
||||
status: 'FAILED',
|
||||
metadata: {},
|
||||
createdAt: createdTimestamp,
|
||||
failedAt: failedTimestamp,
|
||||
error: 'Test error',
|
||||
});
|
||||
|
||||
const result = await flowManager.isFlowStale(flowId, type, 2 * 60 * 1000);
|
||||
|
||||
// Should use failedAt (30s) not createdAt (10m)
|
||||
expect(result.isStale).toBe(false);
|
||||
expect(result.age).toBeLessThan(60 * 1000);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -151,9 +151,25 @@ export class FlowStateManager<T = unknown> {
|
|||
const flowState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined;
|
||||
|
||||
if (!flowState) {
|
||||
logger.warn('[FlowStateManager] Cannot complete flow - flow state not found', {
|
||||
flowId,
|
||||
type,
|
||||
});
|
||||
return false;
|
||||
}
|
||||
|
||||
/** Prevent duplicate completion */
|
||||
if (flowState.status === 'COMPLETED') {
|
||||
logger.debug(
|
||||
'[FlowStateManager] Flow already completed, skipping to prevent duplicate completion',
|
||||
{
|
||||
flowId,
|
||||
type,
|
||||
},
|
||||
);
|
||||
return true;
|
||||
}
|
||||
|
||||
const updatedState: FlowState<T> = {
|
||||
...flowState,
|
||||
status: 'COMPLETED',
|
||||
|
|
@ -162,9 +178,55 @@ export class FlowStateManager<T = unknown> {
|
|||
};
|
||||
|
||||
await this.keyv.set(flowKey, updatedState, this.ttl);
|
||||
|
||||
logger.debug('[FlowStateManager] Flow completed successfully', {
|
||||
flowId,
|
||||
type,
|
||||
});
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a flow is stale based on its age and status
|
||||
* @param flowId - The flow identifier
|
||||
* @param type - The flow type
|
||||
* @param staleThresholdMs - Age in milliseconds after which a non-pending flow is considered stale (default: 2 minutes)
|
||||
* @returns Object with isStale boolean and age in milliseconds
|
||||
*/
|
||||
async isFlowStale(
|
||||
flowId: string,
|
||||
type: string,
|
||||
staleThresholdMs: number = 2 * 60 * 1000,
|
||||
): Promise<{ isStale: boolean; age: number; status?: string }> {
|
||||
const flowKey = this.getFlowKey(flowId, type);
|
||||
const flowState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined;
|
||||
|
||||
if (!flowState) {
|
||||
return { isStale: false, age: 0 };
|
||||
}
|
||||
|
||||
if (flowState.status === 'PENDING') {
|
||||
return { isStale: false, age: 0, status: flowState.status };
|
||||
}
|
||||
|
||||
const completedAt = flowState.completedAt || flowState.failedAt;
|
||||
const createdAt = flowState.createdAt;
|
||||
|
||||
let flowAge = 0;
|
||||
if (completedAt) {
|
||||
flowAge = Date.now() - completedAt;
|
||||
} else if (createdAt) {
|
||||
flowAge = Date.now() - createdAt;
|
||||
}
|
||||
|
||||
return {
|
||||
isStale: flowAge > staleThresholdMs,
|
||||
age: flowAge,
|
||||
status: flowState.status,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Marks a flow as failed
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -1,340 +0,0 @@
|
|||
import { ContentTypes } from 'librechat-data-provider';
|
||||
import { HumanMessage, AIMessage, SystemMessage } from '@langchain/core/messages';
|
||||
import { formatContentStrings } from './content';
|
||||
|
||||
describe('formatContentStrings', () => {
|
||||
describe('Human messages', () => {
|
||||
it('should convert human message with all text blocks to string', () => {
|
||||
const messages = [
|
||||
new HumanMessage({
|
||||
content: [
|
||||
{ type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'Hello' },
|
||||
{ type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'World' },
|
||||
],
|
||||
}),
|
||||
];
|
||||
|
||||
const result = formatContentStrings(messages);
|
||||
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].content).toBe('Hello\nWorld');
|
||||
});
|
||||
|
||||
it('should not convert human message with mixed content types (text + image)', () => {
|
||||
const messages = [
|
||||
new HumanMessage({
|
||||
content: [
|
||||
{ type: ContentTypes.TEXT, text: 'what do you see' },
|
||||
{
|
||||
type: 'image_url',
|
||||
image_url: {
|
||||
url: '_SOME_BASE64_DATA=',
|
||||
detail: 'auto',
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
];
|
||||
|
||||
const result = formatContentStrings(messages);
|
||||
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].content).toEqual([
|
||||
{ type: ContentTypes.TEXT, text: 'what do you see' },
|
||||
{
|
||||
type: 'image_url',
|
||||
image_url: {
|
||||
url: '_SOME_BASE64_DATA=',
|
||||
detail: 'auto',
|
||||
},
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('should leave string content unchanged', () => {
|
||||
const messages = [
|
||||
new HumanMessage({
|
||||
content: 'Hello World',
|
||||
}),
|
||||
];
|
||||
|
||||
const result = formatContentStrings(messages);
|
||||
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].content).toBe('Hello World');
|
||||
});
|
||||
|
||||
it('should handle empty text blocks', () => {
|
||||
const messages = [
|
||||
new HumanMessage({
|
||||
content: [
|
||||
{ type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'Hello' },
|
||||
{ type: ContentTypes.TEXT, [ContentTypes.TEXT]: '' },
|
||||
{ type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'World' },
|
||||
],
|
||||
}),
|
||||
];
|
||||
|
||||
const result = formatContentStrings(messages);
|
||||
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].content).toBe('Hello\n\nWorld');
|
||||
});
|
||||
|
||||
it('should handle null/undefined text values', () => {
|
||||
const messages = [
|
||||
new HumanMessage({
|
||||
content: [
|
||||
{ type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'Hello' },
|
||||
{ type: ContentTypes.TEXT, [ContentTypes.TEXT]: null },
|
||||
{ type: ContentTypes.TEXT, [ContentTypes.TEXT]: undefined },
|
||||
{ type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'World' },
|
||||
],
|
||||
}),
|
||||
];
|
||||
|
||||
const result = formatContentStrings(messages);
|
||||
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].content).toBe('Hello\n\n\nWorld');
|
||||
});
|
||||
});
|
||||
|
||||
describe('AI messages', () => {
|
||||
it('should convert AI message with all text blocks to string', () => {
|
||||
const messages = [
|
||||
new AIMessage({
|
||||
content: [
|
||||
{ type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'Hello' },
|
||||
{ type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'World' },
|
||||
],
|
||||
}),
|
||||
];
|
||||
|
||||
const result = formatContentStrings(messages);
|
||||
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].content).toBe('Hello\nWorld');
|
||||
expect(result[0].getType()).toBe('ai');
|
||||
});
|
||||
|
||||
it('should not convert AI message with mixed content types', () => {
|
||||
const messages = [
|
||||
new AIMessage({
|
||||
content: [
|
||||
{ type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'Here is an image' },
|
||||
{ type: ContentTypes.TOOL_CALL, tool_call: { name: 'generate_image' } },
|
||||
],
|
||||
}),
|
||||
];
|
||||
|
||||
const result = formatContentStrings(messages);
|
||||
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].content).toEqual([
|
||||
{ type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'Here is an image' },
|
||||
{ type: ContentTypes.TOOL_CALL, tool_call: { name: 'generate_image' } },
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('System messages', () => {
|
||||
it('should convert System message with all text blocks to string', () => {
|
||||
const messages = [
|
||||
new SystemMessage({
|
||||
content: [
|
||||
{ type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'System' },
|
||||
{ type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'Message' },
|
||||
],
|
||||
}),
|
||||
];
|
||||
|
||||
const result = formatContentStrings(messages);
|
||||
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].content).toBe('System\nMessage');
|
||||
expect(result[0].getType()).toBe('system');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Mixed message types', () => {
|
||||
it('should process all valid message types in mixed array', () => {
|
||||
const messages = [
|
||||
new HumanMessage({
|
||||
content: [
|
||||
{ type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'Human' },
|
||||
{ type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'Message' },
|
||||
],
|
||||
}),
|
||||
new AIMessage({
|
||||
content: [
|
||||
{ type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'AI' },
|
||||
{ type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'Response' },
|
||||
],
|
||||
}),
|
||||
new SystemMessage({
|
||||
content: [
|
||||
{ type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'System' },
|
||||
{ type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'Prompt' },
|
||||
],
|
||||
}),
|
||||
];
|
||||
|
||||
const result = formatContentStrings(messages);
|
||||
|
||||
expect(result).toHaveLength(3);
|
||||
// All messages should be converted
|
||||
expect(result[0].content).toBe('Human\nMessage');
|
||||
expect(result[0].getType()).toBe('human');
|
||||
|
||||
expect(result[1].content).toBe('AI\nResponse');
|
||||
expect(result[1].getType()).toBe('ai');
|
||||
|
||||
expect(result[2].content).toBe('System\nPrompt');
|
||||
expect(result[2].getType()).toBe('system');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Edge cases', () => {
|
||||
it('should handle empty array', () => {
|
||||
const result = formatContentStrings([]);
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
|
||||
it('should handle messages with non-array content', () => {
|
||||
const messages = [
|
||||
new HumanMessage({
|
||||
content: 'This is a string content',
|
||||
}),
|
||||
];
|
||||
|
||||
const result = formatContentStrings(messages);
|
||||
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].content).toBe('This is a string content');
|
||||
});
|
||||
|
||||
it('should trim the final concatenated string', () => {
|
||||
const messages = [
|
||||
new HumanMessage({
|
||||
content: [
|
||||
{ type: ContentTypes.TEXT, [ContentTypes.TEXT]: ' Hello ' },
|
||||
{ type: ContentTypes.TEXT, [ContentTypes.TEXT]: ' World ' },
|
||||
],
|
||||
}),
|
||||
];
|
||||
|
||||
const result = formatContentStrings(messages);
|
||||
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].content).toBe('Hello \n World');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Real-world scenarios', () => {
|
||||
it('should handle the exact scenario from the issue', () => {
|
||||
const messages = [
|
||||
new HumanMessage({
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'hi there',
|
||||
},
|
||||
],
|
||||
}),
|
||||
new AIMessage({
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'Hi Danny! How can I help you today?',
|
||||
},
|
||||
],
|
||||
}),
|
||||
new HumanMessage({
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'what do you see',
|
||||
},
|
||||
{
|
||||
type: 'image_url',
|
||||
image_url: {
|
||||
url: '_SOME_BASE64_DATA=',
|
||||
detail: 'auto',
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
];
|
||||
|
||||
const result = formatContentStrings(messages);
|
||||
|
||||
expect(result).toHaveLength(3);
|
||||
|
||||
// First human message (all text) should be converted
|
||||
expect(result[0].content).toBe('hi there');
|
||||
expect(result[0].getType()).toBe('human');
|
||||
|
||||
// AI message (all text) should now also be converted
|
||||
expect(result[1].content).toBe('Hi Danny! How can I help you today?');
|
||||
expect(result[1].getType()).toBe('ai');
|
||||
|
||||
// Third message (mixed content) should remain unchanged
|
||||
expect(result[2].content).toEqual([
|
||||
{
|
||||
type: 'text',
|
||||
text: 'what do you see',
|
||||
},
|
||||
{
|
||||
type: 'image_url',
|
||||
image_url: {
|
||||
url: '_SOME_BASE64_DATA=',
|
||||
detail: 'auto',
|
||||
},
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle messages with tool calls', () => {
|
||||
const messages = [
|
||||
new HumanMessage({
|
||||
content: [
|
||||
{ type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'Please use the calculator' },
|
||||
{
|
||||
type: ContentTypes.TOOL_CALL,
|
||||
tool_call: { name: 'calculator', args: '{"a": 1, "b": 2}' },
|
||||
},
|
||||
],
|
||||
}),
|
||||
new AIMessage({
|
||||
content: [
|
||||
{ type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'I will calculate that for you' },
|
||||
{
|
||||
type: ContentTypes.TOOL_CALL,
|
||||
tool_call: { name: 'calculator', args: '{"a": 1, "b": 2}' },
|
||||
},
|
||||
],
|
||||
}),
|
||||
];
|
||||
|
||||
const result = formatContentStrings(messages);
|
||||
|
||||
expect(result).toHaveLength(2);
|
||||
// Should not convert because not all blocks are text
|
||||
expect(result[0].content).toEqual([
|
||||
{ type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'Please use the calculator' },
|
||||
{
|
||||
type: ContentTypes.TOOL_CALL,
|
||||
tool_call: { name: 'calculator', args: '{"a": 1, "b": 2}' },
|
||||
},
|
||||
]);
|
||||
expect(result[1].content).toEqual([
|
||||
{ type: ContentTypes.TEXT, [ContentTypes.TEXT]: 'I will calculate that for you' },
|
||||
{
|
||||
type: ContentTypes.TOOL_CALL,
|
||||
tool_call: { name: 'calculator', args: '{"a": 1, "b": 2}' },
|
||||
},
|
||||
]);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -1,57 +0,0 @@
|
|||
import { ContentTypes } from 'librechat-data-provider';
|
||||
import type { BaseMessage } from '@langchain/core/messages';
|
||||
|
||||
/**
|
||||
* Formats an array of messages for LangChain, making sure all content fields are strings
|
||||
* @param {Array<HumanMessage | AIMessage | SystemMessage | ToolMessage>} payload - The array of messages to format.
|
||||
* @returns {Array<HumanMessage | AIMessage | SystemMessage | ToolMessage>} - The array of formatted LangChain messages, including ToolMessages for tool calls.
|
||||
*/
|
||||
export const formatContentStrings = (payload: Array<BaseMessage>): Array<BaseMessage> => {
|
||||
// Create a new array to store the processed messages
|
||||
const result: Array<BaseMessage> = [];
|
||||
|
||||
for (const message of payload) {
|
||||
const messageType = message.getType();
|
||||
const isValidMessage =
|
||||
messageType === 'human' || messageType === 'ai' || messageType === 'system';
|
||||
|
||||
if (!isValidMessage) {
|
||||
result.push(message);
|
||||
continue;
|
||||
}
|
||||
|
||||
// If content is already a string, add as-is
|
||||
if (typeof message.content === 'string') {
|
||||
result.push(message);
|
||||
continue;
|
||||
}
|
||||
|
||||
// If content is not an array, add as-is
|
||||
if (!Array.isArray(message.content)) {
|
||||
result.push(message);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if all content blocks are text type
|
||||
const allTextBlocks = message.content.every((block) => block.type === ContentTypes.TEXT);
|
||||
|
||||
// Only convert to string if all blocks are text type
|
||||
if (!allTextBlocks) {
|
||||
result.push(message);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Reduce text types to a single string
|
||||
const content = message.content.reduce((acc, curr) => {
|
||||
if (curr.type === ContentTypes.TEXT) {
|
||||
return `${acc}${curr[ContentTypes.TEXT] || ''}\n`;
|
||||
}
|
||||
return acc;
|
||||
}, '');
|
||||
|
||||
message.content = content.trim();
|
||||
result.push(message);
|
||||
}
|
||||
|
||||
return result;
|
||||
};
|
||||
|
|
@ -1 +0,0 @@
|
|||
export * from './content';
|
||||
|
|
@ -3,13 +3,13 @@ export * from './cdn';
|
|||
/* Auth */
|
||||
export * from './auth';
|
||||
/* MCP */
|
||||
export * from './mcp/registry/MCPServersRegistry';
|
||||
export * from './mcp/MCPManager';
|
||||
export * from './mcp/connection';
|
||||
export * from './mcp/oauth';
|
||||
export * from './mcp/auth';
|
||||
export * from './mcp/zod';
|
||||
/* Utilities */
|
||||
export * from './format';
|
||||
export * from './mcp/utils';
|
||||
export * from './utils';
|
||||
export * from './db/utils';
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import type { FlowMetadata } from '~/flow/types';
|
|||
import type * as t from './types';
|
||||
import { MCPTokenStorage, MCPOAuthHandler } from '~/mcp/oauth';
|
||||
import { sanitizeUrlForLogging } from './utils';
|
||||
import { withTimeout } from '~/utils/promise';
|
||||
import { MCPConnection } from './connection';
|
||||
import { processMCPEnv } from '~/utils';
|
||||
|
||||
|
|
@ -231,14 +232,11 @@ export class MCPConnectionFactory {
|
|||
/** Attempts to establish connection with timeout handling */
|
||||
protected async attemptToConnect(connection: MCPConnection): Promise<void> {
|
||||
const connectTimeout = this.connectionTimeout ?? this.serverConfig.initTimeout ?? 30000;
|
||||
const connectionTimeout = new Promise<void>((_, reject) =>
|
||||
setTimeout(
|
||||
() => reject(new Error(`Connection timeout after ${connectTimeout}ms`)),
|
||||
connectTimeout,
|
||||
),
|
||||
await withTimeout(
|
||||
this.connectTo(connection),
|
||||
connectTimeout,
|
||||
`Connection timeout after ${connectTimeout}ms`,
|
||||
);
|
||||
const connectionAttempt = this.connectTo(connection);
|
||||
await Promise.race([connectionAttempt, connectionTimeout]);
|
||||
|
||||
if (await connection.isConnected()) return;
|
||||
logger.error(`${this.logPrefix} Failed to establish connection.`);
|
||||
|
|
@ -331,6 +329,7 @@ export class MCPConnectionFactory {
|
|||
|
||||
/** Check if there's already an ongoing OAuth flow for this flowId */
|
||||
const existingFlow = await this.flowManager.getFlowState(flowId, 'mcp_oauth');
|
||||
|
||||
if (existingFlow && existingFlow.status === 'PENDING') {
|
||||
logger.debug(
|
||||
`${this.logPrefix} OAuth flow already exists for ${flowId}, waiting for completion`,
|
||||
|
|
@ -351,6 +350,38 @@ export class MCPConnectionFactory {
|
|||
return { tokens, clientInfo };
|
||||
}
|
||||
|
||||
// Clean up old completed/failed flows, but only if they're actually stale
|
||||
// This prevents race conditions where we delete a flow that's still being processed
|
||||
if (existingFlow && existingFlow.status !== 'PENDING') {
|
||||
const STALE_FLOW_THRESHOLD = 2 * 60 * 1000; // 2 minutes
|
||||
const { isStale, age, status } = await this.flowManager.isFlowStale(
|
||||
flowId,
|
||||
'mcp_oauth',
|
||||
STALE_FLOW_THRESHOLD,
|
||||
);
|
||||
|
||||
if (isStale) {
|
||||
try {
|
||||
await this.flowManager.deleteFlow(flowId, 'mcp_oauth');
|
||||
logger.debug(
|
||||
`${this.logPrefix} Cleared stale ${status} OAuth flow (age: ${Math.round(age / 1000)}s)`,
|
||||
);
|
||||
} catch (error) {
|
||||
logger.warn(`${this.logPrefix} Failed to clear stale OAuth flow`, error);
|
||||
}
|
||||
} else {
|
||||
logger.debug(
|
||||
`${this.logPrefix} Skipping cleanup of recent ${status} flow (age: ${Math.round(age / 1000)}s, threshold: ${STALE_FLOW_THRESHOLD / 1000}s)`,
|
||||
);
|
||||
// If flow is recent but not pending, something might be wrong
|
||||
if (status === 'FAILED') {
|
||||
logger.warn(
|
||||
`${this.logPrefix} Recent OAuth flow failed, will retry after ${Math.round((STALE_FLOW_THRESHOLD - age) / 1000)}s`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug(`${this.logPrefix} Initiating new OAuth flow for ${this.serverName}...`);
|
||||
const {
|
||||
authorizationUrl,
|
||||
|
|
|
|||
|
|
@ -5,11 +5,14 @@ import type { RequestOptions } from '@modelcontextprotocol/sdk/shared/protocol.j
|
|||
import type { TokenMethods } from '@librechat/data-schemas';
|
||||
import type { FlowStateManager } from '~/flow/manager';
|
||||
import type { TUser } from 'librechat-data-provider';
|
||||
import type { MCPOAuthTokens } from '~/mcp/oauth';
|
||||
import type { MCPOAuthTokens } from './oauth';
|
||||
import type { RequestBody } from '~/types';
|
||||
import type * as t from './types';
|
||||
import { UserConnectionManager } from '~/mcp/UserConnectionManager';
|
||||
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
|
||||
import { UserConnectionManager } from './UserConnectionManager';
|
||||
import { ConnectionsRepository } from './ConnectionsRepository';
|
||||
import { MCPServerInspector } from './registry/MCPServerInspector';
|
||||
import { MCPServersInitializer } from './registry/MCPServersInitializer';
|
||||
import { mcpServersRegistry as registry } from './registry/MCPServersRegistry';
|
||||
import { formatToolContent } from './parsers';
|
||||
import { MCPConnection } from './connection';
|
||||
import { processMCPEnv } from '~/utils/env';
|
||||
|
|
@ -24,8 +27,8 @@ export class MCPManager extends UserConnectionManager {
|
|||
/** Creates and initializes the singleton MCPManager instance */
|
||||
public static async createInstance(configs: t.MCPServers): Promise<MCPManager> {
|
||||
if (MCPManager.instance) throw new Error('MCPManager has already been initialized.');
|
||||
MCPManager.instance = new MCPManager(configs);
|
||||
await MCPManager.instance.initialize();
|
||||
MCPManager.instance = new MCPManager();
|
||||
await MCPManager.instance.initialize(configs);
|
||||
return MCPManager.instance;
|
||||
}
|
||||
|
||||
|
|
@ -36,9 +39,10 @@ export class MCPManager extends UserConnectionManager {
|
|||
}
|
||||
|
||||
/** Initializes the MCPManager by setting up server registry and app connections */
|
||||
public async initialize() {
|
||||
await this.serversRegistry.initialize();
|
||||
this.appConnections = new ConnectionsRepository(this.serversRegistry.appServerConfigs);
|
||||
public async initialize(configs: t.MCPServers) {
|
||||
await MCPServersInitializer.initialize(configs);
|
||||
const appConfigs = await registry.sharedAppServers.getAll();
|
||||
this.appConnections = new ConnectionsRepository(appConfigs);
|
||||
}
|
||||
|
||||
/** Retrieves an app-level or user-specific connection based on provided arguments */
|
||||
|
|
@ -62,36 +66,18 @@ export class MCPManager extends UserConnectionManager {
|
|||
}
|
||||
}
|
||||
|
||||
/** Get servers that require OAuth */
|
||||
public getOAuthServers(): Set<string> {
|
||||
return this.serversRegistry.oauthServers;
|
||||
}
|
||||
|
||||
/** Get all servers */
|
||||
public getAllServers(): t.MCPServers {
|
||||
return this.serversRegistry.rawConfigs;
|
||||
}
|
||||
|
||||
/** Returns all available tool functions from app-level connections */
|
||||
public getAppToolFunctions(): t.LCAvailableTools {
|
||||
return this.serversRegistry.toolFunctions;
|
||||
public async getAppToolFunctions(): Promise<t.LCAvailableTools> {
|
||||
const toolFunctions: t.LCAvailableTools = {};
|
||||
const configs = await registry.getAllServerConfigs();
|
||||
for (const config of Object.values(configs)) {
|
||||
if (config.toolFunctions != null) {
|
||||
Object.assign(toolFunctions, config.toolFunctions);
|
||||
}
|
||||
}
|
||||
return toolFunctions;
|
||||
}
|
||||
|
||||
/** Returns all available tool functions from all connections available to user */
|
||||
public async getAllToolFunctions(userId: string): Promise<t.LCAvailableTools | null> {
|
||||
const allToolFunctions: t.LCAvailableTools = this.getAppToolFunctions();
|
||||
const userConnections = this.getUserConnections(userId);
|
||||
if (!userConnections || userConnections.size === 0) {
|
||||
return allToolFunctions;
|
||||
}
|
||||
|
||||
for (const [serverName, connection] of userConnections.entries()) {
|
||||
const toolFunctions = await this.serversRegistry.getToolFunctions(serverName, connection);
|
||||
Object.assign(allToolFunctions, toolFunctions);
|
||||
}
|
||||
|
||||
return allToolFunctions;
|
||||
}
|
||||
/** Returns all available tool functions from all connections available to user */
|
||||
public async getServerToolFunctions(
|
||||
userId: string,
|
||||
|
|
@ -99,7 +85,7 @@ export class MCPManager extends UserConnectionManager {
|
|||
): Promise<t.LCAvailableTools | null> {
|
||||
try {
|
||||
if (this.appConnections?.has(serverName)) {
|
||||
return this.serversRegistry.getToolFunctions(
|
||||
return MCPServerInspector.getToolFunctions(
|
||||
serverName,
|
||||
await this.appConnections.get(serverName),
|
||||
);
|
||||
|
|
@ -113,7 +99,7 @@ export class MCPManager extends UserConnectionManager {
|
|||
return null;
|
||||
}
|
||||
|
||||
return this.serversRegistry.getToolFunctions(serverName, userConnections.get(serverName)!);
|
||||
return MCPServerInspector.getToolFunctions(serverName, userConnections.get(serverName)!);
|
||||
} catch (error) {
|
||||
logger.warn(
|
||||
`[getServerToolFunctions] Error getting tool functions for server ${serverName}`,
|
||||
|
|
@ -128,8 +114,14 @@ export class MCPManager extends UserConnectionManager {
|
|||
* @param serverNames Optional array of server names. If not provided or empty, returns all servers.
|
||||
* @returns Object mapping server names to their instructions
|
||||
*/
|
||||
public getInstructions(serverNames?: string[]): Record<string, string> {
|
||||
const instructions = this.serversRegistry.serverInstructions;
|
||||
private async getInstructions(serverNames?: string[]): Promise<Record<string, string>> {
|
||||
const instructions: Record<string, string> = {};
|
||||
const configs = await registry.getAllServerConfigs();
|
||||
for (const [serverName, config] of Object.entries(configs)) {
|
||||
if (config.serverInstructions != null) {
|
||||
instructions[serverName] = config.serverInstructions as string;
|
||||
}
|
||||
}
|
||||
if (!serverNames) return instructions;
|
||||
return pick(instructions, serverNames);
|
||||
}
|
||||
|
|
@ -139,9 +131,9 @@ export class MCPManager extends UserConnectionManager {
|
|||
* @param serverNames Optional array of server names to include. If not provided, includes all servers.
|
||||
* @returns Formatted instructions string ready for context injection
|
||||
*/
|
||||
public formatInstructionsForContext(serverNames?: string[]): string {
|
||||
public async formatInstructionsForContext(serverNames?: string[]): Promise<string> {
|
||||
/** Instructions for specified servers or all stored instructions */
|
||||
const instructionsToInclude = this.getInstructions(serverNames);
|
||||
const instructionsToInclude = await this.getInstructions(serverNames);
|
||||
|
||||
if (Object.keys(instructionsToInclude).length === 0) {
|
||||
return '';
|
||||
|
|
@ -225,7 +217,7 @@ Please follow these instructions when using tools from the respective MCP server
|
|||
);
|
||||
}
|
||||
|
||||
const rawConfig = this.getRawConfig(serverName) as t.MCPOptions;
|
||||
const rawConfig = (await registry.getServerConfig(serverName, userId)) as t.MCPOptions;
|
||||
const currentOptions = processMCPEnv({
|
||||
user,
|
||||
options: rawConfig,
|
||||
|
|
|
|||
|
|
@ -1,230 +0,0 @@
|
|||
import mapValues from 'lodash/mapValues';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import { Constants } from 'librechat-data-provider';
|
||||
import type { JsonSchemaType } from '@librechat/data-schemas';
|
||||
import type { MCPConnection } from '~/mcp/connection';
|
||||
import type * as t from '~/mcp/types';
|
||||
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
|
||||
import { detectOAuthRequirement } from '~/mcp/oauth';
|
||||
import { sanitizeUrlForLogging } from '~/mcp/utils';
|
||||
import { processMCPEnv, isEnabled } from '~/utils';
|
||||
|
||||
const DEFAULT_MCP_INIT_TIMEOUT_MS = 30_000;
|
||||
|
||||
function getMCPInitTimeout(): number {
|
||||
return process.env.MCP_INIT_TIMEOUT_MS != null
|
||||
? parseInt(process.env.MCP_INIT_TIMEOUT_MS)
|
||||
: DEFAULT_MCP_INIT_TIMEOUT_MS;
|
||||
}
|
||||
|
||||
/**
|
||||
* Manages MCP server configurations and metadata discovery.
|
||||
* Fetches server capabilities, OAuth requirements, and tool definitions for registry.
|
||||
* Determines which servers are for app-level connections.
|
||||
* Has its own connections repository. All connections are disconnected after initialization.
|
||||
*/
|
||||
export class MCPServersRegistry {
|
||||
private initialized: boolean = false;
|
||||
private connections: ConnectionsRepository;
|
||||
private initTimeoutMs: number;
|
||||
|
||||
public readonly rawConfigs: t.MCPServers;
|
||||
public readonly parsedConfigs: Record<string, t.ParsedServerConfig>;
|
||||
|
||||
public oauthServers: Set<string> = new Set();
|
||||
public serverInstructions: Record<string, string> = {};
|
||||
public toolFunctions: t.LCAvailableTools = {};
|
||||
public appServerConfigs: t.MCPServers = {};
|
||||
|
||||
constructor(configs: t.MCPServers) {
|
||||
this.rawConfigs = configs;
|
||||
this.parsedConfigs = mapValues(configs, (con) => processMCPEnv({ options: con }));
|
||||
this.connections = new ConnectionsRepository(configs);
|
||||
this.initTimeoutMs = getMCPInitTimeout();
|
||||
}
|
||||
|
||||
/** Initializes all startup-enabled servers by gathering their metadata asynchronously */
|
||||
public async initialize(): Promise<void> {
|
||||
if (this.initialized) return;
|
||||
this.initialized = true;
|
||||
|
||||
const serverNames = Object.keys(this.parsedConfigs);
|
||||
|
||||
await Promise.allSettled(
|
||||
serverNames.map((serverName) => this.initializeServerWithTimeout(serverName)),
|
||||
);
|
||||
}
|
||||
|
||||
/** Wraps server initialization with a timeout to prevent hanging */
|
||||
private async initializeServerWithTimeout(serverName: string): Promise<void> {
|
||||
let timeoutId: NodeJS.Timeout | null = null;
|
||||
|
||||
try {
|
||||
await Promise.race([
|
||||
this.initializeServer(serverName),
|
||||
new Promise<never>((_, reject) => {
|
||||
timeoutId = setTimeout(() => {
|
||||
reject(new Error('Server initialization timed out'));
|
||||
}, this.initTimeoutMs);
|
||||
}),
|
||||
]);
|
||||
} catch (error) {
|
||||
logger.warn(`${this.prefix(serverName)} Server initialization failed:`, error);
|
||||
throw error;
|
||||
} finally {
|
||||
if (timeoutId != null) {
|
||||
clearTimeout(timeoutId);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Initializes a single server with all its metadata and adds it to appropriate collections */
|
||||
private async initializeServer(serverName: string): Promise<void> {
|
||||
const start = Date.now();
|
||||
|
||||
const config = this.parsedConfigs[serverName];
|
||||
|
||||
// 1. Detect OAuth requirements if not already specified
|
||||
try {
|
||||
await this.fetchOAuthRequirement(serverName);
|
||||
|
||||
if (config.startup !== false && !config.requiresOAuth) {
|
||||
await Promise.allSettled([
|
||||
this.fetchServerInstructions(serverName).catch((error) =>
|
||||
logger.warn(`${this.prefix(serverName)} Failed to fetch server instructions:`, error),
|
||||
),
|
||||
this.fetchServerCapabilities(serverName).catch((error) =>
|
||||
logger.warn(`${this.prefix(serverName)} Failed to fetch server capabilities:`, error),
|
||||
),
|
||||
]);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn(`${this.prefix(serverName)} Failed to initialize server:`, error);
|
||||
}
|
||||
|
||||
// 2. Fetch tool functions for this server if a connection was established
|
||||
const getToolFunctions = async (): Promise<t.LCAvailableTools | null> => {
|
||||
try {
|
||||
const loadedConns = await this.connections.getLoaded();
|
||||
const conn = loadedConns.get(serverName);
|
||||
if (conn == null) {
|
||||
return null;
|
||||
}
|
||||
return this.getToolFunctions(serverName, conn);
|
||||
} catch (error) {
|
||||
logger.warn(`${this.prefix(serverName)} Error fetching tool functions:`, error);
|
||||
return null;
|
||||
}
|
||||
};
|
||||
const toolFunctions = await getToolFunctions();
|
||||
|
||||
// 3. Disconnect this server's connection if it was established (fire-and-forget)
|
||||
void this.connections.disconnect(serverName);
|
||||
|
||||
// 4. Side effects
|
||||
// 4.1 Add to OAuth servers if needed
|
||||
if (config.requiresOAuth) {
|
||||
this.oauthServers.add(serverName);
|
||||
}
|
||||
// 4.2 Add server instructions if available
|
||||
if (config.serverInstructions != null) {
|
||||
this.serverInstructions[serverName] = config.serverInstructions as string;
|
||||
}
|
||||
// 4.3 Add to app server configs if eligible (startup enabled, non-OAuth servers)
|
||||
if (config.startup !== false && config.requiresOAuth === false) {
|
||||
this.appServerConfigs[serverName] = this.rawConfigs[serverName];
|
||||
}
|
||||
// 4.4 Add tool functions if available
|
||||
if (toolFunctions != null) {
|
||||
Object.assign(this.toolFunctions, toolFunctions);
|
||||
}
|
||||
|
||||
const duration = Date.now() - start;
|
||||
this.logUpdatedConfig(serverName, duration);
|
||||
}
|
||||
|
||||
/** Converts server tools to LibreChat-compatible tool functions format */
|
||||
public async getToolFunctions(
|
||||
serverName: string,
|
||||
conn: MCPConnection,
|
||||
): Promise<t.LCAvailableTools> {
|
||||
const { tools }: t.MCPToolListResponse = await conn.client.listTools();
|
||||
|
||||
const toolFunctions: t.LCAvailableTools = {};
|
||||
tools.forEach((tool) => {
|
||||
const name = `${tool.name}${Constants.mcp_delimiter}${serverName}`;
|
||||
toolFunctions[name] = {
|
||||
type: 'function',
|
||||
['function']: {
|
||||
name,
|
||||
description: tool.description,
|
||||
parameters: tool.inputSchema as JsonSchemaType,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
return toolFunctions;
|
||||
}
|
||||
|
||||
/** Determines if server requires OAuth if not already specified in the config */
|
||||
private async fetchOAuthRequirement(serverName: string): Promise<boolean> {
|
||||
const config = this.parsedConfigs[serverName];
|
||||
if (config.requiresOAuth != null) return config.requiresOAuth;
|
||||
if (config.url == null) return (config.requiresOAuth = false);
|
||||
if (config.startup === false) return (config.requiresOAuth = false);
|
||||
|
||||
const result = await detectOAuthRequirement(config.url);
|
||||
config.requiresOAuth = result.requiresOAuth;
|
||||
config.oauthMetadata = result.metadata;
|
||||
return config.requiresOAuth;
|
||||
}
|
||||
|
||||
/** Retrieves server instructions from MCP server if enabled in the config */
|
||||
private async fetchServerInstructions(serverName: string): Promise<void> {
|
||||
const config = this.parsedConfigs[serverName];
|
||||
if (!config.serverInstructions) return;
|
||||
|
||||
// If it's a string that's not "true", it's a custom instruction
|
||||
if (typeof config.serverInstructions === 'string' && !isEnabled(config.serverInstructions)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Fetch from server if true (boolean) or "true" (string)
|
||||
const conn = await this.connections.get(serverName);
|
||||
config.serverInstructions = conn.client.getInstructions();
|
||||
if (!config.serverInstructions) {
|
||||
logger.warn(`${this.prefix(serverName)} No server instructions available`);
|
||||
}
|
||||
}
|
||||
|
||||
/** Fetches server capabilities and available tools list */
|
||||
private async fetchServerCapabilities(serverName: string): Promise<void> {
|
||||
const config = this.parsedConfigs[serverName];
|
||||
const conn = await this.connections.get(serverName);
|
||||
const capabilities = conn.client.getServerCapabilities();
|
||||
if (!capabilities) return;
|
||||
config.capabilities = JSON.stringify(capabilities);
|
||||
if (!capabilities.tools) return;
|
||||
const tools = await conn.client.listTools();
|
||||
config.tools = tools.tools.map((tool) => tool.name).join(', ');
|
||||
}
|
||||
|
||||
// Logs server configuration summary after initialization
|
||||
private logUpdatedConfig(serverName: string, initDuration: number): void {
|
||||
const prefix = this.prefix(serverName);
|
||||
const config = this.parsedConfigs[serverName];
|
||||
logger.info(`${prefix} -------------------------------------------------┐`);
|
||||
logger.info(`${prefix} URL: ${config.url ? sanitizeUrlForLogging(config.url) : 'N/A'}`);
|
||||
logger.info(`${prefix} OAuth Required: ${config.requiresOAuth}`);
|
||||
logger.info(`${prefix} Capabilities: ${config.capabilities}`);
|
||||
logger.info(`${prefix} Tools: ${config.tools}`);
|
||||
logger.info(`${prefix} Server Instructions: ${config.serverInstructions}`);
|
||||
logger.info(`${prefix} Initialized in: ${initDuration}ms`);
|
||||
logger.info(`${prefix} -------------------------------------------------┘`);
|
||||
}
|
||||
|
||||
// Returns formatted log prefix for server messages
|
||||
private prefix(serverName: string): string {
|
||||
return `[MCP][${serverName}]`;
|
||||
}
|
||||
}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
import { logger } from '@librechat/data-schemas';
|
||||
import { ErrorCode, McpError } from '@modelcontextprotocol/sdk/types.js';
|
||||
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
|
||||
import { MCPServersRegistry } from '~/mcp/MCPServersRegistry';
|
||||
import { mcpServersRegistry as serversRegistry } from '~/mcp/registry/MCPServersRegistry';
|
||||
import { MCPConnection } from './connection';
|
||||
import type * as t from './types';
|
||||
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
|
||||
|
|
@ -14,7 +14,6 @@ import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
|
|||
* https://github.com/danny-avila/LibreChat/discussions/8790
|
||||
*/
|
||||
export abstract class UserConnectionManager {
|
||||
protected readonly serversRegistry: MCPServersRegistry;
|
||||
// Connections shared by all users.
|
||||
public appConnections: ConnectionsRepository | null = null;
|
||||
// Connections per userId -> serverName -> connection
|
||||
|
|
@ -23,15 +22,6 @@ export abstract class UserConnectionManager {
|
|||
protected userLastActivity: Map<string, number> = new Map();
|
||||
protected readonly USER_CONNECTION_IDLE_TIMEOUT = 15 * 60 * 1000; // 15 minutes (TODO: make configurable)
|
||||
|
||||
constructor(serverConfigs: t.MCPServers) {
|
||||
this.serversRegistry = new MCPServersRegistry(serverConfigs);
|
||||
}
|
||||
|
||||
/** fetches am MCP Server config from the registry */
|
||||
public getRawConfig(serverName: string): t.MCPOptions | undefined {
|
||||
return this.serversRegistry.rawConfigs[serverName];
|
||||
}
|
||||
|
||||
/** Updates the last activity timestamp for a user */
|
||||
protected updateUserLastActivity(userId: string): void {
|
||||
const now = Date.now();
|
||||
|
|
@ -106,7 +96,7 @@ export abstract class UserConnectionManager {
|
|||
logger.info(`[MCP][User: ${userId}][${serverName}] Establishing new connection`);
|
||||
}
|
||||
|
||||
const config = this.serversRegistry.parsedConfigs[serverName];
|
||||
const config = await serversRegistry.getServerConfig(serverName, userId);
|
||||
if (!config) {
|
||||
throw new McpError(
|
||||
ErrorCode.InvalidRequest,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
import { logger } from '@librechat/data-schemas';
|
||||
import type * as t from '~/mcp/types';
|
||||
import { MCPManager } from '~/mcp/MCPManager';
|
||||
import { MCPServersRegistry } from '~/mcp/MCPServersRegistry';
|
||||
import { mcpServersRegistry } from '~/mcp/registry/MCPServersRegistry';
|
||||
import { MCPServersInitializer } from '~/mcp/registry/MCPServersInitializer';
|
||||
import { MCPServerInspector } from '~/mcp/registry/MCPServerInspector';
|
||||
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
|
||||
import { MCPConnection } from '../connection';
|
||||
|
||||
|
|
@ -15,7 +17,24 @@ jest.mock('@librechat/data-schemas', () => ({
|
|||
},
|
||||
}));
|
||||
|
||||
jest.mock('~/mcp/MCPServersRegistry');
|
||||
jest.mock('~/mcp/registry/MCPServersRegistry', () => ({
|
||||
mcpServersRegistry: {
|
||||
sharedAppServers: {
|
||||
getAll: jest.fn(),
|
||||
},
|
||||
getServerConfig: jest.fn(),
|
||||
getAllServerConfigs: jest.fn(),
|
||||
getOAuthServers: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('~/mcp/registry/MCPServersInitializer', () => ({
|
||||
MCPServersInitializer: {
|
||||
initialize: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('~/mcp/registry/MCPServerInspector');
|
||||
jest.mock('~/mcp/ConnectionsRepository');
|
||||
|
||||
const mockLogger = logger as jest.Mocked<typeof logger>;
|
||||
|
|
@ -28,20 +47,12 @@ describe('MCPManager', () => {
|
|||
// Reset MCPManager singleton state
|
||||
(MCPManager as unknown as { instance: null }).instance = null;
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
function mockRegistry(
|
||||
registryConfig: Partial<MCPServersRegistry>,
|
||||
): jest.MockedClass<typeof MCPServersRegistry> {
|
||||
const mock = {
|
||||
initialize: jest.fn().mockResolvedValue(undefined),
|
||||
getToolFunctions: jest.fn().mockResolvedValue(null),
|
||||
...registryConfig,
|
||||
};
|
||||
return (MCPServersRegistry as jest.MockedClass<typeof MCPServersRegistry>).mockImplementation(
|
||||
() => mock as unknown as MCPServersRegistry,
|
||||
);
|
||||
}
|
||||
// Set up default mock implementations
|
||||
(MCPServersInitializer.initialize as jest.Mock).mockResolvedValue(undefined);
|
||||
(mcpServersRegistry.sharedAppServers.getAll as jest.Mock).mockResolvedValue({});
|
||||
(mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({});
|
||||
});
|
||||
|
||||
function mockAppConnections(
|
||||
appConnectionsConfig: Partial<ConnectionsRepository>,
|
||||
|
|
@ -66,12 +77,229 @@ describe('MCPManager', () => {
|
|||
};
|
||||
}
|
||||
|
||||
describe('getAppToolFunctions', () => {
|
||||
it('should return empty object when no servers have tool functions', async () => {
|
||||
(mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({
|
||||
server1: { type: 'stdio', command: 'test', args: [] },
|
||||
server2: { type: 'stdio', command: 'test2', args: [] },
|
||||
});
|
||||
|
||||
const manager = await MCPManager.createInstance(newMCPServersConfig());
|
||||
const result = await manager.getAppToolFunctions();
|
||||
|
||||
expect(result).toEqual({});
|
||||
});
|
||||
|
||||
it('should collect tool functions from multiple servers', async () => {
|
||||
const toolFunctions1 = {
|
||||
tool1_mcp_server1: {
|
||||
type: 'function' as const,
|
||||
function: {
|
||||
name: 'tool1_mcp_server1',
|
||||
description: 'Tool 1',
|
||||
parameters: { type: 'object' as const },
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const toolFunctions2 = {
|
||||
tool2_mcp_server2: {
|
||||
type: 'function' as const,
|
||||
function: {
|
||||
name: 'tool2_mcp_server2',
|
||||
description: 'Tool 2',
|
||||
parameters: { type: 'object' as const },
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
(mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({
|
||||
server1: {
|
||||
type: 'stdio',
|
||||
command: 'test',
|
||||
args: [],
|
||||
toolFunctions: toolFunctions1,
|
||||
},
|
||||
server2: {
|
||||
type: 'stdio',
|
||||
command: 'test2',
|
||||
args: [],
|
||||
toolFunctions: toolFunctions2,
|
||||
},
|
||||
});
|
||||
|
||||
const manager = await MCPManager.createInstance(newMCPServersConfig());
|
||||
const result = await manager.getAppToolFunctions();
|
||||
|
||||
expect(result).toEqual({
|
||||
...toolFunctions1,
|
||||
...toolFunctions2,
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle servers with null or undefined toolFunctions', async () => {
|
||||
const toolFunctions1 = {
|
||||
tool1_mcp_server1: {
|
||||
type: 'function' as const,
|
||||
function: {
|
||||
name: 'tool1_mcp_server1',
|
||||
description: 'Tool 1',
|
||||
parameters: { type: 'object' as const },
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
(mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({
|
||||
server1: {
|
||||
type: 'stdio',
|
||||
command: 'test',
|
||||
args: [],
|
||||
toolFunctions: toolFunctions1,
|
||||
},
|
||||
server2: {
|
||||
type: 'stdio',
|
||||
command: 'test2',
|
||||
args: [],
|
||||
toolFunctions: null,
|
||||
},
|
||||
server3: {
|
||||
type: 'stdio',
|
||||
command: 'test3',
|
||||
args: [],
|
||||
},
|
||||
});
|
||||
|
||||
const manager = await MCPManager.createInstance(newMCPServersConfig());
|
||||
const result = await manager.getAppToolFunctions();
|
||||
|
||||
expect(result).toEqual(toolFunctions1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('formatInstructionsForContext', () => {
|
||||
it('should return empty string when no servers have instructions', async () => {
|
||||
(mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({
|
||||
server1: { type: 'stdio', command: 'test', args: [] },
|
||||
server2: { type: 'stdio', command: 'test2', args: [] },
|
||||
});
|
||||
|
||||
const manager = await MCPManager.createInstance(newMCPServersConfig());
|
||||
const result = await manager.formatInstructionsForContext();
|
||||
|
||||
expect(result).toBe('');
|
||||
});
|
||||
|
||||
it('should format instructions from multiple servers', async () => {
|
||||
(mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({
|
||||
github: {
|
||||
type: 'sse',
|
||||
url: 'https://api.github.com',
|
||||
serverInstructions: 'Use GitHub API with care',
|
||||
},
|
||||
files: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['files.js'],
|
||||
serverInstructions: 'Only read/write files in allowed directories',
|
||||
},
|
||||
});
|
||||
|
||||
const manager = await MCPManager.createInstance(newMCPServersConfig());
|
||||
const result = await manager.formatInstructionsForContext();
|
||||
|
||||
expect(result).toContain('# MCP Server Instructions');
|
||||
expect(result).toContain('## github MCP Server Instructions');
|
||||
expect(result).toContain('Use GitHub API with care');
|
||||
expect(result).toContain('## files MCP Server Instructions');
|
||||
expect(result).toContain('Only read/write files in allowed directories');
|
||||
});
|
||||
|
||||
it('should filter instructions by server names when provided', async () => {
|
||||
(mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({
|
||||
github: {
|
||||
type: 'sse',
|
||||
url: 'https://api.github.com',
|
||||
serverInstructions: 'Use GitHub API with care',
|
||||
},
|
||||
files: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['files.js'],
|
||||
serverInstructions: 'Only read/write files in allowed directories',
|
||||
},
|
||||
database: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['db.js'],
|
||||
serverInstructions: 'Be careful with database operations',
|
||||
},
|
||||
});
|
||||
|
||||
const manager = await MCPManager.createInstance(newMCPServersConfig());
|
||||
const result = await manager.formatInstructionsForContext(['github', 'database']);
|
||||
|
||||
expect(result).toContain('## github MCP Server Instructions');
|
||||
expect(result).toContain('Use GitHub API with care');
|
||||
expect(result).toContain('## database MCP Server Instructions');
|
||||
expect(result).toContain('Be careful with database operations');
|
||||
expect(result).not.toContain('files');
|
||||
expect(result).not.toContain('Only read/write files in allowed directories');
|
||||
});
|
||||
|
||||
it('should handle servers with null or undefined instructions', async () => {
|
||||
(mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({
|
||||
github: {
|
||||
type: 'sse',
|
||||
url: 'https://api.github.com',
|
||||
serverInstructions: 'Use GitHub API with care',
|
||||
},
|
||||
files: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['files.js'],
|
||||
serverInstructions: null,
|
||||
},
|
||||
database: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['db.js'],
|
||||
},
|
||||
});
|
||||
|
||||
const manager = await MCPManager.createInstance(newMCPServersConfig());
|
||||
const result = await manager.formatInstructionsForContext();
|
||||
|
||||
expect(result).toContain('## github MCP Server Instructions');
|
||||
expect(result).toContain('Use GitHub API with care');
|
||||
expect(result).not.toContain('files');
|
||||
expect(result).not.toContain('database');
|
||||
});
|
||||
|
||||
it('should return empty string when filtered servers have no instructions', async () => {
|
||||
(mcpServersRegistry.getAllServerConfigs as jest.Mock).mockResolvedValue({
|
||||
github: {
|
||||
type: 'sse',
|
||||
url: 'https://api.github.com',
|
||||
serverInstructions: 'Use GitHub API with care',
|
||||
},
|
||||
files: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['files.js'],
|
||||
},
|
||||
});
|
||||
|
||||
const manager = await MCPManager.createInstance(newMCPServersConfig());
|
||||
const result = await manager.formatInstructionsForContext(['files']);
|
||||
|
||||
expect(result).toBe('');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getServerToolFunctions', () => {
|
||||
it('should catch and handle errors gracefully', async () => {
|
||||
mockRegistry({
|
||||
getToolFunctions: jest.fn(() => {
|
||||
throw new Error('Connection failed');
|
||||
}),
|
||||
(MCPServerInspector.getToolFunctions as jest.Mock) = jest.fn(() => {
|
||||
throw new Error('Connection failed');
|
||||
});
|
||||
|
||||
mockAppConnections({
|
||||
|
|
@ -90,9 +318,7 @@ describe('MCPManager', () => {
|
|||
});
|
||||
|
||||
it('should catch synchronous errors from getUserConnections', async () => {
|
||||
mockRegistry({
|
||||
getToolFunctions: jest.fn().mockResolvedValue({}),
|
||||
});
|
||||
(MCPServerInspector.getToolFunctions as jest.Mock) = jest.fn().mockResolvedValue({});
|
||||
|
||||
mockAppConnections({
|
||||
has: jest.fn().mockReturnValue(false),
|
||||
|
|
@ -126,9 +352,9 @@ describe('MCPManager', () => {
|
|||
},
|
||||
};
|
||||
|
||||
mockRegistry({
|
||||
getToolFunctions: jest.fn().mockResolvedValue(expectedTools),
|
||||
});
|
||||
(MCPServerInspector.getToolFunctions as jest.Mock) = jest
|
||||
.fn()
|
||||
.mockResolvedValue(expectedTools);
|
||||
|
||||
mockAppConnections({
|
||||
has: jest.fn().mockReturnValue(true),
|
||||
|
|
@ -145,10 +371,8 @@ describe('MCPManager', () => {
|
|||
it('should include specific server name in error messages', async () => {
|
||||
const specificServerName = 'github_mcp_server';
|
||||
|
||||
mockRegistry({
|
||||
getToolFunctions: jest.fn(() => {
|
||||
throw new Error('Server specific error');
|
||||
}),
|
||||
(MCPServerInspector.getToolFunctions as jest.Mock) = jest.fn(() => {
|
||||
throw new Error('Server specific error');
|
||||
});
|
||||
|
||||
mockAppConnections({
|
||||
|
|
|
|||
|
|
@ -1,595 +0,0 @@
|
|||
import { join } from 'path';
|
||||
import { readFileSync } from 'fs';
|
||||
import { load as yamlLoad } from 'js-yaml';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import type { OAuthDetectionResult } from '~/mcp/oauth/detectOAuth';
|
||||
import type * as t from '~/mcp/types';
|
||||
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
|
||||
import { MCPServersRegistry } from '~/mcp/MCPServersRegistry';
|
||||
import { detectOAuthRequirement } from '~/mcp/oauth';
|
||||
import { MCPConnection } from '~/mcp/connection';
|
||||
|
||||
// Mock external dependencies
|
||||
jest.mock('../oauth/detectOAuth');
|
||||
jest.mock('../ConnectionsRepository');
|
||||
jest.mock('../connection');
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
// Mock processMCPEnv to verify it's called and adds a processed marker
|
||||
jest.mock('~/utils', () => ({
|
||||
...jest.requireActual('~/utils'),
|
||||
processMCPEnv: jest.fn(({ options }) => ({
|
||||
...options,
|
||||
_processed: true, // Simple marker to verify processing occurred
|
||||
})),
|
||||
}));
|
||||
|
||||
const mockDetectOAuthRequirement = detectOAuthRequirement as jest.MockedFunction<
|
||||
typeof detectOAuthRequirement
|
||||
>;
|
||||
const mockLogger = logger as jest.Mocked<typeof logger>;
|
||||
|
||||
describe('MCPServersRegistry - Initialize Function', () => {
|
||||
let rawConfigs: t.MCPServers;
|
||||
let expectedParsedConfigs: Record<string, t.ParsedServerConfig>;
|
||||
let mockConnectionsRepo: jest.Mocked<ConnectionsRepository>;
|
||||
let mockConnections: Map<string, jest.Mocked<MCPConnection>>;
|
||||
|
||||
beforeEach(() => {
|
||||
// Load fixtures
|
||||
const rawConfigsPath = join(__dirname, 'fixtures', 'MCPServersRegistry.rawConfigs.yml');
|
||||
const parsedConfigsPath = join(__dirname, 'fixtures', 'MCPServersRegistry.parsedConfigs.yml');
|
||||
|
||||
rawConfigs = yamlLoad(readFileSync(rawConfigsPath, 'utf8')) as t.MCPServers;
|
||||
expectedParsedConfigs = yamlLoad(readFileSync(parsedConfigsPath, 'utf8')) as Record<
|
||||
string,
|
||||
t.ParsedServerConfig
|
||||
>;
|
||||
|
||||
// Setup mock connections
|
||||
mockConnections = new Map();
|
||||
const serverNames = Object.keys(rawConfigs);
|
||||
|
||||
serverNames.forEach((serverName) => {
|
||||
const mockClient = {
|
||||
listTools: jest.fn(),
|
||||
getInstructions: jest.fn(),
|
||||
getServerCapabilities: jest.fn(),
|
||||
};
|
||||
const mockConnection = {
|
||||
client: mockClient,
|
||||
} as unknown as jest.Mocked<MCPConnection>;
|
||||
|
||||
// Setup mock responses based on expected configs
|
||||
const expectedConfig = expectedParsedConfigs[serverName];
|
||||
|
||||
// Mock listTools response
|
||||
if (expectedConfig.tools) {
|
||||
const toolNames = expectedConfig.tools.split(', ');
|
||||
const tools = toolNames.map((name: string) => ({
|
||||
name,
|
||||
description: `Description for ${name}`,
|
||||
inputSchema: {
|
||||
type: 'object' as const,
|
||||
properties: {
|
||||
input: { type: 'string' },
|
||||
},
|
||||
},
|
||||
}));
|
||||
(mockClient.listTools as jest.Mock).mockResolvedValue({ tools });
|
||||
} else {
|
||||
(mockClient.listTools as jest.Mock).mockResolvedValue({ tools: [] });
|
||||
}
|
||||
|
||||
// Mock getInstructions response
|
||||
if (expectedConfig.serverInstructions) {
|
||||
(mockClient.getInstructions as jest.Mock).mockReturnValue(
|
||||
expectedConfig.serverInstructions as string,
|
||||
);
|
||||
} else {
|
||||
(mockClient.getInstructions as jest.Mock).mockReturnValue(undefined);
|
||||
}
|
||||
|
||||
// Mock getServerCapabilities response
|
||||
if (expectedConfig.capabilities) {
|
||||
const capabilities = JSON.parse(expectedConfig.capabilities) as Record<string, unknown>;
|
||||
(mockClient.getServerCapabilities as jest.Mock).mockReturnValue(capabilities);
|
||||
} else {
|
||||
(mockClient.getServerCapabilities as jest.Mock).mockReturnValue(undefined);
|
||||
}
|
||||
|
||||
mockConnections.set(serverName, mockConnection);
|
||||
});
|
||||
|
||||
// Setup ConnectionsRepository mock
|
||||
mockConnectionsRepo = {
|
||||
get: jest.fn(),
|
||||
getLoaded: jest.fn(),
|
||||
disconnectAll: jest.fn(),
|
||||
disconnect: jest.fn().mockResolvedValue(undefined),
|
||||
} as unknown as jest.Mocked<ConnectionsRepository>;
|
||||
|
||||
mockConnectionsRepo.get.mockImplementation((serverName: string) => {
|
||||
const connection = mockConnections.get(serverName);
|
||||
if (!connection) {
|
||||
throw new Error(`Connection not found for server: ${serverName}`);
|
||||
}
|
||||
return Promise.resolve(connection);
|
||||
});
|
||||
|
||||
mockConnectionsRepo.getLoaded.mockResolvedValue(mockConnections);
|
||||
|
||||
(ConnectionsRepository as jest.Mock).mockImplementation(() => mockConnectionsRepo);
|
||||
|
||||
// Setup OAuth detection mock with deterministic results
|
||||
mockDetectOAuthRequirement.mockImplementation((url: string) => {
|
||||
const oauthResults: Record<string, OAuthDetectionResult> = {
|
||||
'https://api.github.com/mcp': {
|
||||
requiresOAuth: true,
|
||||
method: 'protected-resource-metadata',
|
||||
metadata: {
|
||||
authorization_url: 'https://github.com/login/oauth/authorize',
|
||||
token_url: 'https://github.com/login/oauth/access_token',
|
||||
},
|
||||
},
|
||||
'https://api.disabled.com/mcp': {
|
||||
requiresOAuth: false,
|
||||
method: 'no-metadata-found',
|
||||
metadata: null,
|
||||
},
|
||||
'https://api.public.com/mcp': {
|
||||
requiresOAuth: false,
|
||||
method: 'no-metadata-found',
|
||||
metadata: null,
|
||||
},
|
||||
};
|
||||
|
||||
return Promise.resolve(
|
||||
oauthResults[url] || { requiresOAuth: false, method: 'no-metadata-found', metadata: null },
|
||||
);
|
||||
});
|
||||
|
||||
// Clear all mocks
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
delete process.env.MCP_INIT_TIMEOUT_MS;
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('initialize() method', () => {
|
||||
it('should only run initialization once', async () => {
|
||||
const registry = new MCPServersRegistry(rawConfigs);
|
||||
|
||||
await registry.initialize();
|
||||
await registry.initialize(); // Second call should not re-run
|
||||
|
||||
// Verify that connections are only requested for servers that need them
|
||||
// (servers with serverInstructions=true or all servers for capabilities)
|
||||
expect(mockConnectionsRepo.get).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should set all public properties correctly after initialization', async () => {
|
||||
const registry = new MCPServersRegistry(rawConfigs);
|
||||
|
||||
// Verify initial state
|
||||
expect(registry.oauthServers.size).toBe(0);
|
||||
expect(registry.serverInstructions).toEqual({});
|
||||
expect(registry.toolFunctions).toEqual({});
|
||||
expect(registry.appServerConfigs).toEqual({});
|
||||
|
||||
await registry.initialize();
|
||||
|
||||
// Test oauthServers Set
|
||||
expect(registry.oauthServers).toEqual(
|
||||
new Set(['oauth_server', 'oauth_predefined', 'oauth_startup_enabled']),
|
||||
);
|
||||
|
||||
// Test serverInstructions - OAuth servers keep their original boolean value, non-OAuth fetch actual strings
|
||||
expect(registry.serverInstructions).toEqual({
|
||||
stdio_server: 'Follow these instructions for stdio server',
|
||||
oauth_server: true,
|
||||
non_oauth_server: 'Public API instructions',
|
||||
});
|
||||
|
||||
// Test appServerConfigs (startup enabled, non-OAuth servers only)
|
||||
expect(registry.appServerConfigs).toEqual({
|
||||
stdio_server: rawConfigs.stdio_server,
|
||||
websocket_server: rawConfigs.websocket_server,
|
||||
non_oauth_server: rawConfigs.non_oauth_server,
|
||||
});
|
||||
|
||||
// Test toolFunctions (only non-OAuth servers get their tools fetched during initialization)
|
||||
const expectedToolFunctions = {
|
||||
file_read_mcp_stdio_server: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'file_read_mcp_stdio_server',
|
||||
description: 'Description for file_read',
|
||||
parameters: { type: 'object', properties: { input: { type: 'string' } } },
|
||||
},
|
||||
},
|
||||
file_write_mcp_stdio_server: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'file_write_mcp_stdio_server',
|
||||
description: 'Description for file_write',
|
||||
parameters: { type: 'object', properties: { input: { type: 'string' } } },
|
||||
},
|
||||
},
|
||||
};
|
||||
expect(registry.toolFunctions).toEqual(expectedToolFunctions);
|
||||
});
|
||||
|
||||
it('should handle errors gracefully and continue initialization of other servers', async () => {
|
||||
const registry = new MCPServersRegistry(rawConfigs);
|
||||
|
||||
// Make one specific server throw an error during OAuth detection
|
||||
mockDetectOAuthRequirement.mockImplementation((url: string) => {
|
||||
if (url === 'https://api.github.com/mcp') {
|
||||
return Promise.reject(new Error('OAuth detection failed'));
|
||||
}
|
||||
// Return normal responses for other servers
|
||||
const oauthResults: Record<string, OAuthDetectionResult> = {
|
||||
'https://api.disabled.com/mcp': {
|
||||
requiresOAuth: false,
|
||||
method: 'no-metadata-found',
|
||||
metadata: null,
|
||||
},
|
||||
'https://api.public.com/mcp': {
|
||||
requiresOAuth: false,
|
||||
method: 'no-metadata-found',
|
||||
metadata: null,
|
||||
},
|
||||
};
|
||||
return Promise.resolve(
|
||||
oauthResults[url] ?? {
|
||||
requiresOAuth: false,
|
||||
method: 'no-metadata-found',
|
||||
metadata: null,
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
await registry.initialize();
|
||||
|
||||
// Should still initialize successfully for other servers
|
||||
expect(registry.oauthServers).toBeInstanceOf(Set);
|
||||
expect(registry.toolFunctions).toBeDefined();
|
||||
|
||||
// The failed server should not be in oauthServers (since it failed OAuth detection)
|
||||
expect(registry.oauthServers.has('oauth_server')).toBe(false);
|
||||
|
||||
// But other servers should still be processed successfully
|
||||
expect(registry.appServerConfigs).toHaveProperty('stdio_server');
|
||||
expect(registry.appServerConfigs).toHaveProperty('non_oauth_server');
|
||||
|
||||
// Error should be logged as a warning at the higher level
|
||||
expect(mockLogger.warn).toHaveBeenCalledWith(
|
||||
expect.stringContaining('[MCP][oauth_server] Failed to initialize server:'),
|
||||
expect.any(Error),
|
||||
);
|
||||
});
|
||||
|
||||
it('should disconnect individual connections after each server initialization', async () => {
|
||||
const registry = new MCPServersRegistry(rawConfigs);
|
||||
|
||||
await registry.initialize();
|
||||
|
||||
// Verify disconnect was called for each server during initialization
|
||||
// All servers attempt to connect during initialization for metadata gathering
|
||||
const serverNames = Object.keys(rawConfigs);
|
||||
expect(mockConnectionsRepo.disconnect).toHaveBeenCalledTimes(serverNames.length);
|
||||
});
|
||||
|
||||
it('should log configuration updates for each startup-enabled server', async () => {
|
||||
const registry = new MCPServersRegistry(rawConfigs);
|
||||
|
||||
await registry.initialize();
|
||||
|
||||
const serverNames = Object.keys(rawConfigs);
|
||||
serverNames.forEach((serverName) => {
|
||||
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||
expect.stringContaining(`[MCP][${serverName}] URL:`),
|
||||
);
|
||||
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||
expect.stringContaining(`[MCP][${serverName}] OAuth Required:`),
|
||||
);
|
||||
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||
expect.stringContaining(`[MCP][${serverName}] Capabilities:`),
|
||||
);
|
||||
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||
expect.stringContaining(`[MCP][${serverName}] Tools:`),
|
||||
);
|
||||
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||
expect.stringContaining(`[MCP][${serverName}] Server Instructions:`),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it('should have parsedConfigs matching the expected fixture after initialization', async () => {
|
||||
const registry = new MCPServersRegistry(rawConfigs);
|
||||
|
||||
await registry.initialize();
|
||||
|
||||
// Compare the actual parsedConfigs against the expected fixture
|
||||
expect(registry.parsedConfigs).toEqual(expectedParsedConfigs);
|
||||
});
|
||||
|
||||
it('should handle serverInstructions as string "true" correctly and fetch from server', async () => {
|
||||
// Create test config with serverInstructions as string "true"
|
||||
const testConfig: t.MCPServers = {
|
||||
test_server_string_true: {
|
||||
type: 'stdio',
|
||||
args: [],
|
||||
command: 'test-command',
|
||||
serverInstructions: 'true', // Simulating string "true" from YAML parsing
|
||||
},
|
||||
test_server_custom_string: {
|
||||
type: 'stdio',
|
||||
args: [],
|
||||
command: 'test-command',
|
||||
serverInstructions: 'Custom instructions here',
|
||||
},
|
||||
test_server_bool_true: {
|
||||
type: 'stdio',
|
||||
args: [],
|
||||
command: 'test-command',
|
||||
serverInstructions: true,
|
||||
},
|
||||
};
|
||||
|
||||
const registry = new MCPServersRegistry(testConfig);
|
||||
|
||||
// Setup mock connection for servers that should fetch
|
||||
const mockClient = {
|
||||
listTools: jest.fn().mockResolvedValue({ tools: [] }),
|
||||
getInstructions: jest.fn().mockReturnValue('Fetched instructions from server'),
|
||||
getServerCapabilities: jest.fn().mockReturnValue({ tools: {} }),
|
||||
};
|
||||
const mockConnection = {
|
||||
client: mockClient,
|
||||
} as unknown as jest.Mocked<MCPConnection>;
|
||||
|
||||
mockConnectionsRepo.get.mockResolvedValue(mockConnection);
|
||||
mockConnectionsRepo.getLoaded.mockResolvedValue(
|
||||
new Map([
|
||||
['test_server_string_true', mockConnection],
|
||||
['test_server_bool_true', mockConnection],
|
||||
]),
|
||||
);
|
||||
mockDetectOAuthRequirement.mockResolvedValue({
|
||||
requiresOAuth: false,
|
||||
method: 'no-metadata-found',
|
||||
metadata: null,
|
||||
});
|
||||
|
||||
await registry.initialize();
|
||||
|
||||
// Verify that string "true" was treated as fetch-from-server
|
||||
expect(registry.parsedConfigs['test_server_string_true'].serverInstructions).toBe(
|
||||
'Fetched instructions from server',
|
||||
);
|
||||
|
||||
// Verify that custom string was kept as-is
|
||||
expect(registry.parsedConfigs['test_server_custom_string'].serverInstructions).toBe(
|
||||
'Custom instructions here',
|
||||
);
|
||||
|
||||
// Verify that boolean true also fetched from server
|
||||
expect(registry.parsedConfigs['test_server_bool_true'].serverInstructions).toBe(
|
||||
'Fetched instructions from server',
|
||||
);
|
||||
|
||||
// Verify getInstructions was called for both "true" cases
|
||||
expect(mockClient.getInstructions).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should use Promise.allSettled for individual server initialization', async () => {
|
||||
const registry = new MCPServersRegistry(rawConfigs);
|
||||
|
||||
// Spy on Promise.allSettled to verify it's being used
|
||||
const allSettledSpy = jest.spyOn(Promise, 'allSettled');
|
||||
|
||||
await registry.initialize();
|
||||
|
||||
// Verify Promise.allSettled was called with an array of server initialization promises
|
||||
expect(allSettledSpy).toHaveBeenCalledWith(expect.arrayContaining([expect.any(Promise)]));
|
||||
|
||||
// Verify it was called with the correct number of server promises
|
||||
const serverNames = Object.keys(rawConfigs);
|
||||
expect(allSettledSpy).toHaveBeenCalledWith(
|
||||
expect.arrayContaining(new Array(serverNames.length).fill(expect.any(Promise))),
|
||||
);
|
||||
|
||||
allSettledSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should isolate server failures and not affect other servers', async () => {
|
||||
const registry = new MCPServersRegistry(rawConfigs);
|
||||
|
||||
// Make multiple servers fail in different ways
|
||||
mockConnectionsRepo.get.mockImplementation((serverName: string) => {
|
||||
if (serverName === 'stdio_server') {
|
||||
// First server fails
|
||||
throw new Error('Connection failed for stdio_server');
|
||||
}
|
||||
if (serverName === 'websocket_server') {
|
||||
// Second server fails
|
||||
throw new Error('Connection failed for websocket_server');
|
||||
}
|
||||
// Other servers succeed
|
||||
const connection = mockConnections.get(serverName);
|
||||
if (!connection) {
|
||||
throw new Error(`Connection not found for server: ${serverName}`);
|
||||
}
|
||||
return Promise.resolve(connection);
|
||||
});
|
||||
|
||||
await registry.initialize();
|
||||
|
||||
// Despite failures, initialization should complete
|
||||
expect(registry.oauthServers).toBeInstanceOf(Set);
|
||||
expect(registry.toolFunctions).toBeDefined();
|
||||
|
||||
// Successful servers should still be processed
|
||||
expect(registry.appServerConfigs).toHaveProperty('non_oauth_server');
|
||||
|
||||
// Failed servers should not crash the whole initialization
|
||||
expect(mockLogger.warn).toHaveBeenCalledWith(
|
||||
expect.stringContaining('[MCP][stdio_server] Failed to fetch server capabilities:'),
|
||||
expect.any(Error),
|
||||
);
|
||||
expect(mockLogger.warn).toHaveBeenCalledWith(
|
||||
expect.stringContaining('[MCP][websocket_server] Failed to fetch server capabilities:'),
|
||||
expect.any(Error),
|
||||
);
|
||||
});
|
||||
|
||||
it('should properly clean up connections even when some servers fail', async () => {
|
||||
const registry = new MCPServersRegistry(rawConfigs);
|
||||
|
||||
// Track disconnect failures but suppress unhandled rejections
|
||||
const disconnectErrors: Error[] = [];
|
||||
mockConnectionsRepo.disconnect.mockImplementation((serverName: string) => {
|
||||
if (serverName === 'stdio_server') {
|
||||
const error = new Error('Disconnect failed');
|
||||
disconnectErrors.push(error);
|
||||
return Promise.reject(error).catch(() => {}); // Suppress unhandled rejection
|
||||
}
|
||||
return Promise.resolve();
|
||||
});
|
||||
|
||||
await registry.initialize();
|
||||
|
||||
// Should still attempt to disconnect all servers during initialization
|
||||
const serverNames = Object.keys(rawConfigs);
|
||||
expect(mockConnectionsRepo.disconnect).toHaveBeenCalledTimes(serverNames.length);
|
||||
expect(disconnectErrors).toHaveLength(1);
|
||||
});
|
||||
|
||||
it('should timeout individual server initialization after configured timeout', async () => {
|
||||
const timeout = 2000;
|
||||
// Create registry with a short timeout for testing
|
||||
process.env.MCP_INIT_TIMEOUT_MS = `${timeout}`;
|
||||
|
||||
const registry = new MCPServersRegistry(rawConfigs);
|
||||
|
||||
// Make one server hang indefinitely during OAuth detection
|
||||
mockDetectOAuthRequirement.mockImplementation((url: string) => {
|
||||
if (url === 'https://api.github.com/mcp') {
|
||||
// Slow init
|
||||
return new Promise((res) => setTimeout(res, timeout * 2));
|
||||
}
|
||||
// Return normal responses for other servers
|
||||
return Promise.resolve({
|
||||
requiresOAuth: false,
|
||||
method: 'no-metadata-found',
|
||||
metadata: null,
|
||||
});
|
||||
});
|
||||
|
||||
const start = Date.now();
|
||||
await registry.initialize();
|
||||
const duration = Date.now() - start;
|
||||
|
||||
// Should complete within reasonable time despite one server hanging
|
||||
// Allow some buffer for test execution overhead
|
||||
expect(duration).toBeLessThan(timeout * 1.5);
|
||||
|
||||
// The timeout should prevent the hanging server from blocking initialization
|
||||
// Other servers should still be processed successfully
|
||||
expect(registry.appServerConfigs).toHaveProperty('stdio_server');
|
||||
expect(registry.appServerConfigs).toHaveProperty('non_oauth_server');
|
||||
}, 10_000); // 10 second Jest timeout
|
||||
|
||||
it('should skip tool function fetching if connection was not established', async () => {
|
||||
const testConfig: t.MCPServers = {
|
||||
server_with_connection: {
|
||||
type: 'stdio',
|
||||
args: [],
|
||||
command: 'test-command',
|
||||
},
|
||||
server_without_connection: {
|
||||
type: 'stdio',
|
||||
args: [],
|
||||
command: 'failing-command',
|
||||
},
|
||||
};
|
||||
|
||||
const registry = new MCPServersRegistry(testConfig);
|
||||
|
||||
const mockClient = {
|
||||
listTools: jest.fn().mockResolvedValue({
|
||||
tools: [
|
||||
{
|
||||
name: 'test_tool',
|
||||
description: 'Test tool',
|
||||
inputSchema: { type: 'object', properties: {} },
|
||||
},
|
||||
],
|
||||
}),
|
||||
getInstructions: jest.fn().mockReturnValue(undefined),
|
||||
getServerCapabilities: jest.fn().mockReturnValue({ tools: {} }),
|
||||
};
|
||||
const mockConnection = {
|
||||
client: mockClient,
|
||||
} as unknown as jest.Mocked<MCPConnection>;
|
||||
|
||||
mockConnectionsRepo.get.mockImplementation((serverName: string) => {
|
||||
if (serverName === 'server_with_connection') {
|
||||
return Promise.resolve(mockConnection);
|
||||
}
|
||||
throw new Error('Connection failed');
|
||||
});
|
||||
|
||||
// Mock getLoaded to return connections map - the real implementation returns all loaded connections at once
|
||||
mockConnectionsRepo.getLoaded.mockResolvedValue(
|
||||
new Map([['server_with_connection', mockConnection]]),
|
||||
);
|
||||
|
||||
mockDetectOAuthRequirement.mockResolvedValue({
|
||||
requiresOAuth: false,
|
||||
method: 'no-metadata-found',
|
||||
metadata: null,
|
||||
});
|
||||
|
||||
await registry.initialize();
|
||||
|
||||
expect(registry.toolFunctions).toHaveProperty('test_tool_mcp_server_with_connection');
|
||||
expect(Object.keys(registry.toolFunctions)).toHaveLength(1);
|
||||
});
|
||||
|
||||
it('should handle getLoaded returning empty map gracefully', async () => {
|
||||
const testConfig: t.MCPServers = {
|
||||
test_server: {
|
||||
type: 'stdio',
|
||||
args: [],
|
||||
command: 'test-command',
|
||||
},
|
||||
};
|
||||
|
||||
const registry = new MCPServersRegistry(testConfig);
|
||||
|
||||
mockConnectionsRepo.get.mockRejectedValue(new Error('All connections failed'));
|
||||
mockConnectionsRepo.getLoaded.mockResolvedValue(new Map());
|
||||
mockDetectOAuthRequirement.mockResolvedValue({
|
||||
requiresOAuth: false,
|
||||
method: 'no-metadata-found',
|
||||
metadata: null,
|
||||
});
|
||||
|
||||
await registry.initialize();
|
||||
|
||||
expect(registry.toolFunctions).toEqual({});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -1,67 +0,0 @@
|
|||
# Expected parsed MCP server configurations after running initialize()
|
||||
# These represent the expected state of parsedConfigs after all fetch functions complete
|
||||
|
||||
oauth_server:
|
||||
_processed: true
|
||||
type: "streamable-http"
|
||||
url: "https://api.github.com/mcp"
|
||||
headers:
|
||||
Authorization: "Bearer {{GITHUB_TOKEN}}"
|
||||
serverInstructions: true
|
||||
requiresOAuth: true
|
||||
oauthMetadata:
|
||||
authorization_url: "https://github.com/login/oauth/authorize"
|
||||
token_url: "https://github.com/login/oauth/access_token"
|
||||
|
||||
oauth_predefined:
|
||||
_processed: true
|
||||
type: "sse"
|
||||
url: "https://api.example.com/sse"
|
||||
requiresOAuth: true
|
||||
oauthMetadata:
|
||||
authorization_url: "https://example.com/oauth/authorize"
|
||||
token_url: "https://example.com/oauth/token"
|
||||
|
||||
stdio_server:
|
||||
_processed: true
|
||||
command: "node"
|
||||
args: ["server.js"]
|
||||
env:
|
||||
API_KEY: "${TEST_API_KEY}"
|
||||
startup: true
|
||||
serverInstructions: "Follow these instructions for stdio server"
|
||||
requiresOAuth: false
|
||||
capabilities: '{"tools":{"listChanged":true},"resources":{"listChanged":true},"prompts":{}}'
|
||||
tools: "file_read, file_write"
|
||||
|
||||
websocket_server:
|
||||
_processed: true
|
||||
type: "websocket"
|
||||
url: "ws://localhost:3001/mcp"
|
||||
startup: true
|
||||
requiresOAuth: false
|
||||
oauthMetadata: null
|
||||
capabilities: '{"tools":{},"resources":{},"prompts":{}}'
|
||||
tools: ""
|
||||
|
||||
disabled_server:
|
||||
_processed: true
|
||||
requiresOAuth: false
|
||||
type: "streamable-http"
|
||||
url: "https://api.disabled.com/mcp"
|
||||
startup: false
|
||||
|
||||
non_oauth_server:
|
||||
_processed: true
|
||||
type: "streamable-http"
|
||||
url: "https://api.public.com/mcp"
|
||||
requiresOAuth: false
|
||||
serverInstructions: "Public API instructions"
|
||||
capabilities: '{"tools":{},"resources":{},"prompts":{}}'
|
||||
tools: ""
|
||||
|
||||
oauth_startup_enabled:
|
||||
_processed: true
|
||||
type: "sse"
|
||||
url: "https://api.oauth-startup.com/sse"
|
||||
requiresOAuth: true
|
||||
|
|
@ -1,53 +0,0 @@
|
|||
# Raw MCP server configurations used as input to MCPServersRegistry constructor
|
||||
# These configs test different code paths in the initialization process
|
||||
|
||||
# Test OAuth detection with URL - should trigger fetchOAuthRequirement
|
||||
oauth_server:
|
||||
type: "streamable-http"
|
||||
url: "https://api.github.com/mcp"
|
||||
headers:
|
||||
Authorization: "Bearer {{GITHUB_TOKEN}}"
|
||||
serverInstructions: true
|
||||
|
||||
# Test OAuth already specified - should skip OAuth detection
|
||||
oauth_predefined:
|
||||
type: "sse"
|
||||
url: "https://api.example.com/sse"
|
||||
requiresOAuth: true
|
||||
oauthMetadata:
|
||||
authorization_url: "https://example.com/oauth/authorize"
|
||||
token_url: "https://example.com/oauth/token"
|
||||
|
||||
# Test stdio server without URL - should set requiresOAuth to false
|
||||
stdio_server:
|
||||
command: "node"
|
||||
args: ["server.js"]
|
||||
env:
|
||||
API_KEY: "${TEST_API_KEY}"
|
||||
startup: true
|
||||
serverInstructions: "Follow these instructions for stdio server"
|
||||
|
||||
# Test websocket server with capabilities but no tools
|
||||
websocket_server:
|
||||
type: "websocket"
|
||||
url: "ws://localhost:3001/mcp"
|
||||
startup: true
|
||||
|
||||
# Test server with startup disabled - should not be included in appServerConfigs
|
||||
disabled_server:
|
||||
type: "streamable-http"
|
||||
url: "https://api.disabled.com/mcp"
|
||||
startup: false
|
||||
|
||||
# Test non-OAuth server - should be included in appServerConfigs
|
||||
non_oauth_server:
|
||||
type: "streamable-http"
|
||||
url: "https://api.public.com/mcp"
|
||||
requiresOAuth: false
|
||||
serverInstructions: true
|
||||
|
||||
# Test server with OAuth but startup enabled - should not be in appServerConfigs
|
||||
oauth_startup_enabled:
|
||||
type: "sse"
|
||||
url: "https://api.oauth-startup.com/sse"
|
||||
requiresOAuth: true
|
||||
|
|
@ -18,6 +18,7 @@ import type {
|
|||
Response as UndiciResponse,
|
||||
} from 'undici';
|
||||
import type { MCPOAuthTokens } from './oauth/types';
|
||||
import { withTimeout } from '~/utils/promise';
|
||||
import type * as t from './types';
|
||||
import { sanitizeUrlForLogging } from './utils';
|
||||
import { mcpConfig } from './mcpConfig';
|
||||
|
|
@ -457,15 +458,11 @@ export class MCPConnection extends EventEmitter {
|
|||
this.setupTransportDebugHandlers();
|
||||
|
||||
const connectTimeout = this.options.initTimeout ?? 120000;
|
||||
await Promise.race([
|
||||
await withTimeout(
|
||||
this.client.connect(this.transport),
|
||||
new Promise((_resolve, reject) =>
|
||||
setTimeout(
|
||||
() => reject(new Error(`Connection timeout after ${connectTimeout}ms`)),
|
||||
connectTimeout,
|
||||
),
|
||||
),
|
||||
]);
|
||||
connectTimeout,
|
||||
`Connection timeout after ${connectTimeout}ms`,
|
||||
);
|
||||
|
||||
this.connectionState = 'connected';
|
||||
this.emit('connectionChange', 'connected');
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import { TokenMethods } from '@librechat/data-schemas';
|
||||
import { FlowStateManager, MCPConnection, MCPOAuthTokens, MCPOptions } from '../..';
|
||||
import { MCPManager } from '../MCPManager';
|
||||
import { mcpServersRegistry } from '../../mcp/registry/MCPServersRegistry';
|
||||
import { OAuthReconnectionManager } from './OAuthReconnectionManager';
|
||||
import { OAuthReconnectionTracker } from './OAuthReconnectionTracker';
|
||||
|
||||
|
|
@ -14,6 +15,12 @@ jest.mock('@librechat/data-schemas', () => ({
|
|||
}));
|
||||
|
||||
jest.mock('../MCPManager');
|
||||
jest.mock('../../mcp/registry/MCPServersRegistry', () => ({
|
||||
mcpServersRegistry: {
|
||||
getServerConfig: jest.fn(),
|
||||
getOAuthServers: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
describe('OAuthReconnectionManager', () => {
|
||||
let flowManager: jest.Mocked<FlowStateManager<null>>;
|
||||
|
|
@ -51,10 +58,10 @@ describe('OAuthReconnectionManager', () => {
|
|||
getUserConnection: jest.fn(),
|
||||
getUserConnections: jest.fn(),
|
||||
disconnectUserConnection: jest.fn(),
|
||||
getRawConfig: jest.fn(),
|
||||
} as unknown as jest.Mocked<MCPManager>;
|
||||
|
||||
(MCPManager.getInstance as jest.Mock).mockReturnValue(mockMCPManager);
|
||||
(mcpServersRegistry.getServerConfig as jest.Mock).mockResolvedValue({});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
|
|
@ -152,7 +159,7 @@ describe('OAuthReconnectionManager', () => {
|
|||
it('should reconnect eligible servers', async () => {
|
||||
const userId = 'user-123';
|
||||
const oauthServers = new Set(['server1', 'server2', 'server3']);
|
||||
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
|
||||
(mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
|
||||
|
||||
// server1: has failed reconnection
|
||||
reconnectionTracker.setFailed(userId, 'server1');
|
||||
|
|
@ -186,7 +193,9 @@ describe('OAuthReconnectionManager', () => {
|
|||
mockMCPManager.getUserConnection.mockResolvedValue(
|
||||
mockNewConnection as unknown as MCPConnection,
|
||||
);
|
||||
mockMCPManager.getRawConfig.mockReturnValue({ initTimeout: 5000 } as unknown as MCPOptions);
|
||||
(mcpServersRegistry.getServerConfig as jest.Mock).mockResolvedValue({
|
||||
initTimeout: 5000,
|
||||
} as unknown as MCPOptions);
|
||||
|
||||
await reconnectionManager.reconnectServers(userId);
|
||||
|
||||
|
|
@ -215,7 +224,7 @@ describe('OAuthReconnectionManager', () => {
|
|||
it('should handle failed reconnection attempts', async () => {
|
||||
const userId = 'user-123';
|
||||
const oauthServers = new Set(['server1']);
|
||||
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
|
||||
(mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
|
||||
|
||||
// server1: has valid token
|
||||
tokenMethods.findToken.mockResolvedValue({
|
||||
|
|
@ -226,7 +235,9 @@ describe('OAuthReconnectionManager', () => {
|
|||
|
||||
// Mock failed connection
|
||||
mockMCPManager.getUserConnection.mockRejectedValue(new Error('Connection failed'));
|
||||
mockMCPManager.getRawConfig.mockReturnValue({} as unknown as MCPOptions);
|
||||
(mcpServersRegistry.getServerConfig as jest.Mock).mockResolvedValue(
|
||||
{} as unknown as MCPOptions,
|
||||
);
|
||||
|
||||
await reconnectionManager.reconnectServers(userId);
|
||||
|
||||
|
|
@ -242,7 +253,7 @@ describe('OAuthReconnectionManager', () => {
|
|||
it('should not reconnect servers with expired tokens', async () => {
|
||||
const userId = 'user-123';
|
||||
const oauthServers = new Set(['server1']);
|
||||
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
|
||||
(mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
|
||||
|
||||
// server1: has expired token
|
||||
tokenMethods.findToken.mockResolvedValue({
|
||||
|
|
@ -261,7 +272,7 @@ describe('OAuthReconnectionManager', () => {
|
|||
it('should handle connection that returns but is not connected', async () => {
|
||||
const userId = 'user-123';
|
||||
const oauthServers = new Set(['server1']);
|
||||
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
|
||||
(mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
|
||||
|
||||
tokenMethods.findToken.mockResolvedValue({
|
||||
userId,
|
||||
|
|
@ -277,7 +288,9 @@ describe('OAuthReconnectionManager', () => {
|
|||
mockMCPManager.getUserConnection.mockResolvedValue(
|
||||
mockConnection as unknown as MCPConnection,
|
||||
);
|
||||
mockMCPManager.getRawConfig.mockReturnValue({} as unknown as MCPOptions);
|
||||
(mcpServersRegistry.getServerConfig as jest.Mock).mockResolvedValue(
|
||||
{} as unknown as MCPOptions,
|
||||
);
|
||||
|
||||
await reconnectionManager.reconnectServers(userId);
|
||||
|
||||
|
|
@ -359,7 +372,7 @@ describe('OAuthReconnectionManager', () => {
|
|||
it('should not attempt to reconnect servers that have timed out during reconnection', async () => {
|
||||
const userId = 'user-123';
|
||||
const oauthServers = new Set(['server1', 'server2']);
|
||||
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
|
||||
(mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
|
||||
|
||||
const now = Date.now();
|
||||
jest.setSystemTime(now);
|
||||
|
|
@ -414,7 +427,7 @@ describe('OAuthReconnectionManager', () => {
|
|||
const userId = 'user-123';
|
||||
const serverName = 'server1';
|
||||
const oauthServers = new Set([serverName]);
|
||||
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
|
||||
(mcpServersRegistry.getOAuthServers as jest.Mock).mockResolvedValue(oauthServers);
|
||||
|
||||
const now = Date.now();
|
||||
jest.setSystemTime(now);
|
||||
|
|
@ -428,7 +441,9 @@ describe('OAuthReconnectionManager', () => {
|
|||
|
||||
// First reconnect attempt - will fail
|
||||
mockMCPManager.getUserConnection.mockRejectedValueOnce(new Error('Connection failed'));
|
||||
mockMCPManager.getRawConfig.mockReturnValue({} as unknown as MCPOptions);
|
||||
(mcpServersRegistry.getServerConfig as jest.Mock).mockResolvedValue(
|
||||
{} as unknown as MCPOptions,
|
||||
);
|
||||
|
||||
await reconnectionManager.reconnectServers(userId);
|
||||
await jest.runAllTimersAsync();
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import type { MCPOAuthTokens } from './types';
|
|||
import { OAuthReconnectionTracker } from './OAuthReconnectionTracker';
|
||||
import { FlowStateManager } from '~/flow/manager';
|
||||
import { MCPManager } from '~/mcp/MCPManager';
|
||||
import { mcpServersRegistry } from '~/mcp/registry/MCPServersRegistry';
|
||||
|
||||
const DEFAULT_CONNECTION_TIMEOUT_MS = 10_000; // ms
|
||||
|
||||
|
|
@ -72,7 +73,7 @@ export class OAuthReconnectionManager {
|
|||
|
||||
// 1. derive the servers to reconnect
|
||||
const serversToReconnect = [];
|
||||
for (const serverName of this.mcpManager.getOAuthServers()) {
|
||||
for (const serverName of await mcpServersRegistry.getOAuthServers()) {
|
||||
const canReconnect = await this.canReconnect(userId, serverName);
|
||||
if (canReconnect) {
|
||||
serversToReconnect.push(serverName);
|
||||
|
|
@ -104,7 +105,7 @@ export class OAuthReconnectionManager {
|
|||
|
||||
logger.info(`${logPrefix} Attempting reconnection`);
|
||||
|
||||
const config = this.mcpManager.getRawConfig(serverName);
|
||||
const config = await mcpServersRegistry.getServerConfig(serverName, userId);
|
||||
|
||||
const cleanupOnFailedReconnect = () => {
|
||||
this.reconnectionsTracker.setFailed(userId, serverName);
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import { randomBytes } from 'crypto';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import { FetchLike } from '@modelcontextprotocol/sdk/shared/transport';
|
||||
import { OAuthMetadataSchema } from '@modelcontextprotocol/sdk/shared/auth.js';
|
||||
import {
|
||||
registerClient,
|
||||
startAuthorization,
|
||||
|
|
@ -7,7 +9,6 @@ import {
|
|||
discoverAuthorizationServerMetadata,
|
||||
discoverOAuthProtectedResourceMetadata,
|
||||
} from '@modelcontextprotocol/sdk/client/auth.js';
|
||||
import { OAuthMetadataSchema } from '@modelcontextprotocol/sdk/shared/auth.js';
|
||||
import type { MCPOptions } from 'librechat-data-provider';
|
||||
import type { FlowStateManager } from '~/flow/manager';
|
||||
import type {
|
||||
|
|
@ -18,7 +19,6 @@ import type {
|
|||
OAuthMetadata,
|
||||
} from './types';
|
||||
import { sanitizeUrlForLogging } from '~/mcp/utils';
|
||||
import { FetchLike } from '@modelcontextprotocol/sdk/shared/transport';
|
||||
|
||||
/** Type for the OAuth metadata from the SDK */
|
||||
type SDKOAuthMetadata = Parameters<typeof registerClient>[1]['metadata'];
|
||||
|
|
@ -439,9 +439,10 @@ export class MCPOAuthHandler {
|
|||
fetchFn: this.createOAuthFetch(oauthHeaders),
|
||||
});
|
||||
|
||||
logger.debug('[MCPOAuth] Raw tokens from exchange:', {
|
||||
access_token: tokens.access_token ? '[REDACTED]' : undefined,
|
||||
refresh_token: tokens.refresh_token ? '[REDACTED]' : undefined,
|
||||
logger.debug('[MCPOAuth] Token exchange successful', {
|
||||
flowId,
|
||||
has_access_token: !!tokens.access_token,
|
||||
has_refresh_token: !!tokens.refresh_token,
|
||||
expires_in: tokens.expires_in,
|
||||
token_type: tokens.token_type,
|
||||
scope: tokens.scope,
|
||||
|
|
|
|||
|
|
@ -217,7 +217,11 @@ export class MCPTokenStorage {
|
|||
}
|
||||
}
|
||||
|
||||
logger.debug(`${logPrefix} Stored OAuth tokens`);
|
||||
logger.debug(`${logPrefix} Stored OAuth tokens`, {
|
||||
client_id: clientInfo?.client_id,
|
||||
has_refresh_token: !!tokens.refresh_token,
|
||||
expires_at: 'expires_at' in tokens ? tokens.expires_at : 'N/A',
|
||||
});
|
||||
} catch (error) {
|
||||
const logPrefix = this.getLogPrefix(userId, serverName);
|
||||
logger.error(`${logPrefix} Failed to store tokens`, error);
|
||||
|
|
|
|||
123
packages/api/src/mcp/registry/MCPServerInspector.ts
Normal file
123
packages/api/src/mcp/registry/MCPServerInspector.ts
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
import { Constants } from 'librechat-data-provider';
|
||||
import type { JsonSchemaType } from '@librechat/data-schemas';
|
||||
import type { MCPConnection } from '~/mcp/connection';
|
||||
import type * as t from '~/mcp/types';
|
||||
import { detectOAuthRequirement } from '~/mcp/oauth';
|
||||
import { isEnabled } from '~/utils';
|
||||
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
|
||||
|
||||
/**
|
||||
* Inspects MCP servers to discover their metadata, capabilities, and tools.
|
||||
* Connects to servers and populates configuration with OAuth requirements,
|
||||
* server instructions, capabilities, and available tools.
|
||||
*/
|
||||
export class MCPServerInspector {
|
||||
private constructor(
|
||||
private readonly serverName: string,
|
||||
private readonly config: t.ParsedServerConfig,
|
||||
private connection: MCPConnection | undefined,
|
||||
) {}
|
||||
|
||||
/**
|
||||
* Inspects a server and returns an enriched configuration with metadata.
|
||||
* Detects OAuth requirements and fetches server capabilities.
|
||||
* @param serverName - The name of the server (used for tool function naming)
|
||||
* @param rawConfig - The raw server configuration
|
||||
* @param connection - The MCP connection
|
||||
* @returns A fully processed and enriched configuration with server metadata
|
||||
*/
|
||||
public static async inspect(
|
||||
serverName: string,
|
||||
rawConfig: t.MCPOptions,
|
||||
connection?: MCPConnection,
|
||||
): Promise<t.ParsedServerConfig> {
|
||||
const start = Date.now();
|
||||
const inspector = new MCPServerInspector(serverName, rawConfig, connection);
|
||||
await inspector.inspectServer();
|
||||
inspector.config.initDuration = Date.now() - start;
|
||||
return inspector.config;
|
||||
}
|
||||
|
||||
private async inspectServer(): Promise<void> {
|
||||
await this.detectOAuth();
|
||||
|
||||
if (this.config.startup !== false && !this.config.requiresOAuth) {
|
||||
let tempConnection = false;
|
||||
if (!this.connection) {
|
||||
tempConnection = true;
|
||||
this.connection = await MCPConnectionFactory.create({
|
||||
serverName: this.serverName,
|
||||
serverConfig: this.config,
|
||||
});
|
||||
}
|
||||
|
||||
await Promise.allSettled([
|
||||
this.fetchServerInstructions(),
|
||||
this.fetchServerCapabilities(),
|
||||
this.fetchToolFunctions(),
|
||||
]);
|
||||
|
||||
if (tempConnection) await this.connection.disconnect();
|
||||
}
|
||||
}
|
||||
|
||||
private async detectOAuth(): Promise<void> {
|
||||
if (this.config.requiresOAuth != null) return;
|
||||
if (this.config.url == null || this.config.startup === false) {
|
||||
this.config.requiresOAuth = false;
|
||||
return;
|
||||
}
|
||||
|
||||
const result = await detectOAuthRequirement(this.config.url);
|
||||
this.config.requiresOAuth = result.requiresOAuth;
|
||||
this.config.oauthMetadata = result.metadata;
|
||||
}
|
||||
|
||||
private async fetchServerInstructions(): Promise<void> {
|
||||
if (isEnabled(this.config.serverInstructions)) {
|
||||
this.config.serverInstructions = this.connection!.client.getInstructions();
|
||||
}
|
||||
}
|
||||
|
||||
private async fetchServerCapabilities(): Promise<void> {
|
||||
const capabilities = this.connection!.client.getServerCapabilities();
|
||||
this.config.capabilities = JSON.stringify(capabilities);
|
||||
const tools = await this.connection!.client.listTools();
|
||||
this.config.tools = tools.tools.map((tool) => tool.name).join(', ');
|
||||
}
|
||||
|
||||
private async fetchToolFunctions(): Promise<void> {
|
||||
this.config.toolFunctions = await MCPServerInspector.getToolFunctions(
|
||||
this.serverName,
|
||||
this.connection!,
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts server tools to LibreChat-compatible tool functions format.
|
||||
* @param serverName - The name of the server
|
||||
* @param connection - The MCP connection
|
||||
* @returns Tool functions formatted for LibreChat
|
||||
*/
|
||||
public static async getToolFunctions(
|
||||
serverName: string,
|
||||
connection: MCPConnection,
|
||||
): Promise<t.LCAvailableTools> {
|
||||
const { tools }: t.MCPToolListResponse = await connection.client.listTools();
|
||||
|
||||
const toolFunctions: t.LCAvailableTools = {};
|
||||
tools.forEach((tool) => {
|
||||
const name = `${tool.name}${Constants.mcp_delimiter}${serverName}`;
|
||||
toolFunctions[name] = {
|
||||
type: 'function',
|
||||
['function']: {
|
||||
name,
|
||||
description: tool.description,
|
||||
parameters: tool.inputSchema as JsonSchemaType,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
return toolFunctions;
|
||||
}
|
||||
}
|
||||
96
packages/api/src/mcp/registry/MCPServersInitializer.ts
Normal file
96
packages/api/src/mcp/registry/MCPServersInitializer.ts
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
import { registryStatusCache as statusCache } from './cache/RegistryStatusCache';
|
||||
import { isLeader } from '~/cluster';
|
||||
import { withTimeout } from '~/utils';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import { MCPServerInspector } from './MCPServerInspector';
|
||||
import { ParsedServerConfig } from '~/mcp/types';
|
||||
import { sanitizeUrlForLogging } from '~/mcp/utils';
|
||||
import type * as t from '~/mcp/types';
|
||||
import { mcpServersRegistry as registry } from './MCPServersRegistry';
|
||||
|
||||
const MCP_INIT_TIMEOUT_MS =
|
||||
process.env.MCP_INIT_TIMEOUT_MS != null ? parseInt(process.env.MCP_INIT_TIMEOUT_MS) : 30_000;
|
||||
|
||||
/**
|
||||
* Handles initialization of MCP servers at application startup with distributed coordination.
|
||||
* In cluster environments, ensures only the leader node performs initialization while followers wait.
|
||||
* Connects to each configured MCP server, inspects capabilities and tools, then caches the results.
|
||||
* Categorizes servers as either shared app servers (auto-started) or shared user servers (OAuth/on-demand).
|
||||
* Uses a timeout mechanism to prevent hanging on unresponsive servers during initialization.
|
||||
*/
|
||||
export class MCPServersInitializer {
|
||||
/**
|
||||
* Initializes MCP servers with distributed leader-follower coordination.
|
||||
*
|
||||
* Design rationale:
|
||||
* - Handles leader crash scenarios: If the leader crashes during initialization, all followers
|
||||
* will independently attempt initialization after a 3-second delay. The first to become leader
|
||||
* will complete the initialization.
|
||||
* - Only the leader performs the actual initialization work (reset caches, inspect servers).
|
||||
* When complete, the leader signals completion via `statusCache`, allowing followers to proceed.
|
||||
* - Followers wait and poll `statusCache` until the leader finishes, ensuring only one node
|
||||
* performs the expensive initialization operations.
|
||||
*/
|
||||
public static async initialize(rawConfigs: t.MCPServers): Promise<void> {
|
||||
if (await statusCache.isInitialized()) return;
|
||||
|
||||
if (await isLeader()) {
|
||||
// Leader performs initialization
|
||||
await statusCache.reset();
|
||||
await registry.reset();
|
||||
const serverNames = Object.keys(rawConfigs);
|
||||
await Promise.allSettled(
|
||||
serverNames.map((serverName) =>
|
||||
withTimeout(
|
||||
MCPServersInitializer.initializeServer(serverName, rawConfigs[serverName]),
|
||||
MCP_INIT_TIMEOUT_MS,
|
||||
`${MCPServersInitializer.prefix(serverName)} Server initialization timed out`,
|
||||
logger.error,
|
||||
),
|
||||
),
|
||||
);
|
||||
await statusCache.setInitialized(true);
|
||||
} else {
|
||||
// Followers try again after a delay if not initialized
|
||||
await new Promise((resolve) => setTimeout(resolve, 3000));
|
||||
await this.initialize(rawConfigs);
|
||||
}
|
||||
}
|
||||
|
||||
/** Initializes a single server with all its metadata and adds it to appropriate collections */
|
||||
private static async initializeServer(
|
||||
serverName: string,
|
||||
rawConfig: t.MCPOptions,
|
||||
): Promise<void> {
|
||||
try {
|
||||
const config = await MCPServerInspector.inspect(serverName, rawConfig);
|
||||
|
||||
if (config.startup === false || config.requiresOAuth) {
|
||||
await registry.sharedUserServers.add(serverName, config);
|
||||
} else {
|
||||
await registry.sharedAppServers.add(serverName, config);
|
||||
}
|
||||
MCPServersInitializer.logParsedConfig(serverName, config);
|
||||
} catch (error) {
|
||||
logger.error(`${MCPServersInitializer.prefix(serverName)} Failed to initialize:`, error);
|
||||
}
|
||||
}
|
||||
|
||||
// Logs server configuration summary after initialization
|
||||
private static logParsedConfig(serverName: string, config: ParsedServerConfig): void {
|
||||
const prefix = MCPServersInitializer.prefix(serverName);
|
||||
logger.info(`${prefix} -------------------------------------------------┐`);
|
||||
logger.info(`${prefix} URL: ${config.url ? sanitizeUrlForLogging(config.url) : 'N/A'}`);
|
||||
logger.info(`${prefix} OAuth Required: ${config.requiresOAuth}`);
|
||||
logger.info(`${prefix} Capabilities: ${config.capabilities}`);
|
||||
logger.info(`${prefix} Tools: ${config.tools}`);
|
||||
logger.info(`${prefix} Server Instructions: ${config.serverInstructions}`);
|
||||
logger.info(`${prefix} Initialized in: ${config.initDuration ?? 'N/A'}ms`);
|
||||
logger.info(`${prefix} -------------------------------------------------┘`);
|
||||
}
|
||||
|
||||
// Returns formatted log prefix for server messages
|
||||
private static prefix(serverName: string): string {
|
||||
return `[MCP][${serverName}]`;
|
||||
}
|
||||
}
|
||||
91
packages/api/src/mcp/registry/MCPServersRegistry.ts
Normal file
91
packages/api/src/mcp/registry/MCPServersRegistry.ts
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
import type * as t from '~/mcp/types';
|
||||
import {
|
||||
ServerConfigsCacheFactory,
|
||||
type ServerConfigsCache,
|
||||
} from './cache/ServerConfigsCacheFactory';
|
||||
|
||||
/**
|
||||
* Central registry for managing MCP server configurations across different scopes and users.
|
||||
* Maintains three categories of server configurations:
|
||||
* - Shared App Servers: Auto-started servers available to all users (initialized at startup)
|
||||
* - Shared User Servers: User-scope servers that require OAuth or on-demand startup
|
||||
* - Private User Servers: Per-user configurations dynamically added during runtime
|
||||
*
|
||||
* Provides a unified interface for retrieving server configs with proper fallback hierarchy:
|
||||
* checks shared app servers first, then shared user servers, then private user servers.
|
||||
* Handles server lifecycle operations including adding, removing, and querying configurations.
|
||||
*/
|
||||
class MCPServersRegistry {
|
||||
public readonly sharedAppServers = ServerConfigsCacheFactory.create('App', true);
|
||||
public readonly sharedUserServers = ServerConfigsCacheFactory.create('User', true);
|
||||
private readonly privateUserServers: Map<string | undefined, ServerConfigsCache> = new Map();
|
||||
|
||||
public async addPrivateUserServer(
|
||||
userId: string,
|
||||
serverName: string,
|
||||
config: t.ParsedServerConfig,
|
||||
): Promise<void> {
|
||||
if (!this.privateUserServers.has(userId)) {
|
||||
const cache = ServerConfigsCacheFactory.create(`User(${userId})`, false);
|
||||
this.privateUserServers.set(userId, cache);
|
||||
}
|
||||
await this.privateUserServers.get(userId)!.add(serverName, config);
|
||||
}
|
||||
|
||||
public async updatePrivateUserServer(
|
||||
userId: string,
|
||||
serverName: string,
|
||||
config: t.ParsedServerConfig,
|
||||
): Promise<void> {
|
||||
const userCache = this.privateUserServers.get(userId);
|
||||
if (!userCache) throw new Error(`No private servers found for user "${userId}".`);
|
||||
await userCache.update(serverName, config);
|
||||
}
|
||||
|
||||
public async removePrivateUserServer(userId: string, serverName: string): Promise<void> {
|
||||
await this.privateUserServers.get(userId)?.remove(serverName);
|
||||
}
|
||||
|
||||
public async getServerConfig(
|
||||
serverName: string,
|
||||
userId?: string,
|
||||
): Promise<t.ParsedServerConfig | undefined> {
|
||||
const sharedAppServer = await this.sharedAppServers.get(serverName);
|
||||
if (sharedAppServer) return sharedAppServer;
|
||||
|
||||
const sharedUserServer = await this.sharedUserServers.get(serverName);
|
||||
if (sharedUserServer) return sharedUserServer;
|
||||
|
||||
const privateUserServer = await this.privateUserServers.get(userId)?.get(serverName);
|
||||
if (privateUserServer) return privateUserServer;
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
public async getAllServerConfigs(userId?: string): Promise<Record<string, t.ParsedServerConfig>> {
|
||||
return {
|
||||
...(await this.sharedAppServers.getAll()),
|
||||
...(await this.sharedUserServers.getAll()),
|
||||
...((await this.privateUserServers.get(userId)?.getAll()) ?? {}),
|
||||
};
|
||||
}
|
||||
|
||||
// TODO: This is currently used to determine if a server requires OAuth. However, this info can
|
||||
// can be determined through config.requiresOAuth. Refactor usages and remove this method.
|
||||
public async getOAuthServers(userId?: string): Promise<Set<string>> {
|
||||
const allServers = await this.getAllServerConfigs(userId);
|
||||
const oauthServers = Object.entries(allServers).filter(([, config]) => config.requiresOAuth);
|
||||
return new Set(oauthServers.map(([name]) => name));
|
||||
}
|
||||
|
||||
public async reset(): Promise<void> {
|
||||
await this.sharedAppServers.reset();
|
||||
await this.sharedUserServers.reset();
|
||||
for (const cache of this.privateUserServers.values()) {
|
||||
await cache.reset();
|
||||
}
|
||||
this.privateUserServers.clear();
|
||||
}
|
||||
}
|
||||
|
||||
export const mcpServersRegistry = new MCPServersRegistry();
|
||||
|
|
@ -0,0 +1,338 @@
|
|||
import type { MCPConnection } from '~/mcp/connection';
|
||||
import type * as t from '~/mcp/types';
|
||||
import { MCPServerInspector } from '~/mcp/registry/MCPServerInspector';
|
||||
import { detectOAuthRequirement } from '~/mcp/oauth';
|
||||
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
|
||||
import { createMockConnection } from './mcpConnectionsMock.helper';
|
||||
|
||||
// Mock external dependencies
|
||||
jest.mock('../../oauth/detectOAuth');
|
||||
jest.mock('../../MCPConnectionFactory');
|
||||
|
||||
const mockDetectOAuthRequirement = detectOAuthRequirement as jest.MockedFunction<
|
||||
typeof detectOAuthRequirement
|
||||
>;
|
||||
|
||||
describe('MCPServerInspector', () => {
|
||||
let mockConnection: jest.Mocked<MCPConnection>;
|
||||
|
||||
beforeEach(() => {
|
||||
mockConnection = createMockConnection('test_server');
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('inspect()', () => {
|
||||
it('should process env and fetch all metadata for non-OAuth stdio server with serverInstructions=true', async () => {
|
||||
const rawConfig: t.MCPOptions = {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['server.js'],
|
||||
serverInstructions: true,
|
||||
};
|
||||
|
||||
mockDetectOAuthRequirement.mockResolvedValue({
|
||||
requiresOAuth: false,
|
||||
method: 'no-metadata-found',
|
||||
});
|
||||
|
||||
const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection);
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['server.js'],
|
||||
serverInstructions: 'instructions for test_server',
|
||||
requiresOAuth: false,
|
||||
capabilities:
|
||||
'{"tools":{"listChanged":true},"resources":{"listChanged":true},"prompts":{"get":"getPrompts for test_server"}}',
|
||||
tools: 'listFiles',
|
||||
toolFunctions: {
|
||||
listFiles_mcp_test_server: expect.objectContaining({
|
||||
type: 'function',
|
||||
function: expect.objectContaining({
|
||||
name: 'listFiles_mcp_test_server',
|
||||
}),
|
||||
}),
|
||||
},
|
||||
initDuration: expect.any(Number),
|
||||
});
|
||||
});
|
||||
|
||||
it('should detect OAuth and skip capabilities fetch for streamable-http server', async () => {
|
||||
const rawConfig: t.MCPOptions = {
|
||||
type: 'streamable-http',
|
||||
url: 'https://api.example.com/mcp',
|
||||
};
|
||||
|
||||
mockDetectOAuthRequirement.mockResolvedValue({
|
||||
requiresOAuth: true,
|
||||
method: 'protected-resource-metadata',
|
||||
});
|
||||
|
||||
const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection);
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'streamable-http',
|
||||
url: 'https://api.example.com/mcp',
|
||||
requiresOAuth: true,
|
||||
oauthMetadata: undefined,
|
||||
initDuration: expect.any(Number),
|
||||
});
|
||||
});
|
||||
|
||||
it('should skip capabilities fetch when startup=false', async () => {
|
||||
const rawConfig: t.MCPOptions = {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['server.js'],
|
||||
startup: false,
|
||||
};
|
||||
|
||||
const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection);
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['server.js'],
|
||||
startup: false,
|
||||
requiresOAuth: false,
|
||||
initDuration: expect.any(Number),
|
||||
});
|
||||
});
|
||||
|
||||
it('should keep custom serverInstructions string and not fetch from server', async () => {
|
||||
const rawConfig: t.MCPOptions = {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['server.js'],
|
||||
serverInstructions: 'Custom instructions here',
|
||||
};
|
||||
|
||||
mockDetectOAuthRequirement.mockResolvedValue({
|
||||
requiresOAuth: false,
|
||||
method: 'no-metadata-found',
|
||||
});
|
||||
|
||||
const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection);
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['server.js'],
|
||||
serverInstructions: 'Custom instructions here',
|
||||
requiresOAuth: false,
|
||||
capabilities:
|
||||
'{"tools":{"listChanged":true},"resources":{"listChanged":true},"prompts":{"get":"getPrompts for test_server"}}',
|
||||
tools: 'listFiles',
|
||||
toolFunctions: expect.any(Object),
|
||||
initDuration: expect.any(Number),
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle serverInstructions as string "true" and fetch from server', async () => {
|
||||
const rawConfig: t.MCPOptions = {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['server.js'],
|
||||
serverInstructions: 'true', // String "true" from YAML
|
||||
};
|
||||
|
||||
mockDetectOAuthRequirement.mockResolvedValue({
|
||||
requiresOAuth: false,
|
||||
method: 'no-metadata-found',
|
||||
});
|
||||
|
||||
const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection);
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['server.js'],
|
||||
serverInstructions: 'instructions for test_server',
|
||||
requiresOAuth: false,
|
||||
capabilities:
|
||||
'{"tools":{"listChanged":true},"resources":{"listChanged":true},"prompts":{"get":"getPrompts for test_server"}}',
|
||||
tools: 'listFiles',
|
||||
toolFunctions: expect.any(Object),
|
||||
initDuration: expect.any(Number),
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle predefined requiresOAuth without detection', async () => {
|
||||
const rawConfig: t.MCPOptions = {
|
||||
type: 'sse',
|
||||
url: 'https://api.example.com/sse',
|
||||
requiresOAuth: true,
|
||||
};
|
||||
|
||||
const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection);
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'sse',
|
||||
url: 'https://api.example.com/sse',
|
||||
requiresOAuth: true,
|
||||
initDuration: expect.any(Number),
|
||||
});
|
||||
});
|
||||
|
||||
it('should fetch capabilities when server has no tools', async () => {
|
||||
const rawConfig: t.MCPOptions = {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['server.js'],
|
||||
};
|
||||
|
||||
mockDetectOAuthRequirement.mockResolvedValue({
|
||||
requiresOAuth: false,
|
||||
method: 'no-metadata-found',
|
||||
});
|
||||
|
||||
// Mock server with no tools
|
||||
mockConnection.client.listTools = jest.fn().mockResolvedValue({ tools: [] });
|
||||
|
||||
const result = await MCPServerInspector.inspect('test_server', rawConfig, mockConnection);
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['server.js'],
|
||||
requiresOAuth: false,
|
||||
capabilities:
|
||||
'{"tools":{"listChanged":true},"resources":{"listChanged":true},"prompts":{"get":"getPrompts for test_server"}}',
|
||||
tools: '',
|
||||
toolFunctions: {},
|
||||
initDuration: expect.any(Number),
|
||||
});
|
||||
});
|
||||
|
||||
it('should create temporary connection when no connection is provided', async () => {
|
||||
const rawConfig: t.MCPOptions = {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['server.js'],
|
||||
serverInstructions: true,
|
||||
};
|
||||
|
||||
const tempMockConnection = createMockConnection('test_server');
|
||||
(MCPConnectionFactory.create as jest.Mock).mockResolvedValue(tempMockConnection);
|
||||
|
||||
mockDetectOAuthRequirement.mockResolvedValue({
|
||||
requiresOAuth: false,
|
||||
method: 'no-metadata-found',
|
||||
});
|
||||
|
||||
const result = await MCPServerInspector.inspect('test_server', rawConfig);
|
||||
|
||||
// Verify factory was called to create connection
|
||||
expect(MCPConnectionFactory.create).toHaveBeenCalledWith({
|
||||
serverName: 'test_server',
|
||||
serverConfig: expect.objectContaining({ type: 'stdio', command: 'node' }),
|
||||
});
|
||||
|
||||
// Verify temporary connection was disconnected
|
||||
expect(tempMockConnection.disconnect).toHaveBeenCalled();
|
||||
|
||||
// Verify result is correct
|
||||
expect(result).toEqual({
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['server.js'],
|
||||
serverInstructions: 'instructions for test_server',
|
||||
requiresOAuth: false,
|
||||
capabilities:
|
||||
'{"tools":{"listChanged":true},"resources":{"listChanged":true},"prompts":{"get":"getPrompts for test_server"}}',
|
||||
tools: 'listFiles',
|
||||
toolFunctions: expect.any(Object),
|
||||
initDuration: expect.any(Number),
|
||||
});
|
||||
});
|
||||
|
||||
it('should not create temporary connection when connection is provided', async () => {
|
||||
const rawConfig: t.MCPOptions = {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['server.js'],
|
||||
serverInstructions: true,
|
||||
};
|
||||
|
||||
mockDetectOAuthRequirement.mockResolvedValue({
|
||||
requiresOAuth: false,
|
||||
method: 'no-metadata-found',
|
||||
});
|
||||
|
||||
await MCPServerInspector.inspect('test_server', rawConfig, mockConnection);
|
||||
|
||||
// Verify factory was NOT called
|
||||
expect(MCPConnectionFactory.create).not.toHaveBeenCalled();
|
||||
|
||||
// Verify provided connection was NOT disconnected
|
||||
expect(mockConnection.disconnect).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getToolFunctions()', () => {
|
||||
it('should convert MCP tools to LibreChat tool functions format', async () => {
|
||||
mockConnection.client.listTools = jest.fn().mockResolvedValue({
|
||||
tools: [
|
||||
{
|
||||
name: 'file_read',
|
||||
description: 'Read a file',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: { path: { type: 'string' } },
|
||||
},
|
||||
},
|
||||
{
|
||||
name: 'file_write',
|
||||
description: 'Write a file',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
path: { type: 'string' },
|
||||
content: { type: 'string' },
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
const result = await MCPServerInspector.getToolFunctions('my_server', mockConnection);
|
||||
|
||||
expect(result).toEqual({
|
||||
file_read_mcp_my_server: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'file_read_mcp_my_server',
|
||||
description: 'Read a file',
|
||||
parameters: {
|
||||
type: 'object',
|
||||
properties: { path: { type: 'string' } },
|
||||
},
|
||||
},
|
||||
},
|
||||
file_write_mcp_my_server: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'file_write_mcp_my_server',
|
||||
description: 'Write a file',
|
||||
parameters: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
path: { type: 'string' },
|
||||
content: { type: 'string' },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle empty tools list', async () => {
|
||||
mockConnection.client.listTools = jest.fn().mockResolvedValue({ tools: [] });
|
||||
|
||||
const result = await MCPServerInspector.getToolFunctions('my_server', mockConnection);
|
||||
|
||||
expect(result).toEqual({});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,301 @@
|
|||
import { expect } from '@playwright/test';
|
||||
import type * as t from '~/mcp/types';
|
||||
import type { MCPConnection } from '~/mcp/connection';
|
||||
|
||||
// Mock isLeader to always return true to avoid lock contention during parallel operations
|
||||
jest.mock('~/cluster', () => ({
|
||||
...jest.requireActual('~/cluster'),
|
||||
isLeader: jest.fn().mockResolvedValue(true),
|
||||
}));
|
||||
|
||||
describe('MCPServersInitializer Redis Integration Tests', () => {
|
||||
let MCPServersInitializer: typeof import('../MCPServersInitializer').MCPServersInitializer;
|
||||
let registry: typeof import('../MCPServersRegistry').mcpServersRegistry;
|
||||
let registryStatusCache: typeof import('../cache/RegistryStatusCache').registryStatusCache;
|
||||
let MCPServerInspector: typeof import('../MCPServerInspector').MCPServerInspector;
|
||||
let MCPConnectionFactory: typeof import('~/mcp/MCPConnectionFactory').MCPConnectionFactory;
|
||||
let keyvRedisClient: Awaited<typeof import('~/cache/redisClients')>['keyvRedisClient'];
|
||||
let LeaderElection: typeof import('~/cluster/LeaderElection').LeaderElection;
|
||||
let leaderInstance: InstanceType<typeof import('~/cluster/LeaderElection').LeaderElection>;
|
||||
|
||||
const testConfigs: t.MCPServers = {
|
||||
disabled_server: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['disabled.js'],
|
||||
startup: false,
|
||||
},
|
||||
oauth_server: {
|
||||
type: 'streamable-http',
|
||||
url: 'https://api.example.com/mcp',
|
||||
},
|
||||
file_tools_server: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['tools.js'],
|
||||
},
|
||||
search_tools_server: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['instructions.js'],
|
||||
},
|
||||
};
|
||||
|
||||
const testParsedConfigs: Record<string, t.ParsedServerConfig> = {
|
||||
disabled_server: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['disabled.js'],
|
||||
startup: false,
|
||||
requiresOAuth: false,
|
||||
},
|
||||
oauth_server: {
|
||||
type: 'streamable-http',
|
||||
url: 'https://api.example.com/mcp',
|
||||
requiresOAuth: true,
|
||||
},
|
||||
file_tools_server: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['tools.js'],
|
||||
requiresOAuth: false,
|
||||
serverInstructions: 'Instructions for file_tools_server',
|
||||
tools: 'file_read, file_write',
|
||||
capabilities: '{"tools":{"listChanged":true}}',
|
||||
toolFunctions: {
|
||||
file_read_mcp_file_tools_server: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'file_read_mcp_file_tools_server',
|
||||
description: 'Read a file',
|
||||
parameters: { type: 'object' },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
search_tools_server: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['instructions.js'],
|
||||
requiresOAuth: false,
|
||||
serverInstructions: 'Instructions for search_tools_server',
|
||||
capabilities: '{"tools":{"listChanged":true}}',
|
||||
tools: 'search',
|
||||
toolFunctions: {
|
||||
search_mcp_search_tools_server: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'search_mcp_search_tools_server',
|
||||
description: 'Search tool',
|
||||
parameters: { type: 'object' },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
beforeAll(async () => {
|
||||
// Set up environment variables for Redis (only if not already set)
|
||||
process.env.USE_REDIS = process.env.USE_REDIS ?? 'true';
|
||||
process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379';
|
||||
process.env.REDIS_KEY_PREFIX =
|
||||
process.env.REDIS_KEY_PREFIX ?? 'MCPServersInitializer-IntegrationTest';
|
||||
|
||||
// Import modules after setting env vars
|
||||
const initializerModule = await import('../MCPServersInitializer');
|
||||
const registryModule = await import('../MCPServersRegistry');
|
||||
const statusCacheModule = await import('../cache/RegistryStatusCache');
|
||||
const inspectorModule = await import('../MCPServerInspector');
|
||||
const connectionFactoryModule = await import('~/mcp/MCPConnectionFactory');
|
||||
const redisClients = await import('~/cache/redisClients');
|
||||
const leaderElectionModule = await import('~/cluster/LeaderElection');
|
||||
|
||||
MCPServersInitializer = initializerModule.MCPServersInitializer;
|
||||
registry = registryModule.mcpServersRegistry;
|
||||
registryStatusCache = statusCacheModule.registryStatusCache;
|
||||
MCPServerInspector = inspectorModule.MCPServerInspector;
|
||||
MCPConnectionFactory = connectionFactoryModule.MCPConnectionFactory;
|
||||
keyvRedisClient = redisClients.keyvRedisClient;
|
||||
LeaderElection = leaderElectionModule.LeaderElection;
|
||||
|
||||
// Ensure Redis is connected
|
||||
if (!keyvRedisClient) throw new Error('Redis client is not initialized');
|
||||
|
||||
// Wait for Redis to be ready
|
||||
if (!keyvRedisClient.isOpen) await keyvRedisClient.connect();
|
||||
|
||||
// Become leader so we can perform write operations
|
||||
leaderInstance = new LeaderElection();
|
||||
const isLeader = await leaderInstance.isLeader();
|
||||
expect(isLeader).toBe(true);
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
// Ensure we're still the leader
|
||||
const isLeader = await leaderInstance.isLeader();
|
||||
if (!isLeader) {
|
||||
throw new Error('Lost leader status before test');
|
||||
}
|
||||
|
||||
// Mock MCPServerInspector.inspect to return parsed config
|
||||
jest.spyOn(MCPServerInspector, 'inspect').mockImplementation(async (serverName: string) => {
|
||||
return {
|
||||
...testParsedConfigs[serverName],
|
||||
_processedByInspector: true,
|
||||
} as unknown as t.ParsedServerConfig;
|
||||
});
|
||||
|
||||
// Mock MCPConnection
|
||||
const mockConnection = {
|
||||
disconnect: jest.fn().mockResolvedValue(undefined),
|
||||
} as unknown as jest.Mocked<MCPConnection>;
|
||||
|
||||
// Mock MCPConnectionFactory
|
||||
jest.spyOn(MCPConnectionFactory, 'create').mockResolvedValue(mockConnection);
|
||||
|
||||
// Reset caches before each test
|
||||
await registryStatusCache.reset();
|
||||
await registry.reset();
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
// Clean up: clear all test keys from Redis
|
||||
if (keyvRedisClient) {
|
||||
const pattern = '*MCPServersInitializer-IntegrationTest*';
|
||||
if ('scanIterator' in keyvRedisClient) {
|
||||
for await (const key of keyvRedisClient.scanIterator({ MATCH: pattern })) {
|
||||
await keyvRedisClient.del(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
jest.restoreAllMocks();
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
// Resign as leader
|
||||
if (leaderInstance) await leaderInstance.resign();
|
||||
|
||||
// Close Redis connection
|
||||
if (keyvRedisClient?.isOpen) await keyvRedisClient.disconnect();
|
||||
});
|
||||
|
||||
describe('initialize()', () => {
|
||||
it('should reset registry and status cache before initialization', async () => {
|
||||
// Pre-populate registry with some old servers
|
||||
await registry.sharedAppServers.add('old_app_server', testParsedConfigs.file_tools_server);
|
||||
await registry.sharedUserServers.add('old_user_server', testParsedConfigs.oauth_server);
|
||||
|
||||
// Initialize with new configs (this should reset first)
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
// Verify old servers are gone
|
||||
expect(await registry.sharedAppServers.get('old_app_server')).toBeUndefined();
|
||||
expect(await registry.sharedUserServers.get('old_user_server')).toBeUndefined();
|
||||
|
||||
// Verify new servers are present
|
||||
expect(await registry.sharedAppServers.get('file_tools_server')).toBeDefined();
|
||||
expect(await registry.sharedUserServers.get('oauth_server')).toBeDefined();
|
||||
expect(await registryStatusCache.isInitialized()).toBe(true);
|
||||
});
|
||||
|
||||
it('should skip initialization if already initialized', async () => {
|
||||
// First initialization
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
// Clear mock calls
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Second initialization should skip due to static flag
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
// Verify inspect was not called again
|
||||
expect(MCPServerInspector.inspect).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should add disabled servers to sharedUserServers', async () => {
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
const disabledServer = await registry.sharedUserServers.get('disabled_server');
|
||||
expect(disabledServer).toBeDefined();
|
||||
expect(disabledServer).toMatchObject({
|
||||
...testParsedConfigs.disabled_server,
|
||||
_processedByInspector: true,
|
||||
});
|
||||
});
|
||||
|
||||
it('should add OAuth servers to sharedUserServers', async () => {
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
const oauthServer = await registry.sharedUserServers.get('oauth_server');
|
||||
expect(oauthServer).toBeDefined();
|
||||
expect(oauthServer).toMatchObject({
|
||||
...testParsedConfigs.oauth_server,
|
||||
_processedByInspector: true,
|
||||
});
|
||||
});
|
||||
|
||||
it('should add enabled non-OAuth servers to sharedAppServers', async () => {
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
const fileToolsServer = await registry.sharedAppServers.get('file_tools_server');
|
||||
expect(fileToolsServer).toBeDefined();
|
||||
expect(fileToolsServer).toMatchObject({
|
||||
...testParsedConfigs.file_tools_server,
|
||||
_processedByInspector: true,
|
||||
});
|
||||
|
||||
const searchToolsServer = await registry.sharedAppServers.get('search_tools_server');
|
||||
expect(searchToolsServer).toBeDefined();
|
||||
expect(searchToolsServer).toMatchObject({
|
||||
...testParsedConfigs.search_tools_server,
|
||||
_processedByInspector: true,
|
||||
});
|
||||
});
|
||||
|
||||
it('should successfully initialize all servers', async () => {
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
// Verify all servers were added to appropriate registries
|
||||
expect(await registry.sharedUserServers.get('disabled_server')).toBeDefined();
|
||||
expect(await registry.sharedUserServers.get('oauth_server')).toBeDefined();
|
||||
expect(await registry.sharedAppServers.get('file_tools_server')).toBeDefined();
|
||||
expect(await registry.sharedAppServers.get('search_tools_server')).toBeDefined();
|
||||
});
|
||||
|
||||
it('should handle inspection failures gracefully', async () => {
|
||||
// Mock inspection failure for one server
|
||||
jest.spyOn(MCPServerInspector, 'inspect').mockImplementation(async (serverName: string) => {
|
||||
if (serverName === 'file_tools_server') {
|
||||
throw new Error('Inspection failed');
|
||||
}
|
||||
return {
|
||||
...testParsedConfigs[serverName],
|
||||
_processedByInspector: true,
|
||||
} as unknown as t.ParsedServerConfig;
|
||||
});
|
||||
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
// Verify other servers were still processed
|
||||
const disabledServer = await registry.sharedUserServers.get('disabled_server');
|
||||
expect(disabledServer).toBeDefined();
|
||||
|
||||
const oauthServer = await registry.sharedUserServers.get('oauth_server');
|
||||
expect(oauthServer).toBeDefined();
|
||||
|
||||
const searchToolsServer = await registry.sharedAppServers.get('search_tools_server');
|
||||
expect(searchToolsServer).toBeDefined();
|
||||
|
||||
// Verify file_tools_server was not added (due to inspection failure)
|
||||
const fileToolsServer = await registry.sharedAppServers.get('file_tools_server');
|
||||
expect(fileToolsServer).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should set initialized status after completion', async () => {
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
expect(await registryStatusCache.isInitialized()).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,292 @@
|
|||
import { logger } from '@librechat/data-schemas';
|
||||
import * as t from '~/mcp/types';
|
||||
import { MCPConnectionFactory } from '~/mcp/MCPConnectionFactory';
|
||||
import { MCPServersInitializer } from '~/mcp/registry/MCPServersInitializer';
|
||||
import { MCPConnection } from '~/mcp/connection';
|
||||
import { registryStatusCache } from '~/mcp/registry/cache/RegistryStatusCache';
|
||||
import { MCPServerInspector } from '~/mcp/registry/MCPServerInspector';
|
||||
import { mcpServersRegistry as registry } from '~/mcp/registry/MCPServersRegistry';
|
||||
|
||||
// Mock external dependencies
|
||||
jest.mock('../../MCPConnectionFactory');
|
||||
jest.mock('../../connection');
|
||||
jest.mock('../../registry/MCPServerInspector');
|
||||
jest.mock('~/cluster', () => ({
|
||||
isLeader: jest.fn().mockResolvedValue(true),
|
||||
}));
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
const mockLogger = logger as jest.Mocked<typeof logger>;
|
||||
const mockInspect = MCPServerInspector.inspect as jest.MockedFunction<
|
||||
typeof MCPServerInspector.inspect
|
||||
>;
|
||||
|
||||
describe('MCPServersInitializer', () => {
|
||||
let mockConnection: jest.Mocked<MCPConnection>;
|
||||
|
||||
const testConfigs: t.MCPServers = {
|
||||
disabled_server: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['disabled.js'],
|
||||
startup: false,
|
||||
},
|
||||
oauth_server: {
|
||||
type: 'streamable-http',
|
||||
url: 'https://api.example.com/mcp',
|
||||
},
|
||||
file_tools_server: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['tools.js'],
|
||||
},
|
||||
search_tools_server: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['instructions.js'],
|
||||
},
|
||||
};
|
||||
|
||||
const testParsedConfigs: Record<string, t.ParsedServerConfig> = {
|
||||
disabled_server: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['disabled.js'],
|
||||
startup: false,
|
||||
requiresOAuth: false,
|
||||
},
|
||||
oauth_server: {
|
||||
type: 'streamable-http',
|
||||
url: 'https://api.example.com/mcp',
|
||||
requiresOAuth: true,
|
||||
},
|
||||
file_tools_server: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['tools.js'],
|
||||
requiresOAuth: false,
|
||||
serverInstructions: 'Instructions for file_tools_server',
|
||||
tools: 'file_read, file_write',
|
||||
capabilities: '{"tools":{"listChanged":true}}',
|
||||
toolFunctions: {
|
||||
file_read_mcp_file_tools_server: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'file_read_mcp_file_tools_server',
|
||||
description: 'Read a file',
|
||||
parameters: { type: 'object' },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
search_tools_server: {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['instructions.js'],
|
||||
requiresOAuth: false,
|
||||
serverInstructions: 'Instructions for search_tools_server',
|
||||
capabilities: '{"tools":{"listChanged":true}}',
|
||||
tools: 'search',
|
||||
toolFunctions: {
|
||||
search_mcp_search_tools_server: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'search_mcp_search_tools_server',
|
||||
description: 'Search tool',
|
||||
parameters: { type: 'object' },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
beforeEach(async () => {
|
||||
// Setup MCPConnection mock
|
||||
mockConnection = {
|
||||
disconnect: jest.fn().mockResolvedValue(undefined),
|
||||
} as unknown as jest.Mocked<MCPConnection>;
|
||||
|
||||
// Setup MCPConnectionFactory mock
|
||||
(MCPConnectionFactory.create as jest.Mock).mockResolvedValue(mockConnection);
|
||||
|
||||
// Mock MCPServerInspector.inspect to return parsed config
|
||||
mockInspect.mockImplementation(async (serverName: string) => {
|
||||
return {
|
||||
...testParsedConfigs[serverName],
|
||||
_processedByInspector: true,
|
||||
} as unknown as t.ParsedServerConfig;
|
||||
});
|
||||
|
||||
// Reset caches before each test
|
||||
await registryStatusCache.reset();
|
||||
await registry.sharedAppServers.reset();
|
||||
await registry.sharedUserServers.reset();
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
jest.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('initialize()', () => {
|
||||
it('should reset registry and status cache before initialization', async () => {
|
||||
// Pre-populate registry with some old servers
|
||||
await registry.sharedAppServers.add('old_app_server', testParsedConfigs.file_tools_server);
|
||||
await registry.sharedUserServers.add('old_user_server', testParsedConfigs.oauth_server);
|
||||
|
||||
// Initialize with new configs (this should reset first)
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
// Verify old servers are gone
|
||||
expect(await registry.sharedAppServers.get('old_app_server')).toBeUndefined();
|
||||
expect(await registry.sharedUserServers.get('old_user_server')).toBeUndefined();
|
||||
|
||||
// Verify new servers are present
|
||||
expect(await registry.sharedAppServers.get('file_tools_server')).toBeDefined();
|
||||
expect(await registry.sharedUserServers.get('oauth_server')).toBeDefined();
|
||||
expect(await registryStatusCache.isInitialized()).toBe(true);
|
||||
});
|
||||
|
||||
it('should skip initialization if already initialized (Redis flag)', async () => {
|
||||
// First initialization
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Second initialization should skip due to Redis cache flag
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
expect(mockInspect).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should process all server configs through inspector', async () => {
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
// Verify all configs were processed by inspector (without connection parameter)
|
||||
expect(mockInspect).toHaveBeenCalledTimes(4);
|
||||
expect(mockInspect).toHaveBeenCalledWith('disabled_server', testConfigs.disabled_server);
|
||||
expect(mockInspect).toHaveBeenCalledWith('oauth_server', testConfigs.oauth_server);
|
||||
expect(mockInspect).toHaveBeenCalledWith('file_tools_server', testConfigs.file_tools_server);
|
||||
expect(mockInspect).toHaveBeenCalledWith(
|
||||
'search_tools_server',
|
||||
testConfigs.search_tools_server,
|
||||
);
|
||||
});
|
||||
|
||||
it('should add disabled servers to sharedUserServers', async () => {
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
const disabledServer = await registry.sharedUserServers.get('disabled_server');
|
||||
expect(disabledServer).toBeDefined();
|
||||
expect(disabledServer).toMatchObject({
|
||||
...testParsedConfigs.disabled_server,
|
||||
_processedByInspector: true,
|
||||
});
|
||||
});
|
||||
|
||||
it('should add OAuth servers to sharedUserServers', async () => {
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
const oauthServer = await registry.sharedUserServers.get('oauth_server');
|
||||
expect(oauthServer).toBeDefined();
|
||||
expect(oauthServer).toMatchObject({
|
||||
...testParsedConfigs.oauth_server,
|
||||
_processedByInspector: true,
|
||||
});
|
||||
});
|
||||
|
||||
it('should add enabled non-OAuth servers to sharedAppServers', async () => {
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
const fileToolsServer = await registry.sharedAppServers.get('file_tools_server');
|
||||
expect(fileToolsServer).toBeDefined();
|
||||
expect(fileToolsServer).toMatchObject({
|
||||
...testParsedConfigs.file_tools_server,
|
||||
_processedByInspector: true,
|
||||
});
|
||||
|
||||
const searchToolsServer = await registry.sharedAppServers.get('search_tools_server');
|
||||
expect(searchToolsServer).toBeDefined();
|
||||
expect(searchToolsServer).toMatchObject({
|
||||
...testParsedConfigs.search_tools_server,
|
||||
_processedByInspector: true,
|
||||
});
|
||||
});
|
||||
|
||||
it('should successfully initialize all servers', async () => {
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
// Verify all servers were added to appropriate registries
|
||||
expect(await registry.sharedUserServers.get('disabled_server')).toBeDefined();
|
||||
expect(await registry.sharedUserServers.get('oauth_server')).toBeDefined();
|
||||
expect(await registry.sharedAppServers.get('file_tools_server')).toBeDefined();
|
||||
expect(await registry.sharedAppServers.get('search_tools_server')).toBeDefined();
|
||||
});
|
||||
|
||||
it('should handle inspection failures gracefully', async () => {
|
||||
// Mock inspection failure for one server
|
||||
mockInspect.mockImplementation(async (serverName: string) => {
|
||||
if (serverName === 'file_tools_server') {
|
||||
throw new Error('Inspection failed');
|
||||
}
|
||||
return {
|
||||
...testParsedConfigs[serverName],
|
||||
_processedByInspector: true,
|
||||
} as unknown as t.ParsedServerConfig;
|
||||
});
|
||||
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
// Verify other servers were still processed
|
||||
const disabledServer = await registry.sharedUserServers.get('disabled_server');
|
||||
expect(disabledServer).toBeDefined();
|
||||
|
||||
const oauthServer = await registry.sharedUserServers.get('oauth_server');
|
||||
expect(oauthServer).toBeDefined();
|
||||
|
||||
const searchToolsServer = await registry.sharedAppServers.get('search_tools_server');
|
||||
expect(searchToolsServer).toBeDefined();
|
||||
|
||||
// Verify file_tools_server was not added (due to inspection failure)
|
||||
const fileToolsServer = await registry.sharedAppServers.get('file_tools_server');
|
||||
expect(fileToolsServer).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should log server configuration after initialization', async () => {
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
// Verify logging occurred for each server
|
||||
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||
expect.stringContaining('[MCP][disabled_server]'),
|
||||
);
|
||||
expect(mockLogger.info).toHaveBeenCalledWith(expect.stringContaining('[MCP][oauth_server]'));
|
||||
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||
expect.stringContaining('[MCP][file_tools_server]'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should use Promise.allSettled for parallel server initialization', async () => {
|
||||
const allSettledSpy = jest.spyOn(Promise, 'allSettled');
|
||||
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
expect(allSettledSpy).toHaveBeenCalledWith(expect.arrayContaining([expect.any(Promise)]));
|
||||
expect(allSettledSpy).toHaveBeenCalledTimes(1);
|
||||
|
||||
allSettledSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should set initialized status after completion', async () => {
|
||||
await MCPServersInitializer.initialize(testConfigs);
|
||||
|
||||
expect(await registryStatusCache.isInitialized()).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,227 @@
|
|||
import { expect } from '@playwright/test';
|
||||
import type * as t from '~/mcp/types';
|
||||
|
||||
/**
|
||||
* Integration tests for MCPServersRegistry using Redis-backed cache.
|
||||
* For unit tests using in-memory cache, see MCPServersRegistry.test.ts
|
||||
*/
|
||||
describe('MCPServersRegistry Redis Integration Tests', () => {
|
||||
let registry: typeof import('../MCPServersRegistry').mcpServersRegistry;
|
||||
let keyvRedisClient: Awaited<typeof import('~/cache/redisClients')>['keyvRedisClient'];
|
||||
let LeaderElection: typeof import('~/cluster/LeaderElection').LeaderElection;
|
||||
let leaderInstance: InstanceType<typeof import('~/cluster/LeaderElection').LeaderElection>;
|
||||
|
||||
const testParsedConfig: t.ParsedServerConfig = {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['tools.js'],
|
||||
requiresOAuth: false,
|
||||
serverInstructions: 'Instructions for file_tools_server',
|
||||
tools: 'file_read, file_write',
|
||||
capabilities: '{"tools":{"listChanged":true}}',
|
||||
toolFunctions: {
|
||||
file_read_mcp_file_tools_server: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'file_read_mcp_file_tools_server',
|
||||
description: 'Read a file',
|
||||
parameters: { type: 'object' },
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
beforeAll(async () => {
|
||||
// Set up environment variables for Redis (only if not already set)
|
||||
process.env.USE_REDIS = process.env.USE_REDIS ?? 'true';
|
||||
process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379';
|
||||
process.env.REDIS_KEY_PREFIX =
|
||||
process.env.REDIS_KEY_PREFIX ?? 'MCPServersRegistry-IntegrationTest';
|
||||
|
||||
// Import modules after setting env vars
|
||||
const registryModule = await import('../MCPServersRegistry');
|
||||
const redisClients = await import('~/cache/redisClients');
|
||||
const leaderElectionModule = await import('~/cluster/LeaderElection');
|
||||
|
||||
registry = registryModule.mcpServersRegistry;
|
||||
keyvRedisClient = redisClients.keyvRedisClient;
|
||||
LeaderElection = leaderElectionModule.LeaderElection;
|
||||
|
||||
// Ensure Redis is connected
|
||||
if (!keyvRedisClient) throw new Error('Redis client is not initialized');
|
||||
|
||||
// Wait for Redis to be ready
|
||||
if (!keyvRedisClient.isOpen) await keyvRedisClient.connect();
|
||||
|
||||
// Become leader so we can perform write operations
|
||||
leaderInstance = new LeaderElection();
|
||||
const isLeader = await leaderInstance.isLeader();
|
||||
expect(isLeader).toBe(true);
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
// Clean up: reset registry to clear all test data
|
||||
await registry.reset();
|
||||
|
||||
// Also clean up any remaining test keys from Redis
|
||||
if (keyvRedisClient) {
|
||||
const pattern = '*MCPServersRegistry-IntegrationTest*';
|
||||
if ('scanIterator' in keyvRedisClient) {
|
||||
for await (const key of keyvRedisClient.scanIterator({ MATCH: pattern })) {
|
||||
await keyvRedisClient.del(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
// Resign as leader
|
||||
if (leaderInstance) await leaderInstance.resign();
|
||||
|
||||
// Close Redis connection
|
||||
if (keyvRedisClient?.isOpen) await keyvRedisClient.disconnect();
|
||||
});
|
||||
|
||||
describe('private user servers', () => {
|
||||
it('should add and remove private user server', async () => {
|
||||
const userId = 'user123';
|
||||
const serverName = 'private_server';
|
||||
|
||||
// Add private user server
|
||||
await registry.addPrivateUserServer(userId, serverName, testParsedConfig);
|
||||
|
||||
// Verify server was added
|
||||
const retrievedConfig = await registry.getServerConfig(serverName, userId);
|
||||
expect(retrievedConfig).toEqual(testParsedConfig);
|
||||
|
||||
// Remove private user server
|
||||
await registry.removePrivateUserServer(userId, serverName);
|
||||
|
||||
// Verify server was removed
|
||||
const configAfterRemoval = await registry.getServerConfig(serverName, userId);
|
||||
expect(configAfterRemoval).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should throw error when adding duplicate private user server', async () => {
|
||||
const userId = 'user123';
|
||||
const serverName = 'private_server';
|
||||
|
||||
await registry.addPrivateUserServer(userId, serverName, testParsedConfig);
|
||||
await expect(
|
||||
registry.addPrivateUserServer(userId, serverName, testParsedConfig),
|
||||
).rejects.toThrow(
|
||||
'Server "private_server" already exists in cache. Use update() to modify existing configs.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should update an existing private user server', async () => {
|
||||
const userId = 'user123';
|
||||
const serverName = 'private_server';
|
||||
const updatedConfig: t.ParsedServerConfig = {
|
||||
type: 'stdio',
|
||||
command: 'python',
|
||||
args: ['updated.py'],
|
||||
requiresOAuth: true,
|
||||
};
|
||||
|
||||
// Add private user server
|
||||
await registry.addPrivateUserServer(userId, serverName, testParsedConfig);
|
||||
|
||||
// Update the server config
|
||||
await registry.updatePrivateUserServer(userId, serverName, updatedConfig);
|
||||
|
||||
// Verify server was updated
|
||||
const retrievedConfig = await registry.getServerConfig(serverName, userId);
|
||||
expect(retrievedConfig).toEqual(updatedConfig);
|
||||
});
|
||||
|
||||
it('should throw error when updating non-existent server', async () => {
|
||||
const userId = 'user123';
|
||||
const serverName = 'private_server';
|
||||
|
||||
// Add a user cache first
|
||||
await registry.addPrivateUserServer(userId, 'other_server', testParsedConfig);
|
||||
|
||||
await expect(
|
||||
registry.updatePrivateUserServer(userId, serverName, testParsedConfig),
|
||||
).rejects.toThrow(
|
||||
'Server "private_server" does not exist in cache. Use add() to create new configs.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw error when updating server for non-existent user', async () => {
|
||||
const userId = 'nonexistent_user';
|
||||
const serverName = 'private_server';
|
||||
|
||||
await expect(
|
||||
registry.updatePrivateUserServer(userId, serverName, testParsedConfig),
|
||||
).rejects.toThrow('No private servers found for user "nonexistent_user".');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getAllServerConfigs', () => {
|
||||
it('should return correct servers based on userId', async () => {
|
||||
// Add servers to all three caches
|
||||
await registry.sharedAppServers.add('app_server', testParsedConfig);
|
||||
await registry.sharedUserServers.add('user_server', testParsedConfig);
|
||||
await registry.addPrivateUserServer('abc', 'abc_private_server', testParsedConfig);
|
||||
await registry.addPrivateUserServer('xyz', 'xyz_private_server', testParsedConfig);
|
||||
|
||||
// Without userId: should return only shared app + shared user servers
|
||||
const configsNoUser = await registry.getAllServerConfigs();
|
||||
expect(Object.keys(configsNoUser)).toHaveLength(2);
|
||||
expect(configsNoUser).toHaveProperty('app_server');
|
||||
expect(configsNoUser).toHaveProperty('user_server');
|
||||
|
||||
// With userId 'abc': should return shared app + shared user + abc's private servers
|
||||
const configsAbc = await registry.getAllServerConfigs('abc');
|
||||
expect(Object.keys(configsAbc)).toHaveLength(3);
|
||||
expect(configsAbc).toHaveProperty('app_server');
|
||||
expect(configsAbc).toHaveProperty('user_server');
|
||||
expect(configsAbc).toHaveProperty('abc_private_server');
|
||||
|
||||
// With userId 'xyz': should return shared app + shared user + xyz's private servers
|
||||
const configsXyz = await registry.getAllServerConfigs('xyz');
|
||||
expect(Object.keys(configsXyz)).toHaveLength(3);
|
||||
expect(configsXyz).toHaveProperty('app_server');
|
||||
expect(configsXyz).toHaveProperty('user_server');
|
||||
expect(configsXyz).toHaveProperty('xyz_private_server');
|
||||
});
|
||||
});
|
||||
|
||||
describe('reset', () => {
|
||||
it('should clear all servers from all caches (shared app, shared user, and private user)', async () => {
|
||||
const userId = 'user123';
|
||||
|
||||
// Add servers to all three caches
|
||||
await registry.sharedAppServers.add('app_server', testParsedConfig);
|
||||
await registry.sharedUserServers.add('user_server', testParsedConfig);
|
||||
await registry.addPrivateUserServer(userId, 'private_server', testParsedConfig);
|
||||
|
||||
// Verify all servers are accessible before reset
|
||||
const appConfigBefore = await registry.getServerConfig('app_server');
|
||||
const userConfigBefore = await registry.getServerConfig('user_server');
|
||||
const privateConfigBefore = await registry.getServerConfig('private_server', userId);
|
||||
const allConfigsBefore = await registry.getAllServerConfigs(userId);
|
||||
|
||||
expect(appConfigBefore).toEqual(testParsedConfig);
|
||||
expect(userConfigBefore).toEqual(testParsedConfig);
|
||||
expect(privateConfigBefore).toEqual(testParsedConfig);
|
||||
expect(Object.keys(allConfigsBefore)).toHaveLength(3);
|
||||
|
||||
// Reset everything
|
||||
await registry.reset();
|
||||
|
||||
// Verify all servers are cleared after reset
|
||||
const appConfigAfter = await registry.getServerConfig('app_server');
|
||||
const userConfigAfter = await registry.getServerConfig('user_server');
|
||||
const privateConfigAfter = await registry.getServerConfig('private_server', userId);
|
||||
const allConfigsAfter = await registry.getAllServerConfigs(userId);
|
||||
|
||||
expect(appConfigAfter).toBeUndefined();
|
||||
expect(userConfigAfter).toBeUndefined();
|
||||
expect(privateConfigAfter).toBeUndefined();
|
||||
expect(Object.keys(allConfigsAfter)).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,175 @@
|
|||
import * as t from '~/mcp/types';
|
||||
import { mcpServersRegistry as registry } from '~/mcp/registry/MCPServersRegistry';
|
||||
|
||||
/**
|
||||
* Unit tests for MCPServersRegistry using in-memory cache.
|
||||
* For integration tests using Redis-backed cache, see MCPServersRegistry.cache_integration.spec.ts
|
||||
*/
|
||||
describe('MCPServersRegistry', () => {
|
||||
const testParsedConfig: t.ParsedServerConfig = {
|
||||
type: 'stdio',
|
||||
command: 'node',
|
||||
args: ['tools.js'],
|
||||
requiresOAuth: false,
|
||||
serverInstructions: 'Instructions for file_tools_server',
|
||||
tools: 'file_read, file_write',
|
||||
capabilities: '{"tools":{"listChanged":true}}',
|
||||
toolFunctions: {
|
||||
file_read_mcp_file_tools_server: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'file_read_mcp_file_tools_server',
|
||||
description: 'Read a file',
|
||||
parameters: { type: 'object' },
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
beforeEach(async () => {
|
||||
await registry.reset();
|
||||
});
|
||||
|
||||
describe('private user servers', () => {
|
||||
it('should add and remove private user server', async () => {
|
||||
const userId = 'user123';
|
||||
const serverName = 'private_server';
|
||||
|
||||
// Add private user server
|
||||
await registry.addPrivateUserServer(userId, serverName, testParsedConfig);
|
||||
|
||||
// Verify server was added
|
||||
const retrievedConfig = await registry.getServerConfig(serverName, userId);
|
||||
expect(retrievedConfig).toEqual(testParsedConfig);
|
||||
|
||||
// Remove private user server
|
||||
await registry.removePrivateUserServer(userId, serverName);
|
||||
|
||||
// Verify server was removed
|
||||
const configAfterRemoval = await registry.getServerConfig(serverName, userId);
|
||||
expect(configAfterRemoval).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should throw error when adding duplicate private user server', async () => {
|
||||
const userId = 'user123';
|
||||
const serverName = 'private_server';
|
||||
|
||||
await registry.addPrivateUserServer(userId, serverName, testParsedConfig);
|
||||
await expect(
|
||||
registry.addPrivateUserServer(userId, serverName, testParsedConfig),
|
||||
).rejects.toThrow(
|
||||
'Server "private_server" already exists in cache. Use update() to modify existing configs.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should update an existing private user server', async () => {
|
||||
const userId = 'user123';
|
||||
const serverName = 'private_server';
|
||||
const updatedConfig: t.ParsedServerConfig = {
|
||||
type: 'stdio',
|
||||
command: 'python',
|
||||
args: ['updated.py'],
|
||||
requiresOAuth: true,
|
||||
};
|
||||
|
||||
// Add private user server
|
||||
await registry.addPrivateUserServer(userId, serverName, testParsedConfig);
|
||||
|
||||
// Update the server config
|
||||
await registry.updatePrivateUserServer(userId, serverName, updatedConfig);
|
||||
|
||||
// Verify server was updated
|
||||
const retrievedConfig = await registry.getServerConfig(serverName, userId);
|
||||
expect(retrievedConfig).toEqual(updatedConfig);
|
||||
});
|
||||
|
||||
it('should throw error when updating non-existent server', async () => {
|
||||
const userId = 'user123';
|
||||
const serverName = 'private_server';
|
||||
|
||||
// Add a user cache first
|
||||
await registry.addPrivateUserServer(userId, 'other_server', testParsedConfig);
|
||||
|
||||
await expect(
|
||||
registry.updatePrivateUserServer(userId, serverName, testParsedConfig),
|
||||
).rejects.toThrow(
|
||||
'Server "private_server" does not exist in cache. Use add() to create new configs.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw error when updating server for non-existent user', async () => {
|
||||
const userId = 'nonexistent_user';
|
||||
const serverName = 'private_server';
|
||||
|
||||
await expect(
|
||||
registry.updatePrivateUserServer(userId, serverName, testParsedConfig),
|
||||
).rejects.toThrow('No private servers found for user "nonexistent_user".');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getAllServerConfigs', () => {
|
||||
it('should return correct servers based on userId', async () => {
|
||||
// Add servers to all three caches
|
||||
await registry.sharedAppServers.add('app_server', testParsedConfig);
|
||||
await registry.sharedUserServers.add('user_server', testParsedConfig);
|
||||
await registry.addPrivateUserServer('abc', 'abc_private_server', testParsedConfig);
|
||||
await registry.addPrivateUserServer('xyz', 'xyz_private_server', testParsedConfig);
|
||||
|
||||
// Without userId: should return only shared app + shared user servers
|
||||
const configsNoUser = await registry.getAllServerConfigs();
|
||||
expect(Object.keys(configsNoUser)).toHaveLength(2);
|
||||
expect(configsNoUser).toHaveProperty('app_server');
|
||||
expect(configsNoUser).toHaveProperty('user_server');
|
||||
|
||||
// With userId 'abc': should return shared app + shared user + abc's private servers
|
||||
const configsAbc = await registry.getAllServerConfigs('abc');
|
||||
expect(Object.keys(configsAbc)).toHaveLength(3);
|
||||
expect(configsAbc).toHaveProperty('app_server');
|
||||
expect(configsAbc).toHaveProperty('user_server');
|
||||
expect(configsAbc).toHaveProperty('abc_private_server');
|
||||
|
||||
// With userId 'xyz': should return shared app + shared user + xyz's private servers
|
||||
const configsXyz = await registry.getAllServerConfigs('xyz');
|
||||
expect(Object.keys(configsXyz)).toHaveLength(3);
|
||||
expect(configsXyz).toHaveProperty('app_server');
|
||||
expect(configsXyz).toHaveProperty('user_server');
|
||||
expect(configsXyz).toHaveProperty('xyz_private_server');
|
||||
});
|
||||
});
|
||||
|
||||
describe('reset', () => {
|
||||
it('should clear all servers from all caches (shared app, shared user, and private user)', async () => {
|
||||
const userId = 'user123';
|
||||
|
||||
// Add servers to all three caches
|
||||
await registry.sharedAppServers.add('app_server', testParsedConfig);
|
||||
await registry.sharedUserServers.add('user_server', testParsedConfig);
|
||||
await registry.addPrivateUserServer(userId, 'private_server', testParsedConfig);
|
||||
|
||||
// Verify all servers are accessible before reset
|
||||
const appConfigBefore = await registry.getServerConfig('app_server');
|
||||
const userConfigBefore = await registry.getServerConfig('user_server');
|
||||
const privateConfigBefore = await registry.getServerConfig('private_server', userId);
|
||||
const allConfigsBefore = await registry.getAllServerConfigs(userId);
|
||||
|
||||
expect(appConfigBefore).toEqual(testParsedConfig);
|
||||
expect(userConfigBefore).toEqual(testParsedConfig);
|
||||
expect(privateConfigBefore).toEqual(testParsedConfig);
|
||||
expect(Object.keys(allConfigsBefore)).toHaveLength(3);
|
||||
|
||||
// Reset everything
|
||||
await registry.reset();
|
||||
|
||||
// Verify all servers are cleared after reset
|
||||
const appConfigAfter = await registry.getServerConfig('app_server');
|
||||
const userConfigAfter = await registry.getServerConfig('user_server');
|
||||
const privateConfigAfter = await registry.getServerConfig('private_server', userId);
|
||||
const allConfigsAfter = await registry.getAllServerConfigs(userId);
|
||||
|
||||
expect(appConfigAfter).toBeUndefined();
|
||||
expect(userConfigAfter).toBeUndefined();
|
||||
expect(privateConfigAfter).toBeUndefined();
|
||||
expect(Object.keys(allConfigsAfter)).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
import type { MCPConnection } from '~/mcp/connection';
|
||||
|
||||
/**
|
||||
* Creates a single mock MCP connection for testing.
|
||||
* The connection has a client with mocked methods that return server-specific data.
|
||||
* @param serverName - Name of the server to create mock connection for
|
||||
* @returns Mocked MCPConnection instance
|
||||
*/
|
||||
export function createMockConnection(serverName: string): jest.Mocked<MCPConnection> {
|
||||
const mockClient = {
|
||||
getInstructions: jest.fn().mockReturnValue(`instructions for ${serverName}`),
|
||||
getServerCapabilities: jest.fn().mockReturnValue({
|
||||
tools: { listChanged: true },
|
||||
resources: { listChanged: true },
|
||||
prompts: { get: `getPrompts for ${serverName}` },
|
||||
}),
|
||||
listTools: jest.fn().mockResolvedValue({
|
||||
tools: [
|
||||
{
|
||||
name: 'listFiles',
|
||||
description: `Description for ${serverName}'s listFiles tool`,
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
input: { type: 'string' },
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
};
|
||||
|
||||
return {
|
||||
client: mockClient,
|
||||
disconnect: jest.fn().mockResolvedValue(undefined),
|
||||
} as unknown as jest.Mocked<MCPConnection>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates mock MCP connections for testing.
|
||||
* Each connection has a client with mocked methods that return server-specific data.
|
||||
* @param serverNames - Array of server names to create mock connections for
|
||||
* @returns Map of server names to mocked MCPConnection instances
|
||||
*/
|
||||
export function createMockConnectionsMap(
|
||||
serverNames: string[],
|
||||
): Map<string, jest.Mocked<MCPConnection>> {
|
||||
const mockConnections = new Map<string, jest.Mocked<MCPConnection>>();
|
||||
|
||||
serverNames.forEach((serverName) => {
|
||||
mockConnections.set(serverName, createMockConnection(serverName));
|
||||
});
|
||||
|
||||
return mockConnections;
|
||||
}
|
||||
26
packages/api/src/mcp/registry/cache/BaseRegistryCache.ts
vendored
Normal file
26
packages/api/src/mcp/registry/cache/BaseRegistryCache.ts
vendored
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
import type Keyv from 'keyv';
|
||||
import { isLeader } from '~/cluster';
|
||||
|
||||
/**
|
||||
* Base class for MCP registry caches that require distributed leader coordination.
|
||||
* Provides helper methods for leader-only operations and success validation.
|
||||
* All concrete implementations must provide their own Keyv cache instance.
|
||||
*/
|
||||
export abstract class BaseRegistryCache {
|
||||
protected readonly PREFIX = 'MCP::ServersRegistry';
|
||||
protected abstract readonly cache: Keyv;
|
||||
|
||||
protected async leaderCheck(action: string): Promise<void> {
|
||||
if (!(await isLeader())) throw new Error(`Only leader can ${action}.`);
|
||||
}
|
||||
|
||||
protected successCheck(action: string, success: boolean): true {
|
||||
if (!success) throw new Error(`Failed to ${action} in cache.`);
|
||||
return true;
|
||||
}
|
||||
|
||||
public async reset(): Promise<void> {
|
||||
await this.leaderCheck(`reset ${this.cache.namespace} cache`);
|
||||
await this.cache.clear();
|
||||
}
|
||||
}
|
||||
37
packages/api/src/mcp/registry/cache/RegistryStatusCache.ts
vendored
Normal file
37
packages/api/src/mcp/registry/cache/RegistryStatusCache.ts
vendored
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
import { standardCache } from '~/cache';
|
||||
import { BaseRegistryCache } from './BaseRegistryCache';
|
||||
|
||||
// Status keys
|
||||
const INITIALIZED = 'INITIALIZED';
|
||||
|
||||
/**
|
||||
* Cache for tracking MCP Servers Registry metadata and status across distributed instances.
|
||||
* Uses Redis-backed storage to coordinate state between leader and follower nodes.
|
||||
* Currently, tracks initialization status to ensure only the leader performs initialization
|
||||
* while followers wait for completion. Designed to be extended with additional registry
|
||||
* metadata as needed (e.g., last update timestamps, version info, health status).
|
||||
* This cache is only meant to be used internally by registry management components.
|
||||
*/
|
||||
class RegistryStatusCache extends BaseRegistryCache {
|
||||
protected readonly cache = standardCache(`${this.PREFIX}::Status`);
|
||||
|
||||
public async isInitialized(): Promise<boolean> {
|
||||
return (await this.get(INITIALIZED)) === true;
|
||||
}
|
||||
|
||||
public async setInitialized(value: boolean): Promise<void> {
|
||||
await this.set(INITIALIZED, value);
|
||||
}
|
||||
|
||||
private async get<T = unknown>(key: string): Promise<T | undefined> {
|
||||
return this.cache.get(key);
|
||||
}
|
||||
|
||||
private async set(key: string, value: string | number | boolean, ttl?: number): Promise<void> {
|
||||
await this.leaderCheck('set MCP Servers Registry status');
|
||||
const success = await this.cache.set(key, value, ttl);
|
||||
this.successCheck(`set status key "${key}"`, success);
|
||||
}
|
||||
}
|
||||
|
||||
export const registryStatusCache = new RegistryStatusCache();
|
||||
31
packages/api/src/mcp/registry/cache/ServerConfigsCacheFactory.ts
vendored
Normal file
31
packages/api/src/mcp/registry/cache/ServerConfigsCacheFactory.ts
vendored
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
import { cacheConfig } from '~/cache';
|
||||
import { ServerConfigsCacheInMemory } from './ServerConfigsCacheInMemory';
|
||||
import { ServerConfigsCacheRedis } from './ServerConfigsCacheRedis';
|
||||
|
||||
export type ServerConfigsCache = ServerConfigsCacheInMemory | ServerConfigsCacheRedis;
|
||||
|
||||
/**
|
||||
* Factory for creating the appropriate ServerConfigsCache implementation based on deployment mode.
|
||||
* Automatically selects between in-memory and Redis-backed storage depending on USE_REDIS config.
|
||||
* In single-instance mode (USE_REDIS=false), returns lightweight in-memory cache.
|
||||
* In cluster mode (USE_REDIS=true), returns Redis-backed cache with distributed coordination.
|
||||
* Provides a unified interface regardless of the underlying storage mechanism.
|
||||
*/
|
||||
export class ServerConfigsCacheFactory {
|
||||
/**
|
||||
* Create a ServerConfigsCache instance.
|
||||
* Returns Redis implementation if Redis is configured, otherwise in-memory implementation.
|
||||
*
|
||||
* @param owner - The owner of the cache (e.g., 'user', 'global') - only used for Redis namespacing
|
||||
* @param leaderOnly - Whether operations should only be performed by the leader (only applies to Redis)
|
||||
* @returns ServerConfigsCache instance
|
||||
*/
|
||||
static create(owner: string, leaderOnly: boolean): ServerConfigsCache {
|
||||
if (cacheConfig.USE_REDIS) {
|
||||
return new ServerConfigsCacheRedis(owner, leaderOnly);
|
||||
}
|
||||
|
||||
// In-memory mode uses a simple Map - doesn't need owner/namespace
|
||||
return new ServerConfigsCacheInMemory();
|
||||
}
|
||||
}
|
||||
46
packages/api/src/mcp/registry/cache/ServerConfigsCacheInMemory.ts
vendored
Normal file
46
packages/api/src/mcp/registry/cache/ServerConfigsCacheInMemory.ts
vendored
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
import { ParsedServerConfig } from '~/mcp/types';
|
||||
|
||||
/**
|
||||
* In-memory implementation of MCP server configurations cache for single-instance deployments.
|
||||
* Uses a native JavaScript Map for fast, local storage without Redis dependencies.
|
||||
* Suitable for development environments or single-server production deployments.
|
||||
* Does not require leader checks or distributed coordination since data is instance-local.
|
||||
* Data is lost on server restart and not shared across multiple server instances.
|
||||
*/
|
||||
export class ServerConfigsCacheInMemory {
|
||||
private readonly cache: Map<string, ParsedServerConfig> = new Map();
|
||||
|
||||
public async add(serverName: string, config: ParsedServerConfig): Promise<void> {
|
||||
if (this.cache.has(serverName))
|
||||
throw new Error(
|
||||
`Server "${serverName}" already exists in cache. Use update() to modify existing configs.`,
|
||||
);
|
||||
this.cache.set(serverName, config);
|
||||
}
|
||||
|
||||
public async update(serverName: string, config: ParsedServerConfig): Promise<void> {
|
||||
if (!this.cache.has(serverName))
|
||||
throw new Error(
|
||||
`Server "${serverName}" does not exist in cache. Use add() to create new configs.`,
|
||||
);
|
||||
this.cache.set(serverName, config);
|
||||
}
|
||||
|
||||
public async remove(serverName: string): Promise<void> {
|
||||
if (!this.cache.delete(serverName)) {
|
||||
throw new Error(`Failed to remove server "${serverName}" in cache.`);
|
||||
}
|
||||
}
|
||||
|
||||
public async get(serverName: string): Promise<ParsedServerConfig | undefined> {
|
||||
return this.cache.get(serverName);
|
||||
}
|
||||
|
||||
public async getAll(): Promise<Record<string, ParsedServerConfig>> {
|
||||
return Object.fromEntries(this.cache);
|
||||
}
|
||||
|
||||
public async reset(): Promise<void> {
|
||||
this.cache.clear();
|
||||
}
|
||||
}
|
||||
80
packages/api/src/mcp/registry/cache/ServerConfigsCacheRedis.ts
vendored
Normal file
80
packages/api/src/mcp/registry/cache/ServerConfigsCacheRedis.ts
vendored
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
import type Keyv from 'keyv';
|
||||
import { fromPairs } from 'lodash';
|
||||
import { standardCache, keyvRedisClient } from '~/cache';
|
||||
import { ParsedServerConfig } from '~/mcp/types';
|
||||
import { BaseRegistryCache } from './BaseRegistryCache';
|
||||
|
||||
/**
|
||||
* Redis-backed implementation of MCP server configurations cache for distributed deployments.
|
||||
* Stores server configs in Redis with namespace isolation by owner (App, User, or specific user ID).
|
||||
* Enables data sharing across multiple server instances in a cluster environment.
|
||||
* Supports optional leader-only write operations to prevent race conditions during initialization.
|
||||
* Data persists across server restarts and is accessible from any instance in the cluster.
|
||||
*/
|
||||
export class ServerConfigsCacheRedis extends BaseRegistryCache {
|
||||
protected readonly cache: Keyv;
|
||||
private readonly owner: string;
|
||||
private readonly leaderOnly: boolean;
|
||||
|
||||
constructor(owner: string, leaderOnly: boolean) {
|
||||
super();
|
||||
this.owner = owner;
|
||||
this.leaderOnly = leaderOnly;
|
||||
this.cache = standardCache(`${this.PREFIX}::Servers::${owner}`);
|
||||
}
|
||||
|
||||
public async add(serverName: string, config: ParsedServerConfig): Promise<void> {
|
||||
if (this.leaderOnly) await this.leaderCheck(`add ${this.owner} MCP servers`);
|
||||
const exists = await this.cache.has(serverName);
|
||||
if (exists)
|
||||
throw new Error(
|
||||
`Server "${serverName}" already exists in cache. Use update() to modify existing configs.`,
|
||||
);
|
||||
const success = await this.cache.set(serverName, config);
|
||||
this.successCheck(`add ${this.owner} server "${serverName}"`, success);
|
||||
}
|
||||
|
||||
public async update(serverName: string, config: ParsedServerConfig): Promise<void> {
|
||||
if (this.leaderOnly) await this.leaderCheck(`update ${this.owner} MCP servers`);
|
||||
const exists = await this.cache.has(serverName);
|
||||
if (!exists)
|
||||
throw new Error(
|
||||
`Server "${serverName}" does not exist in cache. Use add() to create new configs.`,
|
||||
);
|
||||
const success = await this.cache.set(serverName, config);
|
||||
this.successCheck(`update ${this.owner} server "${serverName}"`, success);
|
||||
}
|
||||
|
||||
public async remove(serverName: string): Promise<void> {
|
||||
if (this.leaderOnly) await this.leaderCheck(`remove ${this.owner} MCP servers`);
|
||||
const success = await this.cache.delete(serverName);
|
||||
this.successCheck(`remove ${this.owner} server "${serverName}"`, success);
|
||||
}
|
||||
|
||||
public async get(serverName: string): Promise<ParsedServerConfig | undefined> {
|
||||
return this.cache.get(serverName);
|
||||
}
|
||||
|
||||
public async getAll(): Promise<Record<string, ParsedServerConfig>> {
|
||||
// Use Redis SCAN iterator directly (non-blocking, production-ready)
|
||||
// Note: Keyv uses a single colon ':' between namespace and key, even if GLOBAL_PREFIX_SEPARATOR is '::'
|
||||
const pattern = `*${this.cache.namespace}:*`;
|
||||
const entries: Array<[string, ParsedServerConfig]> = [];
|
||||
|
||||
// Use scanIterator from Redis client
|
||||
if (keyvRedisClient && 'scanIterator' in keyvRedisClient) {
|
||||
for await (const key of keyvRedisClient.scanIterator({ MATCH: pattern })) {
|
||||
// Extract the actual key name (last part after final colon)
|
||||
// Full key format: "prefix::namespace:keyName"
|
||||
const lastColonIndex = key.lastIndexOf(':');
|
||||
const keyName = key.substring(lastColonIndex + 1);
|
||||
const value = await this.cache.get(keyName);
|
||||
if (value) {
|
||||
entries.push([keyName, value as ParsedServerConfig]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fromPairs(entries);
|
||||
}
|
||||
}
|
||||
73
packages/api/src/mcp/registry/cache/__tests__/RegistryStatusCache.cache_integration.spec.ts
vendored
Normal file
73
packages/api/src/mcp/registry/cache/__tests__/RegistryStatusCache.cache_integration.spec.ts
vendored
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
import { expect } from '@playwright/test';
|
||||
|
||||
describe('RegistryStatusCache Integration Tests', () => {
|
||||
let registryStatusCache: typeof import('../RegistryStatusCache').registryStatusCache;
|
||||
let keyvRedisClient: Awaited<typeof import('~/cache/redisClients')>['keyvRedisClient'];
|
||||
let LeaderElection: typeof import('~/cluster/LeaderElection').LeaderElection;
|
||||
let leaderInstance: InstanceType<typeof import('~/cluster/LeaderElection').LeaderElection>;
|
||||
|
||||
beforeAll(async () => {
|
||||
// Set up environment variables for Redis (only if not already set)
|
||||
process.env.USE_REDIS = process.env.USE_REDIS ?? 'true';
|
||||
process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379';
|
||||
process.env.REDIS_KEY_PREFIX =
|
||||
process.env.REDIS_KEY_PREFIX ?? 'RegistryStatusCache-IntegrationTest';
|
||||
|
||||
// Import modules after setting env vars
|
||||
const statusCacheModule = await import('../RegistryStatusCache');
|
||||
const redisClients = await import('~/cache/redisClients');
|
||||
const leaderElectionModule = await import('~/cluster/LeaderElection');
|
||||
|
||||
registryStatusCache = statusCacheModule.registryStatusCache;
|
||||
keyvRedisClient = redisClients.keyvRedisClient;
|
||||
LeaderElection = leaderElectionModule.LeaderElection;
|
||||
|
||||
// Ensure Redis is connected
|
||||
if (!keyvRedisClient) throw new Error('Redis client is not initialized');
|
||||
|
||||
// Wait for Redis to be ready
|
||||
if (!keyvRedisClient.isOpen) await keyvRedisClient.connect();
|
||||
|
||||
// Become leader so we can perform write operations
|
||||
leaderInstance = new LeaderElection();
|
||||
const isLeader = await leaderInstance.isLeader();
|
||||
expect(isLeader).toBe(true);
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
// Clean up: clear all test keys from Redis
|
||||
if (keyvRedisClient) {
|
||||
const pattern = '*RegistryStatusCache-IntegrationTest*';
|
||||
if ('scanIterator' in keyvRedisClient) {
|
||||
for await (const key of keyvRedisClient.scanIterator({ MATCH: pattern })) {
|
||||
await keyvRedisClient.del(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
// Resign as leader
|
||||
if (leaderInstance) await leaderInstance.resign();
|
||||
|
||||
// Close Redis connection
|
||||
if (keyvRedisClient?.isOpen) await keyvRedisClient.disconnect();
|
||||
});
|
||||
|
||||
describe('Initialization status tracking', () => {
|
||||
it('should return false for isInitialized when not set', async () => {
|
||||
const initialized = await registryStatusCache.isInitialized();
|
||||
expect(initialized).toBe(false);
|
||||
});
|
||||
|
||||
it('should set and get initialized status', async () => {
|
||||
await registryStatusCache.setInitialized(true);
|
||||
const initialized = await registryStatusCache.isInitialized();
|
||||
expect(initialized).toBe(true);
|
||||
|
||||
await registryStatusCache.setInitialized(false);
|
||||
const uninitialized = await registryStatusCache.isInitialized();
|
||||
expect(uninitialized).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
70
packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheFactory.test.ts
vendored
Normal file
70
packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheFactory.test.ts
vendored
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
import { ServerConfigsCacheFactory } from '../ServerConfigsCacheFactory';
|
||||
import { ServerConfigsCacheInMemory } from '../ServerConfigsCacheInMemory';
|
||||
import { ServerConfigsCacheRedis } from '../ServerConfigsCacheRedis';
|
||||
import { cacheConfig } from '~/cache';
|
||||
|
||||
// Mock the cache implementations
|
||||
jest.mock('../ServerConfigsCacheInMemory');
|
||||
jest.mock('../ServerConfigsCacheRedis');
|
||||
|
||||
// Mock the cache config module
|
||||
jest.mock('~/cache', () => ({
|
||||
cacheConfig: {
|
||||
USE_REDIS: false,
|
||||
},
|
||||
}));
|
||||
|
||||
describe('ServerConfigsCacheFactory', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('create()', () => {
|
||||
it('should return ServerConfigsCacheRedis when USE_REDIS is true', () => {
|
||||
// Arrange
|
||||
cacheConfig.USE_REDIS = true;
|
||||
|
||||
// Act
|
||||
const cache = ServerConfigsCacheFactory.create('TestOwner', true);
|
||||
|
||||
// Assert
|
||||
expect(cache).toBeInstanceOf(ServerConfigsCacheRedis);
|
||||
expect(ServerConfigsCacheRedis).toHaveBeenCalledWith('TestOwner', true);
|
||||
});
|
||||
|
||||
it('should return ServerConfigsCacheInMemory when USE_REDIS is false', () => {
|
||||
// Arrange
|
||||
cacheConfig.USE_REDIS = false;
|
||||
|
||||
// Act
|
||||
const cache = ServerConfigsCacheFactory.create('TestOwner', false);
|
||||
|
||||
// Assert
|
||||
expect(cache).toBeInstanceOf(ServerConfigsCacheInMemory);
|
||||
expect(ServerConfigsCacheInMemory).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should pass correct parameters to ServerConfigsCacheRedis', () => {
|
||||
// Arrange
|
||||
cacheConfig.USE_REDIS = true;
|
||||
|
||||
// Act
|
||||
ServerConfigsCacheFactory.create('App', true);
|
||||
|
||||
// Assert
|
||||
expect(ServerConfigsCacheRedis).toHaveBeenCalledWith('App', true);
|
||||
});
|
||||
|
||||
it('should create ServerConfigsCacheInMemory without parameters when USE_REDIS is false', () => {
|
||||
// Arrange
|
||||
cacheConfig.USE_REDIS = false;
|
||||
|
||||
// Act
|
||||
ServerConfigsCacheFactory.create('User', false);
|
||||
|
||||
// Assert
|
||||
// In-memory cache doesn't use owner/leaderOnly parameters
|
||||
expect(ServerConfigsCacheInMemory).toHaveBeenCalledWith();
|
||||
});
|
||||
});
|
||||
});
|
||||
173
packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheInMemory.test.ts
vendored
Normal file
173
packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheInMemory.test.ts
vendored
Normal file
|
|
@ -0,0 +1,173 @@
|
|||
import { expect } from '@playwright/test';
|
||||
import { ParsedServerConfig } from '~/mcp/types';
|
||||
|
||||
describe('ServerConfigsCacheInMemory Integration Tests', () => {
|
||||
let ServerConfigsCacheInMemory: typeof import('../ServerConfigsCacheInMemory').ServerConfigsCacheInMemory;
|
||||
let cache: InstanceType<
|
||||
typeof import('../ServerConfigsCacheInMemory').ServerConfigsCacheInMemory
|
||||
>;
|
||||
|
||||
// Test data
|
||||
const mockConfig1: ParsedServerConfig = {
|
||||
command: 'node',
|
||||
args: ['server1.js'],
|
||||
env: { TEST: 'value1' },
|
||||
};
|
||||
|
||||
const mockConfig2: ParsedServerConfig = {
|
||||
command: 'python',
|
||||
args: ['server2.py'],
|
||||
env: { TEST: 'value2' },
|
||||
};
|
||||
|
||||
const mockConfig3: ParsedServerConfig = {
|
||||
command: 'node',
|
||||
args: ['server3.js'],
|
||||
url: 'http://localhost:3000',
|
||||
requiresOAuth: true,
|
||||
};
|
||||
|
||||
beforeAll(async () => {
|
||||
// Import modules
|
||||
const cacheModule = await import('../ServerConfigsCacheInMemory');
|
||||
ServerConfigsCacheInMemory = cacheModule.ServerConfigsCacheInMemory;
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
// Create a fresh instance for each test
|
||||
cache = new ServerConfigsCacheInMemory();
|
||||
});
|
||||
|
||||
describe('add and get operations', () => {
|
||||
it('should add and retrieve a server config', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
const result = await cache.get('server1');
|
||||
expect(result).toEqual(mockConfig1);
|
||||
});
|
||||
|
||||
it('should return undefined for non-existent server', async () => {
|
||||
const result = await cache.get('non-existent');
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should throw error when adding duplicate server', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
await expect(cache.add('server1', mockConfig2)).rejects.toThrow(
|
||||
'Server "server1" already exists in cache. Use update() to modify existing configs.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle multiple server configs', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
await cache.add('server2', mockConfig2);
|
||||
await cache.add('server3', mockConfig3);
|
||||
|
||||
const result1 = await cache.get('server1');
|
||||
const result2 = await cache.get('server2');
|
||||
const result3 = await cache.get('server3');
|
||||
|
||||
expect(result1).toEqual(mockConfig1);
|
||||
expect(result2).toEqual(mockConfig2);
|
||||
expect(result3).toEqual(mockConfig3);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getAll operation', () => {
|
||||
it('should return empty object when no servers exist', async () => {
|
||||
const result = await cache.getAll();
|
||||
expect(result).toEqual({});
|
||||
});
|
||||
|
||||
it('should return all server configs', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
await cache.add('server2', mockConfig2);
|
||||
await cache.add('server3', mockConfig3);
|
||||
|
||||
const result = await cache.getAll();
|
||||
expect(result).toEqual({
|
||||
server1: mockConfig1,
|
||||
server2: mockConfig2,
|
||||
server3: mockConfig3,
|
||||
});
|
||||
});
|
||||
|
||||
it('should reflect updates in getAll', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
await cache.add('server2', mockConfig2);
|
||||
|
||||
let result = await cache.getAll();
|
||||
expect(Object.keys(result).length).toBe(2);
|
||||
|
||||
await cache.add('server3', mockConfig3);
|
||||
result = await cache.getAll();
|
||||
expect(Object.keys(result).length).toBe(3);
|
||||
expect(result.server3).toEqual(mockConfig3);
|
||||
});
|
||||
});
|
||||
|
||||
describe('update operation', () => {
|
||||
it('should update an existing server config', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
expect(await cache.get('server1')).toEqual(mockConfig1);
|
||||
|
||||
await cache.update('server1', mockConfig2);
|
||||
const result = await cache.get('server1');
|
||||
expect(result).toEqual(mockConfig2);
|
||||
});
|
||||
|
||||
it('should throw error when updating non-existent server', async () => {
|
||||
await expect(cache.update('non-existent', mockConfig1)).rejects.toThrow(
|
||||
'Server "non-existent" does not exist in cache. Use add() to create new configs.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should reflect updates in getAll', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
await cache.add('server2', mockConfig2);
|
||||
|
||||
await cache.update('server1', mockConfig3);
|
||||
const result = await cache.getAll();
|
||||
expect(result.server1).toEqual(mockConfig3);
|
||||
expect(result.server2).toEqual(mockConfig2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('remove operation', () => {
|
||||
it('should remove an existing server config', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
expect(await cache.get('server1')).toEqual(mockConfig1);
|
||||
|
||||
await cache.remove('server1');
|
||||
expect(await cache.get('server1')).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should throw error when removing non-existent server', async () => {
|
||||
await expect(cache.remove('non-existent')).rejects.toThrow(
|
||||
'Failed to remove server "non-existent" in cache.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should remove server from getAll results', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
await cache.add('server2', mockConfig2);
|
||||
|
||||
let result = await cache.getAll();
|
||||
expect(Object.keys(result).length).toBe(2);
|
||||
|
||||
await cache.remove('server1');
|
||||
result = await cache.getAll();
|
||||
expect(Object.keys(result).length).toBe(1);
|
||||
expect(result.server1).toBeUndefined();
|
||||
expect(result.server2).toEqual(mockConfig2);
|
||||
});
|
||||
|
||||
it('should allow re-adding a removed server', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
await cache.remove('server1');
|
||||
await cache.add('server1', mockConfig3);
|
||||
|
||||
const result = await cache.get('server1');
|
||||
expect(result).toEqual(mockConfig3);
|
||||
});
|
||||
});
|
||||
});
|
||||
278
packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedis.cache_integration.spec.ts
vendored
Normal file
278
packages/api/src/mcp/registry/cache/__tests__/ServerConfigsCacheRedis.cache_integration.spec.ts
vendored
Normal file
|
|
@ -0,0 +1,278 @@
|
|||
import { expect } from '@playwright/test';
|
||||
import { ParsedServerConfig } from '~/mcp/types';
|
||||
|
||||
describe('ServerConfigsCacheRedis Integration Tests', () => {
|
||||
let ServerConfigsCacheRedis: typeof import('../ServerConfigsCacheRedis').ServerConfigsCacheRedis;
|
||||
let keyvRedisClient: Awaited<typeof import('~/cache/redisClients')>['keyvRedisClient'];
|
||||
let LeaderElection: typeof import('~/cluster/LeaderElection').LeaderElection;
|
||||
let checkIsLeader: () => Promise<boolean>;
|
||||
let cache: InstanceType<typeof import('../ServerConfigsCacheRedis').ServerConfigsCacheRedis>;
|
||||
|
||||
// Test data
|
||||
const mockConfig1: ParsedServerConfig = {
|
||||
command: 'node',
|
||||
args: ['server1.js'],
|
||||
env: { TEST: 'value1' },
|
||||
};
|
||||
|
||||
const mockConfig2: ParsedServerConfig = {
|
||||
command: 'python',
|
||||
args: ['server2.py'],
|
||||
env: { TEST: 'value2' },
|
||||
};
|
||||
|
||||
const mockConfig3: ParsedServerConfig = {
|
||||
command: 'node',
|
||||
args: ['server3.js'],
|
||||
url: 'http://localhost:3000',
|
||||
requiresOAuth: true,
|
||||
};
|
||||
|
||||
beforeAll(async () => {
|
||||
// Set up environment variables for Redis (only if not already set)
|
||||
process.env.USE_REDIS = process.env.USE_REDIS ?? 'true';
|
||||
process.env.REDIS_URI = process.env.REDIS_URI ?? 'redis://127.0.0.1:6379';
|
||||
process.env.REDIS_KEY_PREFIX =
|
||||
process.env.REDIS_KEY_PREFIX ?? 'ServerConfigsCacheRedis-IntegrationTest';
|
||||
|
||||
// Import modules after setting env vars
|
||||
const cacheModule = await import('../ServerConfigsCacheRedis');
|
||||
const redisClients = await import('~/cache/redisClients');
|
||||
const leaderElectionModule = await import('~/cluster/LeaderElection');
|
||||
const clusterModule = await import('~/cluster');
|
||||
|
||||
ServerConfigsCacheRedis = cacheModule.ServerConfigsCacheRedis;
|
||||
keyvRedisClient = redisClients.keyvRedisClient;
|
||||
LeaderElection = leaderElectionModule.LeaderElection;
|
||||
checkIsLeader = clusterModule.isLeader;
|
||||
|
||||
// Ensure Redis is connected
|
||||
if (!keyvRedisClient) throw new Error('Redis client is not initialized');
|
||||
|
||||
// Wait for Redis to be ready
|
||||
if (!keyvRedisClient.isOpen) await keyvRedisClient.connect();
|
||||
|
||||
// Clear any existing leader key to ensure clean state
|
||||
await keyvRedisClient.del(LeaderElection.LEADER_KEY);
|
||||
|
||||
// Become leader so we can perform write operations (using default election instance)
|
||||
const isLeader = await checkIsLeader();
|
||||
expect(isLeader).toBe(true);
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
// Create a fresh instance for each test with leaderOnly=true
|
||||
cache = new ServerConfigsCacheRedis('test-user', true);
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
// Clean up: clear all test keys from Redis
|
||||
if (keyvRedisClient) {
|
||||
const pattern = '*ServerConfigsCacheRedis-IntegrationTest*';
|
||||
if ('scanIterator' in keyvRedisClient) {
|
||||
for await (const key of keyvRedisClient.scanIterator({ MATCH: pattern })) {
|
||||
await keyvRedisClient.del(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
// Clear leader key to allow other tests to become leader
|
||||
if (keyvRedisClient) await keyvRedisClient.del(LeaderElection.LEADER_KEY);
|
||||
|
||||
// Close Redis connection
|
||||
if (keyvRedisClient?.isOpen) await keyvRedisClient.disconnect();
|
||||
});
|
||||
|
||||
describe('add and get operations', () => {
|
||||
it('should add and retrieve a server config', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
const result = await cache.get('server1');
|
||||
expect(result).toEqual(mockConfig1);
|
||||
});
|
||||
|
||||
it('should return undefined for non-existent server', async () => {
|
||||
const result = await cache.get('non-existent');
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should throw error when adding duplicate server', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
await expect(cache.add('server1', mockConfig2)).rejects.toThrow(
|
||||
'Server "server1" already exists in cache. Use update() to modify existing configs.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle multiple server configs', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
await cache.add('server2', mockConfig2);
|
||||
await cache.add('server3', mockConfig3);
|
||||
|
||||
const result1 = await cache.get('server1');
|
||||
const result2 = await cache.get('server2');
|
||||
const result3 = await cache.get('server3');
|
||||
|
||||
expect(result1).toEqual(mockConfig1);
|
||||
expect(result2).toEqual(mockConfig2);
|
||||
expect(result3).toEqual(mockConfig3);
|
||||
});
|
||||
|
||||
it('should isolate caches by owner namespace', async () => {
|
||||
const userCache = new ServerConfigsCacheRedis('user1', true);
|
||||
const globalCache = new ServerConfigsCacheRedis('global', true);
|
||||
|
||||
await userCache.add('server1', mockConfig1);
|
||||
await globalCache.add('server1', mockConfig2);
|
||||
|
||||
const userResult = await userCache.get('server1');
|
||||
const globalResult = await globalCache.get('server1');
|
||||
|
||||
expect(userResult).toEqual(mockConfig1);
|
||||
expect(globalResult).toEqual(mockConfig2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getAll operation', () => {
|
||||
it('should return empty object when no servers exist', async () => {
|
||||
const result = await cache.getAll();
|
||||
expect(result).toEqual({});
|
||||
});
|
||||
|
||||
it('should return all server configs', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
await cache.add('server2', mockConfig2);
|
||||
await cache.add('server3', mockConfig3);
|
||||
|
||||
const result = await cache.getAll();
|
||||
expect(result).toEqual({
|
||||
server1: mockConfig1,
|
||||
server2: mockConfig2,
|
||||
server3: mockConfig3,
|
||||
});
|
||||
});
|
||||
|
||||
it('should reflect updates in getAll', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
await cache.add('server2', mockConfig2);
|
||||
|
||||
let result = await cache.getAll();
|
||||
expect(Object.keys(result).length).toBe(2);
|
||||
|
||||
await cache.add('server3', mockConfig3);
|
||||
result = await cache.getAll();
|
||||
expect(Object.keys(result).length).toBe(3);
|
||||
expect(result.server3).toEqual(mockConfig3);
|
||||
});
|
||||
|
||||
it('should only return configs for the specific owner', async () => {
|
||||
const userCache = new ServerConfigsCacheRedis('user1', true);
|
||||
const globalCache = new ServerConfigsCacheRedis('global', true);
|
||||
|
||||
await userCache.add('server1', mockConfig1);
|
||||
await userCache.add('server2', mockConfig2);
|
||||
await globalCache.add('server3', mockConfig3);
|
||||
|
||||
const userResult = await userCache.getAll();
|
||||
const globalResult = await globalCache.getAll();
|
||||
|
||||
expect(Object.keys(userResult).length).toBe(2);
|
||||
expect(Object.keys(globalResult).length).toBe(1);
|
||||
expect(userResult.server1).toEqual(mockConfig1);
|
||||
expect(userResult.server3).toBeUndefined();
|
||||
expect(globalResult.server3).toEqual(mockConfig3);
|
||||
});
|
||||
});
|
||||
|
||||
describe('update operation', () => {
|
||||
it('should update an existing server config', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
expect(await cache.get('server1')).toEqual(mockConfig1);
|
||||
|
||||
await cache.update('server1', mockConfig2);
|
||||
const result = await cache.get('server1');
|
||||
expect(result).toEqual(mockConfig2);
|
||||
});
|
||||
|
||||
it('should throw error when updating non-existent server', async () => {
|
||||
await expect(cache.update('non-existent', mockConfig1)).rejects.toThrow(
|
||||
'Server "non-existent" does not exist in cache. Use add() to create new configs.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should reflect updates in getAll', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
await cache.add('server2', mockConfig2);
|
||||
|
||||
await cache.update('server1', mockConfig3);
|
||||
const result = await cache.getAll();
|
||||
expect(result.server1).toEqual(mockConfig3);
|
||||
expect(result.server2).toEqual(mockConfig2);
|
||||
});
|
||||
|
||||
it('should only update in the specific owner namespace', async () => {
|
||||
const userCache = new ServerConfigsCacheRedis('user1', true);
|
||||
const globalCache = new ServerConfigsCacheRedis('global', true);
|
||||
|
||||
await userCache.add('server1', mockConfig1);
|
||||
await globalCache.add('server1', mockConfig2);
|
||||
|
||||
await userCache.update('server1', mockConfig3);
|
||||
|
||||
expect(await userCache.get('server1')).toEqual(mockConfig3);
|
||||
expect(await globalCache.get('server1')).toEqual(mockConfig2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('remove operation', () => {
|
||||
it('should remove an existing server config', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
expect(await cache.get('server1')).toEqual(mockConfig1);
|
||||
|
||||
await cache.remove('server1');
|
||||
expect(await cache.get('server1')).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should throw error when removing non-existent server', async () => {
|
||||
await expect(cache.remove('non-existent')).rejects.toThrow(
|
||||
'Failed to remove test-user server "non-existent"',
|
||||
);
|
||||
});
|
||||
|
||||
it('should remove server from getAll results', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
await cache.add('server2', mockConfig2);
|
||||
|
||||
let result = await cache.getAll();
|
||||
expect(Object.keys(result).length).toBe(2);
|
||||
|
||||
await cache.remove('server1');
|
||||
result = await cache.getAll();
|
||||
expect(Object.keys(result).length).toBe(1);
|
||||
expect(result.server1).toBeUndefined();
|
||||
expect(result.server2).toEqual(mockConfig2);
|
||||
});
|
||||
|
||||
it('should allow re-adding a removed server', async () => {
|
||||
await cache.add('server1', mockConfig1);
|
||||
await cache.remove('server1');
|
||||
await cache.add('server1', mockConfig3);
|
||||
|
||||
const result = await cache.get('server1');
|
||||
expect(result).toEqual(mockConfig3);
|
||||
});
|
||||
|
||||
it('should only remove from the specific owner namespace', async () => {
|
||||
const userCache = new ServerConfigsCacheRedis('user1', true);
|
||||
const globalCache = new ServerConfigsCacheRedis('global', true);
|
||||
|
||||
await userCache.add('server1', mockConfig1);
|
||||
await globalCache.add('server1', mockConfig2);
|
||||
|
||||
await userCache.remove('server1');
|
||||
|
||||
expect(await userCache.get('server1')).toBeUndefined();
|
||||
expect(await globalCache.get('server1')).toEqual(mockConfig2);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -151,6 +151,8 @@ export type ParsedServerConfig = MCPOptions & {
|
|||
oauthMetadata?: Record<string, unknown> | null;
|
||||
capabilities?: string;
|
||||
tools?: string;
|
||||
toolFunctions?: LCAvailableTools;
|
||||
initDuration?: number;
|
||||
};
|
||||
|
||||
export interface BasicConnectionOptions {
|
||||
|
|
|
|||
|
|
@ -54,6 +54,12 @@ export interface AnthropicConfigOptions {
|
|||
proxy?: string | null;
|
||||
/** URL for a reverse proxy, if used */
|
||||
reverseProxyUrl?: string | null;
|
||||
/** Default parameters to apply only if fields are undefined */
|
||||
defaultParams?: Record<string, unknown>;
|
||||
/** Additional parameters to add to the configuration */
|
||||
addParams?: Record<string, unknown>;
|
||||
/** Parameters to drop/exclude from the configuration */
|
||||
dropParams?: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import type { IMongoFile } from '@librechat/data-schemas';
|
||||
import type { ServerRequest } from './http';
|
||||
import type { Readable } from 'stream';
|
||||
import type { Request } from 'express';
|
||||
export interface STTService {
|
||||
getInstance(): Promise<STTService>;
|
||||
getProviderSchema(req: ServerRequest): Promise<[string, object]>;
|
||||
|
|
@ -131,5 +130,5 @@ export interface ProcessedFile {
|
|||
}
|
||||
|
||||
export interface StrategyFunctions {
|
||||
getDownloadStream: (req: Request, filepath: string) => Promise<Readable>;
|
||||
getDownloadStream: (req: ServerRequest, filepath: string) => Promise<Readable>;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,6 +19,8 @@ export interface GoogleConfigOptions {
|
|||
proxy?: string;
|
||||
streaming?: boolean;
|
||||
authHeader?: boolean;
|
||||
/** Default parameters to apply only if fields are undefined */
|
||||
defaultParams?: Record<string, unknown>;
|
||||
addParams?: Record<string, unknown>;
|
||||
dropParams?: string[];
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,11 +7,12 @@ import type { IUser, AppConfig } from '@librechat/data-schemas';
|
|||
*/
|
||||
export type RequestBody = {
|
||||
messageId?: string;
|
||||
fileTokenLimit?: number;
|
||||
conversationId?: string;
|
||||
parentMessageId?: string;
|
||||
};
|
||||
|
||||
export type ServerRequest = Request & {
|
||||
export type ServerRequest = Request<unknown, unknown, RequestBody> & {
|
||||
user?: IUser;
|
||||
config?: AppConfig;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ export type OpenAIConfiguration = OpenAIClientOptions['configuration'];
|
|||
|
||||
export type OAIClientOptions = OpenAIClientOptions & {
|
||||
include_reasoning?: boolean;
|
||||
_lc_stream_delay?: number;
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -100,10 +101,3 @@ export interface InitializeOpenAIOptionsParams {
|
|||
getUserKeyValues: GetUserKeyValuesFunction;
|
||||
checkUserKeyExpiry: CheckUserKeyExpiryFunction;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extended LLM config result with stream rate handling
|
||||
*/
|
||||
export interface OpenAIOptionsResult extends LLMConfigResult {
|
||||
streamRate?: number;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,6 +27,12 @@ export const genAzureEndpoint = ({
|
|||
azureOpenAIApiInstanceName: string;
|
||||
azureOpenAIApiDeploymentName: string;
|
||||
}): string => {
|
||||
// Support both old (.openai.azure.com) and new (.cognitiveservices.azure.com) endpoint formats
|
||||
// If instanceName already includes a full domain, use it as-is
|
||||
if (azureOpenAIApiInstanceName.includes('.azure.com')) {
|
||||
return `https://${azureOpenAIApiInstanceName}/openai/deployments/${azureOpenAIApiDeploymentName}`;
|
||||
}
|
||||
// Legacy format for backward compatibility
|
||||
return `https://${azureOpenAIApiInstanceName}.openai.azure.com/openai/deployments/${azureOpenAIApiDeploymentName}`;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import { Providers } from '@librechat/agents';
|
||||
import { AuthType } from 'librechat-data-provider';
|
||||
|
||||
/**
|
||||
|
|
@ -49,11 +48,3 @@ export function optionalChainWithEmptyCheck(
|
|||
}
|
||||
return values[values.length - 1];
|
||||
}
|
||||
|
||||
/**
|
||||
* Normalize the endpoint name to system-expected value.
|
||||
* @param name
|
||||
*/
|
||||
export function normalizeEndpointName(name = ''): string {
|
||||
return name.toLowerCase() === Providers.OLLAMA ? Providers.OLLAMA : name;
|
||||
}
|
||||
|
|
|
|||
200
packages/api/src/utils/content.spec.ts
Normal file
200
packages/api/src/utils/content.spec.ts
Normal file
|
|
@ -0,0 +1,200 @@
|
|||
import { ContentTypes, ToolCallTypes } from 'librechat-data-provider';
|
||||
import type { Agents, PartMetadata, TMessageContentParts } from 'librechat-data-provider';
|
||||
import type { ToolCall } from '@langchain/core/messages/tool';
|
||||
import { filterMalformedContentParts } from './content';
|
||||
|
||||
describe('filterMalformedContentParts', () => {
|
||||
describe('basic filtering', () => {
|
||||
it('should keep valid tool_call content parts', () => {
|
||||
const parts: TMessageContentParts[] = [
|
||||
{
|
||||
type: ContentTypes.TOOL_CALL,
|
||||
tool_call: {
|
||||
id: 'test-id',
|
||||
name: 'test_function',
|
||||
type: ToolCallTypes.TOOL_CALL,
|
||||
args: '{}',
|
||||
progress: 1,
|
||||
output: 'result',
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
const result = filterMalformedContentParts(parts);
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0]).toEqual(parts[0]);
|
||||
});
|
||||
|
||||
it('should filter out malformed tool_call content parts without tool_call property', () => {
|
||||
const parts: TMessageContentParts[] = [
|
||||
{ type: ContentTypes.TOOL_CALL } as TMessageContentParts,
|
||||
];
|
||||
|
||||
const result = filterMalformedContentParts(parts);
|
||||
expect(result).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should keep other content types unchanged', () => {
|
||||
const parts: TMessageContentParts[] = [
|
||||
{ type: ContentTypes.TEXT, text: 'Hello world' },
|
||||
{ type: ContentTypes.THINK, think: 'Thinking...' },
|
||||
];
|
||||
|
||||
const result = filterMalformedContentParts(parts);
|
||||
expect(result).toHaveLength(2);
|
||||
expect(result).toEqual(parts);
|
||||
});
|
||||
|
||||
it('should filter out null or undefined parts', () => {
|
||||
const parts = [
|
||||
{ type: ContentTypes.TEXT, text: 'Valid' },
|
||||
null,
|
||||
undefined,
|
||||
{ type: ContentTypes.TEXT, text: 'Also valid' },
|
||||
] as TMessageContentParts[];
|
||||
|
||||
const result = filterMalformedContentParts(parts);
|
||||
expect(result).toHaveLength(2);
|
||||
expect(result[0]).toHaveProperty('text', 'Valid');
|
||||
expect(result[1]).toHaveProperty('text', 'Also valid');
|
||||
});
|
||||
|
||||
it('should return non-array input unchanged', () => {
|
||||
const notAnArray = { some: 'object' };
|
||||
const result = filterMalformedContentParts(notAnArray);
|
||||
expect(result).toBe(notAnArray);
|
||||
});
|
||||
});
|
||||
|
||||
describe('real-life example with multiple tool calls', () => {
|
||||
it('should filter out malformed tool_call entries from actual MCP response', () => {
|
||||
const parts: TMessageContentParts[] = [
|
||||
{
|
||||
type: ContentTypes.THINK,
|
||||
think:
|
||||
'The user is asking for 10 different time zones, similar to what would be displayed in a stock trading room floor.',
|
||||
},
|
||||
{
|
||||
type: ContentTypes.TEXT,
|
||||
text: '# Global Market Times\n\nShowing current time in 10 major financial centers:',
|
||||
tool_call_ids: ['tooluse_Yjfib8PoRXCeCcHRH0JqCw'],
|
||||
},
|
||||
{
|
||||
type: ContentTypes.TOOL_CALL,
|
||||
tool_call: {
|
||||
id: 'tooluse_Yjfib8PoRXCeCcHRH0JqCw',
|
||||
name: 'get_current_time_mcp_time',
|
||||
args: '{"timezone":"America/New_York"}',
|
||||
type: ToolCallTypes.TOOL_CALL,
|
||||
progress: 1,
|
||||
output: '{"timezone":"America/New_York","datetime":"2025-11-13T13:43:17-05:00"}',
|
||||
},
|
||||
},
|
||||
{ type: ContentTypes.TOOL_CALL } as TMessageContentParts,
|
||||
{
|
||||
type: ContentTypes.TOOL_CALL,
|
||||
tool_call: {
|
||||
id: 'tooluse_CPsGv9kXTrewVkcO7BEYIg',
|
||||
name: 'get_current_time_mcp_time',
|
||||
args: '{"timezone":"Europe/London"}',
|
||||
type: ToolCallTypes.TOOL_CALL,
|
||||
progress: 1,
|
||||
output: '{"timezone":"Europe/London","datetime":"2025-11-13T18:43:19+00:00"}',
|
||||
},
|
||||
},
|
||||
{ type: ContentTypes.TOOL_CALL } as TMessageContentParts,
|
||||
{
|
||||
type: ContentTypes.TOOL_CALL,
|
||||
tool_call: {
|
||||
id: 'tooluse_5jihRbd4TDWCGebwmAUlfQ',
|
||||
name: 'get_current_time_mcp_time',
|
||||
args: '{"timezone":"Asia/Tokyo"}',
|
||||
type: ToolCallTypes.TOOL_CALL,
|
||||
progress: 1,
|
||||
output: '{"timezone":"Asia/Tokyo","datetime":"2025-11-14T03:43:21+09:00"}',
|
||||
},
|
||||
},
|
||||
{ type: ContentTypes.TOOL_CALL } as TMessageContentParts,
|
||||
{ type: ContentTypes.TOOL_CALL } as TMessageContentParts,
|
||||
{ type: ContentTypes.TOOL_CALL } as TMessageContentParts,
|
||||
{
|
||||
type: ContentTypes.TEXT,
|
||||
text: '## Major Financial Markets Clock:\n\n| Market | Local Time | Day |',
|
||||
},
|
||||
];
|
||||
|
||||
const result = filterMalformedContentParts(parts);
|
||||
|
||||
expect(result).toHaveLength(6);
|
||||
|
||||
expect(result[0].type).toBe(ContentTypes.THINK);
|
||||
expect(result[1].type).toBe(ContentTypes.TEXT);
|
||||
expect(result[2].type).toBe(ContentTypes.TOOL_CALL);
|
||||
expect(result[3].type).toBe(ContentTypes.TOOL_CALL);
|
||||
expect(result[4].type).toBe(ContentTypes.TOOL_CALL);
|
||||
expect(result[5].type).toBe(ContentTypes.TEXT);
|
||||
|
||||
const toolCalls = result.filter((part) => part.type === ContentTypes.TOOL_CALL);
|
||||
expect(toolCalls).toHaveLength(3);
|
||||
|
||||
toolCalls.forEach((toolCall) => {
|
||||
if (toolCall.type === ContentTypes.TOOL_CALL) {
|
||||
expect(toolCall.tool_call).toBeDefined();
|
||||
expect(toolCall.tool_call).toHaveProperty('id');
|
||||
expect(toolCall.tool_call).toHaveProperty('name');
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle empty array', () => {
|
||||
const result = filterMalformedContentParts([]);
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
|
||||
it('should handle array with only malformed tool calls', () => {
|
||||
const parts = [
|
||||
{ type: ContentTypes.TOOL_CALL },
|
||||
{ type: ContentTypes.TOOL_CALL },
|
||||
{ type: ContentTypes.TOOL_CALL },
|
||||
] as TMessageContentParts[];
|
||||
|
||||
const result = filterMalformedContentParts(parts);
|
||||
expect(result).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('edge cases', () => {
|
||||
it('should filter out tool_call with null tool_call property', () => {
|
||||
const parts = [
|
||||
{ type: ContentTypes.TOOL_CALL, tool_call: null as unknown as ToolCall },
|
||||
] as TMessageContentParts[];
|
||||
|
||||
const result = filterMalformedContentParts(parts);
|
||||
expect(result).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should filter out tool_call with non-object tool_call property', () => {
|
||||
const parts = [
|
||||
{
|
||||
type: ContentTypes.TOOL_CALL,
|
||||
tool_call: 'not an object' as unknown as ToolCall & PartMetadata,
|
||||
},
|
||||
] as TMessageContentParts[];
|
||||
|
||||
const result = filterMalformedContentParts(parts);
|
||||
expect(result).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should keep tool_call with empty object as tool_call', () => {
|
||||
const parts: TMessageContentParts[] = [
|
||||
{
|
||||
type: ContentTypes.TOOL_CALL,
|
||||
tool_call: {} as unknown as Agents.ToolCall & PartMetadata,
|
||||
},
|
||||
];
|
||||
|
||||
const result = filterMalformedContentParts(parts);
|
||||
expect(result).toHaveLength(1);
|
||||
});
|
||||
});
|
||||
});
|
||||
46
packages/api/src/utils/content.ts
Normal file
46
packages/api/src/utils/content.ts
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
import { ContentTypes } from 'librechat-data-provider';
|
||||
import type { TMessageContentParts } from 'librechat-data-provider';
|
||||
|
||||
/**
|
||||
* Filters out malformed tool call content parts that don't have the required tool_call property.
|
||||
* This handles edge cases where tool_call content parts may be created with only a type property
|
||||
* but missing the actual tool_call data.
|
||||
*
|
||||
* @param contentParts - Array of content parts to filter
|
||||
* @returns Filtered array with malformed tool calls removed
|
||||
*
|
||||
* @example
|
||||
* // Removes malformed tool_call without the tool_call property
|
||||
* const parts = [
|
||||
* { type: 'tool_call', tool_call: { id: '123', name: 'test' } }, // valid - kept
|
||||
* { type: 'tool_call' }, // invalid - filtered out
|
||||
* { type: 'text', text: 'Hello' }, // valid - kept (other types pass through)
|
||||
* ];
|
||||
* const filtered = filterMalformedContentParts(parts);
|
||||
* // Returns all parts except the malformed tool_call
|
||||
*/
|
||||
export function filterMalformedContentParts(
|
||||
contentParts: TMessageContentParts[],
|
||||
): TMessageContentParts[];
|
||||
export function filterMalformedContentParts<T>(contentParts: T): T;
|
||||
export function filterMalformedContentParts<T>(
|
||||
contentParts: T | TMessageContentParts[],
|
||||
): T | TMessageContentParts[] {
|
||||
if (!Array.isArray(contentParts)) {
|
||||
return contentParts;
|
||||
}
|
||||
|
||||
return contentParts.filter((part) => {
|
||||
if (!part || typeof part !== 'object') {
|
||||
return false;
|
||||
}
|
||||
|
||||
const { type } = part;
|
||||
|
||||
if (type === ContentTypes.TOOL_CALL) {
|
||||
return 'tool_call' in part && part.tool_call != null && typeof part.tool_call === 'object';
|
||||
}
|
||||
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
|
@ -1,15 +1,18 @@
|
|||
export * from './axios';
|
||||
export * from './azure';
|
||||
export * from './common';
|
||||
export * from './content';
|
||||
export * from './email';
|
||||
export * from './env';
|
||||
export * from './events';
|
||||
export * from './files';
|
||||
export * from './generators';
|
||||
export * from './key';
|
||||
export * from './latex';
|
||||
export * from './llm';
|
||||
export * from './math';
|
||||
export * from './openid';
|
||||
export * from './promise';
|
||||
export * from './sanitizeTitle';
|
||||
export * from './tempChatRetention';
|
||||
export * from './text';
|
||||
|
|
|
|||
122
packages/api/src/utils/latex.spec.ts
Normal file
122
packages/api/src/utils/latex.spec.ts
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
import { unescapeLaTeX } from './latex';
|
||||
|
||||
describe('unescapeLaTeX', () => {
|
||||
describe('currency dollar signs', () => {
|
||||
it('should unescape single backslash dollar signs', () => {
|
||||
const input = 'Price: \\$14';
|
||||
const expected = 'Price: $14';
|
||||
expect(unescapeLaTeX(input)).toBe(expected);
|
||||
});
|
||||
|
||||
it('should unescape double backslash dollar signs', () => {
|
||||
const input = 'Price: \\\\$14';
|
||||
const expected = 'Price: $14';
|
||||
expect(unescapeLaTeX(input)).toBe(expected);
|
||||
});
|
||||
|
||||
it('should unescape multiple currency values', () => {
|
||||
const input = '**Crispy Calamari** - *\\\\$14*\n**Truffle Fries** - *\\\\$12*';
|
||||
const expected = '**Crispy Calamari** - *$14*\n**Truffle Fries** - *$12*';
|
||||
expect(unescapeLaTeX(input)).toBe(expected);
|
||||
});
|
||||
|
||||
it('should handle currency with commas and decimals', () => {
|
||||
const input = 'Total: \\\\$1,234.56';
|
||||
const expected = 'Total: $1,234.56';
|
||||
expect(unescapeLaTeX(input)).toBe(expected);
|
||||
});
|
||||
});
|
||||
|
||||
describe('mhchem notation', () => {
|
||||
it('should unescape mhchem ce notation', () => {
|
||||
const input = '$$\\\\ce{H2O}$$';
|
||||
const expected = '$\\ce{H2O}$';
|
||||
expect(unescapeLaTeX(input)).toBe(expected);
|
||||
});
|
||||
|
||||
it('should unescape mhchem pu notation', () => {
|
||||
const input = '$$\\\\pu{123 kJ/mol}$$';
|
||||
const expected = '$\\pu{123 kJ/mol}$';
|
||||
expect(unescapeLaTeX(input)).toBe(expected);
|
||||
});
|
||||
|
||||
it('should handle multiple mhchem expressions', () => {
|
||||
const input = '$$\\\\ce{H2O}$$ and $$\\\\ce{CO2}$$';
|
||||
const expected = '$\\ce{H2O}$ and $\\ce{CO2}$';
|
||||
expect(unescapeLaTeX(input)).toBe(expected);
|
||||
});
|
||||
});
|
||||
|
||||
describe('edge cases', () => {
|
||||
it('should handle empty string', () => {
|
||||
expect(unescapeLaTeX('')).toBe('');
|
||||
});
|
||||
|
||||
it('should handle null', () => {
|
||||
expect(unescapeLaTeX(null)).toBe(null);
|
||||
});
|
||||
|
||||
it('should handle undefined', () => {
|
||||
expect(unescapeLaTeX(undefined)).toBe(undefined);
|
||||
});
|
||||
|
||||
it('should handle string with no dollar signs', () => {
|
||||
const input = 'Hello world';
|
||||
expect(unescapeLaTeX(input)).toBe(input);
|
||||
});
|
||||
|
||||
it('should handle mixed escaped and unescaped content', () => {
|
||||
const input = 'Price \\\\$14 and some text';
|
||||
const expected = 'Price $14 and some text';
|
||||
expect(unescapeLaTeX(input)).toBe(expected);
|
||||
});
|
||||
});
|
||||
|
||||
describe('real-world example from bug report', () => {
|
||||
it('should correctly unescape restaurant menu content', () => {
|
||||
const input = `# The Golden Spoon
|
||||
## *Contemporary American Cuisine*
|
||||
|
||||
---
|
||||
|
||||
### STARTERS
|
||||
|
||||
**Crispy Calamari** - *\\\\$14*
|
||||
Lightly fried, served with marinara & lemon aioli
|
||||
|
||||
**Truffle Fries** - *\\\\$12*
|
||||
Hand-cut fries, parmesan, truffle oil, fresh herbs
|
||||
|
||||
**Burrata & Heirloom Tomatoes** - *\\\\$16*
|
||||
Fresh burrata, basil pesto, balsamic reduction, grilled sourdough
|
||||
|
||||
**Thai Chicken Lettuce Wraps** - *\\\\$13*
|
||||
Spicy ground chicken, water chestnuts, ginger-soy glaze
|
||||
|
||||
**Soup of the Day** - *\\\\$9`;
|
||||
|
||||
const expected = `# The Golden Spoon
|
||||
## *Contemporary American Cuisine*
|
||||
|
||||
---
|
||||
|
||||
### STARTERS
|
||||
|
||||
**Crispy Calamari** - *$14*
|
||||
Lightly fried, served with marinara & lemon aioli
|
||||
|
||||
**Truffle Fries** - *$12*
|
||||
Hand-cut fries, parmesan, truffle oil, fresh herbs
|
||||
|
||||
**Burrata & Heirloom Tomatoes** - *$16*
|
||||
Fresh burrata, basil pesto, balsamic reduction, grilled sourdough
|
||||
|
||||
**Thai Chicken Lettuce Wraps** - *$13*
|
||||
Spicy ground chicken, water chestnuts, ginger-soy glaze
|
||||
|
||||
**Soup of the Day** - *$9`;
|
||||
|
||||
expect(unescapeLaTeX(input)).toBe(expected);
|
||||
});
|
||||
});
|
||||
});
|
||||
27
packages/api/src/utils/latex.ts
Normal file
27
packages/api/src/utils/latex.ts
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
/**
|
||||
* Unescapes LaTeX preprocessing done by the frontend preprocessLaTeX function.
|
||||
* This reverses the escaping of currency dollar signs and other LaTeX transformations.
|
||||
*
|
||||
* The frontend escapes dollar signs for proper LaTeX rendering (e.g., $14 → \\$14),
|
||||
* but the database stores the original unescaped versions. This function reverses
|
||||
* that transformation to match database content.
|
||||
*
|
||||
* @param text - The escaped text from the frontend
|
||||
* @returns The unescaped text matching the database format
|
||||
*/
|
||||
export function unescapeLaTeX(text: string | null | undefined): string | null | undefined {
|
||||
if (!text || typeof text !== 'string') {
|
||||
return text;
|
||||
}
|
||||
|
||||
// Unescape currency dollar signs (\\$ or \$ → $)
|
||||
// This is the main transformation done by preprocessLaTeX for currency
|
||||
let result = text.replace(/\\\\?\$/g, '$');
|
||||
|
||||
// Unescape mhchem notation if present
|
||||
// Convert $$\\ce{...}$$ back to $\ce{...}$
|
||||
result = result.replace(/\$\$\\\\ce\{([^}]*)\}\$\$/g, '$\\ce{$1}$');
|
||||
result = result.replace(/\$\$\\\\pu\{([^}]*)\}\$\$/g, '$\\pu{$1}$');
|
||||
|
||||
return result;
|
||||
}
|
||||
115
packages/api/src/utils/promise.spec.ts
Normal file
115
packages/api/src/utils/promise.spec.ts
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
import { withTimeout } from './promise';
|
||||
|
||||
describe('withTimeout', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllTimers();
|
||||
});
|
||||
|
||||
it('should resolve when promise completes before timeout', async () => {
|
||||
const promise = Promise.resolve('success');
|
||||
const result = await withTimeout(promise, 1000);
|
||||
expect(result).toBe('success');
|
||||
});
|
||||
|
||||
it('should reject when promise rejects before timeout', async () => {
|
||||
const promise = Promise.reject(new Error('test error'));
|
||||
await expect(withTimeout(promise, 1000)).rejects.toThrow('test error');
|
||||
});
|
||||
|
||||
it('should timeout when promise takes too long', async () => {
|
||||
const promise = new Promise((resolve) => setTimeout(() => resolve('late'), 2000));
|
||||
await expect(withTimeout(promise, 100, 'Custom timeout message')).rejects.toThrow(
|
||||
'Custom timeout message',
|
||||
);
|
||||
});
|
||||
|
||||
it('should use default error message when none provided', async () => {
|
||||
const promise = new Promise((resolve) => setTimeout(() => resolve('late'), 2000));
|
||||
await expect(withTimeout(promise, 100)).rejects.toThrow('Operation timed out after 100ms');
|
||||
});
|
||||
|
||||
it('should clear timeout when promise resolves', async () => {
|
||||
const clearTimeoutSpy = jest.spyOn(global, 'clearTimeout');
|
||||
const promise = Promise.resolve('fast');
|
||||
|
||||
await withTimeout(promise, 1000);
|
||||
|
||||
expect(clearTimeoutSpy).toHaveBeenCalled();
|
||||
clearTimeoutSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should clear timeout when promise rejects', async () => {
|
||||
const clearTimeoutSpy = jest.spyOn(global, 'clearTimeout');
|
||||
const promise = Promise.reject(new Error('fail'));
|
||||
|
||||
await expect(withTimeout(promise, 1000)).rejects.toThrow('fail');
|
||||
|
||||
expect(clearTimeoutSpy).toHaveBeenCalled();
|
||||
clearTimeoutSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should handle multiple concurrent timeouts', async () => {
|
||||
const promise1 = Promise.resolve('first');
|
||||
const promise2 = new Promise((resolve) => setTimeout(() => resolve('second'), 50));
|
||||
const promise3 = new Promise((resolve) => setTimeout(() => resolve('third'), 2000));
|
||||
|
||||
const [result1, result2] = await Promise.all([
|
||||
withTimeout(promise1, 1000),
|
||||
withTimeout(promise2, 1000),
|
||||
]);
|
||||
|
||||
expect(result1).toBe('first');
|
||||
expect(result2).toBe('second');
|
||||
|
||||
await expect(withTimeout(promise3, 100)).rejects.toThrow('Operation timed out after 100ms');
|
||||
});
|
||||
|
||||
it('should work with async functions', async () => {
|
||||
const asyncFunction = async () => {
|
||||
await new Promise((resolve) => setTimeout(resolve, 10));
|
||||
return 'async result';
|
||||
};
|
||||
|
||||
const result = await withTimeout(asyncFunction(), 1000);
|
||||
expect(result).toBe('async result');
|
||||
});
|
||||
|
||||
it('should work with any return type', async () => {
|
||||
const numberPromise = Promise.resolve(42);
|
||||
const objectPromise = Promise.resolve({ key: 'value' });
|
||||
const arrayPromise = Promise.resolve([1, 2, 3]);
|
||||
|
||||
expect(await withTimeout(numberPromise, 1000)).toBe(42);
|
||||
expect(await withTimeout(objectPromise, 1000)).toEqual({ key: 'value' });
|
||||
expect(await withTimeout(arrayPromise, 1000)).toEqual([1, 2, 3]);
|
||||
});
|
||||
|
||||
it('should call logger when timeout occurs', async () => {
|
||||
const loggerMock = jest.fn();
|
||||
const promise = new Promise((resolve) => setTimeout(() => resolve('late'), 2000));
|
||||
const errorMessage = 'Custom timeout with logger';
|
||||
|
||||
await expect(withTimeout(promise, 100, errorMessage, loggerMock)).rejects.toThrow(errorMessage);
|
||||
|
||||
expect(loggerMock).toHaveBeenCalledTimes(1);
|
||||
expect(loggerMock).toHaveBeenCalledWith(errorMessage, expect.any(Error));
|
||||
});
|
||||
|
||||
it('should not call logger when promise resolves', async () => {
|
||||
const loggerMock = jest.fn();
|
||||
const promise = Promise.resolve('success');
|
||||
|
||||
const result = await withTimeout(promise, 1000, 'Should not timeout', loggerMock);
|
||||
|
||||
expect(result).toBe('success');
|
||||
expect(loggerMock).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should work without logger parameter', async () => {
|
||||
const promise = new Promise((resolve) => setTimeout(() => resolve('late'), 2000));
|
||||
|
||||
await expect(withTimeout(promise, 100, 'No logger provided')).rejects.toThrow(
|
||||
'No logger provided',
|
||||
);
|
||||
});
|
||||
});
|
||||
42
packages/api/src/utils/promise.ts
Normal file
42
packages/api/src/utils/promise.ts
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Wraps a promise with a timeout. If the promise doesn't resolve/reject within
|
||||
* the specified time, it will be rejected with a timeout error.
|
||||
*
|
||||
* @param promise - The promise to wrap with a timeout
|
||||
* @param timeoutMs - Timeout duration in milliseconds
|
||||
* @param errorMessage - Custom error message for timeout (optional)
|
||||
* @param logger - Optional logger function to log timeout errors (e.g., console.warn, logger.warn)
|
||||
* @returns Promise that resolves/rejects with the original promise or times out
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* const result = await withTimeout(
|
||||
* fetchData(),
|
||||
* 5000,
|
||||
* 'Failed to fetch data within 5 seconds',
|
||||
* console.warn
|
||||
* );
|
||||
* ```
|
||||
*/
|
||||
export async function withTimeout<T>(
|
||||
promise: Promise<T>,
|
||||
timeoutMs: number,
|
||||
errorMessage?: string,
|
||||
logger?: (message: string, error: Error) => void,
|
||||
): Promise<T> {
|
||||
let timeoutId: NodeJS.Timeout;
|
||||
|
||||
const timeoutPromise = new Promise<never>((_, reject) => {
|
||||
timeoutId = setTimeout(() => {
|
||||
const error = new Error(errorMessage ?? `Operation timed out after ${timeoutMs}ms`);
|
||||
if (logger) logger(error.message, error);
|
||||
reject(error);
|
||||
}, timeoutMs);
|
||||
});
|
||||
|
||||
try {
|
||||
return await Promise.race([promise, timeoutPromise]);
|
||||
} finally {
|
||||
clearTimeout(timeoutId!);
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue