🧠 feat: User Memories for Conversational Context (#7760)

* 🧠 feat: User Memories for Conversational Context

chore: mcp typing, use `t`

WIP: first pass, Memories UI

- Added MemoryViewer component for displaying, editing, and deleting user memories.
- Integrated data provider hooks for fetching, updating, and deleting memories.
- Implemented pagination and loading states for better user experience.
- Created unit tests for MemoryViewer to ensure functionality and interaction with data provider.
- Updated translation files to include new UI strings related to memories.

chore: move mcp-related files to own directory

chore: rename librechat-mcp to librechat-api

WIP: first pass, memory processing and data schemas

chore: linting in fileSearch.js query description

chore: rename librechat-api to @librechat/api across the project

WIP: first pass, functional memory agent

feat: add MemoryEditDialog and MemoryViewer components for managing user memories

- Introduced MemoryEditDialog for editing memory entries with validation and toast notifications.
- Updated MemoryViewer to support editing and deleting memories, including pagination and loading states.
- Enhanced data provider to handle memory updates with optional original key for better management.
- Added new localization strings for memory-related UI elements.

feat: add memory permissions management

- Implemented memory permissions in the backend, allowing roles to have specific permissions for using, creating, updating, and reading memories.
- Added new API endpoints for updating memory permissions associated with roles.
- Created a new AdminSettings component for managing memory permissions in the frontend.
- Integrated memory permissions into the existing roles and permissions schemas.
- Updated the interface to include memory settings and permissions.
- Enhanced the MemoryViewer component to conditionally render admin settings based on user roles.
- Added localization support for memory permissions in the translation files.

feat: move AdminSettings component to a new position in MemoryViewer for better visibility

refactor: clean up commented code in MemoryViewer component

feat: enhance MemoryViewer with search functionality and improve MemoryEditDialog integration

- Added a search input to filter memories in the MemoryViewer component.
- Refactored MemoryEditDialog to accept children for better customization.
- Updated MemoryViewer to utilize the new EditMemoryButton and DeleteMemoryButton components for editing and deleting memories.
- Improved localization support by adding new strings for memory filtering and deletion confirmation.

refactor: optimize memory filtering in MemoryViewer using match-sorter

- Replaced manual filtering logic with match-sorter for improved search functionality.
- Enhanced performance and readability of the filteredMemories computation.

feat: enhance MemoryEditDialog with triggerRef and improve updateMemory mutation handling

feat: implement access control for MemoryEditDialog and MemoryViewer components

refactor: remove commented out code and create runMemory method

refactor: rename role based files

feat: implement access control for memory usage in AgentClient

refactor: simplify checkVisionRequest method in AgentClient by removing commented-out code

refactor: make `agents` dir in api package

refactor: migrate Azure utilities to TypeScript and consolidate imports

refactor: move sanitizeFilename function to a new file and update imports, add related tests

refactor: update LLM configuration types and consolidate Azure options in the API package

chore: linting

chore: import order

refactor: replace getLLMConfig with getOpenAIConfig and remove unused LLM configuration file

chore: update winston-daily-rotate-file to version 5.0.0 and add object-hash dependency in package-lock.json

refactor: move primeResources and optionalChainWithEmptyCheck functions to resources.ts and update imports

refactor: move createRun function to a new run.ts file and update related imports

fix: ensure safeAttachments is correctly typed as an array of TFile

chore: add node-fetch dependency and refactor fetch-related functions into packages/api/utils, removing the old generators file

refactor: enhance TEndpointOption type by using Pick to streamline endpoint fields and add new properties for model parameters and client options

feat: implement initializeOpenAIOptions function and update OpenAI types for enhanced configuration handling

fix: update types due to new TEndpointOption typing

fix: ensure safe access to group parameters in initializeOpenAIOptions function

fix: remove redundant API key validation comment in initializeOpenAIOptions function

refactor: rename initializeOpenAIOptions to initializeOpenAI for consistency and update related documentation

refactor: decouple req.body fields and tool loading from initializeAgentOptions

chore: linting

refactor: adjust column widths in MemoryViewer for improved layout

refactor: simplify agent initialization by creating loadAgent function and removing unused code

feat: add memory configuration loading and validation functions

WIP: first pass, memory processing with config

feat: implement memory callback and artifact handling

feat: implement memory artifacts display and processing updates

feat: add memory configuration options and schema validation for validKeys

fix: update MemoryEditDialog and MemoryViewer to handle memory state and display improvements

refactor: remove padding from BookmarkTable and MemoryViewer headers for consistent styling

WIP: initial tokenLimit config and move Tokenizer to @librechat/api

refactor: update mongoMeili plugin methods to use callback for better error handling

feat: enhance memory management with token tracking and usage metrics

- Added token counting for memory entries to enforce limits and provide usage statistics.
- Updated memory retrieval and update routes to include total token usage and limit.
- Enhanced MemoryEditDialog and MemoryViewer components to display memory usage and token information.
- Refactored memory processing functions to handle token limits and provide feedback on memory capacity.

feat: implement memory artifact handling in attachment handler

- Enhanced useAttachmentHandler to process memory artifacts when receiving updates.
- Introduced handleMemoryArtifact utility to manage memory updates and deletions.
- Updated query client to reflect changes in memory state based on incoming data.

refactor: restructure web search key extraction logic

- Moved the logic for extracting API keys from the webSearchAuth configuration into a dedicated function, getWebSearchKeys.
- Updated webSearchKeys to utilize the new function for improved clarity and maintainability.
- Prevents build time errors

feat: add personalization settings and memory preferences management

- Introduced a new Personalization tab in settings to manage user memory preferences.
- Implemented API endpoints and client-side logic for updating memory preferences.
- Enhanced user interface components to reflect personalization options and memory usage.
- Updated permissions to allow users to opt out of memory features.
- Added localization support for new settings and messages related to personalization.

style: personalization switch class

feat: add PersonalizationIcon and align Side Panel UI

feat: implement memory creation functionality

- Added a new API endpoint for creating memory entries, including validation for key and value.
- Introduced MemoryCreateDialog component for user interface to facilitate memory creation.
- Integrated token limit checks to prevent exceeding user memory capacity.
- Updated MemoryViewer to include a button for opening the memory creation dialog.
- Enhanced localization support for new messages related to memory creation.

feat: enhance message processing with configurable window size

- Updated AgentClient to use a configurable message window size for processing messages.
- Introduced messageWindowSize option in memory configuration schema with a default value of 5.
- Improved logic for selecting messages to process based on the configured window size.

chore: update librechat-data-provider version to 0.7.87 in package.json and package-lock.json

chore: remove OpenAPIPlugin and its associated tests

chore: remove MIGRATION_README.md as migration tasks are completed

ci: fix backend tests

chore: remove unused translation keys from localization file

chore: remove problematic test file and unused var in AgentClient

chore: remove unused import and import directly for JSDoc

* feat: add api package build stage in Dockerfile for improved modularity

* docs: reorder build steps in contributing guide for clarity
This commit is contained in:
Danny Avila 2025-06-07 18:52:22 -04:00 committed by GitHub
parent cd7dd576c1
commit 29ef91b4dd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
170 changed files with 5700 additions and 3632 deletions

2
packages/api/.gitignore vendored Normal file
View file

@ -0,0 +1,2 @@
node_modules/
test_bundle/

View file

@ -0,0 +1,4 @@
module.exports = {
presets: [['@babel/preset-env', { targets: { node: 'current' } }], '@babel/preset-typescript'],
plugins: ['babel-plugin-replace-ts-export-assignment'],
};

View file

@ -0,0 +1,19 @@
export default {
collectCoverageFrom: ['src/**/*.{js,jsx,ts,tsx}', '!<rootDir>/node_modules/'],
coveragePathIgnorePatterns: ['/node_modules/', '/dist/'],
coverageReporters: ['text', 'cobertura'],
testResultsProcessor: 'jest-junit',
moduleNameMapper: {
'^@src/(.*)$': '<rootDir>/src/$1',
},
// coverageThreshold: {
// global: {
// statements: 58,
// branches: 49,
// functions: 50,
// lines: 57,
// },
// },
restoreMocks: true,
testTimeout: 15000,
};

83
packages/api/package.json Normal file
View file

@ -0,0 +1,83 @@
{
"name": "@librechat/api",
"version": "1.2.2",
"type": "commonjs",
"description": "MCP services for LibreChat",
"main": "dist/index.js",
"module": "dist/index.es.js",
"types": "./dist/types/index.d.ts",
"exports": {
".": {
"require": "./dist/index.js",
"types": "./dist/types/index.d.ts"
}
},
"scripts": {
"clean": "rimraf dist",
"build": "npm run clean && rollup -c --bundleConfigAsCjs",
"build:watch": "rollup -c -w --bundleConfigAsCjs",
"test": "jest --coverage --watch",
"test:ci": "jest --coverage --ci",
"verify": "npm run test:ci",
"b:clean": "bun run rimraf dist",
"b:build": "bun run b:clean && bun run rollup -c --silent --bundleConfigAsCjs",
"start:everything-sse": "node -r dotenv/config --loader ./tsconfig-paths-bootstrap.mjs --experimental-specifier-resolution=node ./src/examples/everything/sse.ts",
"start:everything": "node -r dotenv/config --loader ./tsconfig-paths-bootstrap.mjs --experimental-specifier-resolution=node ./src/demo/everything.ts",
"start:filesystem": "node -r dotenv/config --loader ./tsconfig-paths-bootstrap.mjs --experimental-specifier-resolution=node ./src/demo/filesystem.ts",
"start:servers": "node -r dotenv/config --loader ./tsconfig-paths-bootstrap.mjs --experimental-specifier-resolution=node ./src/demo/servers.ts"
},
"repository": {
"type": "git",
"url": "git+https://github.com/danny-avila/LibreChat.git"
},
"author": "",
"license": "ISC",
"bugs": {
"url": "https://github.com/danny-avila/LibreChat/issues"
},
"homepage": "https://librechat.ai",
"devDependencies": {
"@babel/preset-env": "^7.21.5",
"@babel/preset-react": "^7.18.6",
"@babel/preset-typescript": "^7.21.0",
"@rollup/plugin-alias": "^5.1.0",
"@rollup/plugin-commonjs": "^25.0.2",
"@rollup/plugin-json": "^6.1.0",
"@rollup/plugin-node-resolve": "^15.1.0",
"@rollup/plugin-replace": "^5.0.5",
"@rollup/plugin-terser": "^0.4.4",
"@rollup/plugin-typescript": "^12.1.2",
"@types/bun": "^1.2.15",
"@types/diff": "^6.0.0",
"@types/express": "^5.0.0",
"@types/jest": "^29.5.2",
"@types/node": "^20.3.0",
"@types/react": "^18.2.18",
"@types/winston": "^2.4.4",
"jest": "^29.5.0",
"jest-junit": "^16.0.0",
"librechat-data-provider": "*",
"rimraf": "^5.0.1",
"rollup": "^4.22.4",
"rollup-plugin-generate-package-json": "^3.2.0",
"rollup-plugin-peer-deps-external": "^2.2.4",
"ts-node": "^10.9.2",
"typescript": "^5.0.4"
},
"publishConfig": {
"registry": "https://registry.npmjs.org/"
},
"peerDependencies": {
"@librechat/agents": "^2.4.37",
"@librechat/data-schemas": "*",
"librechat-data-provider": "*",
"@modelcontextprotocol/sdk": "^1.11.2",
"diff": "^7.0.0",
"eventsource": "^3.0.2",
"express": "^4.21.2",
"node-fetch": "2.7.0",
"keyv": "^5.3.2",
"zod": "^3.22.4",
"tiktoken": "^1.0.15"
}
}

View file

@ -0,0 +1,47 @@
// rollup.config.js
import { readFileSync } from 'fs';
import terser from '@rollup/plugin-terser';
import replace from '@rollup/plugin-replace';
import commonjs from '@rollup/plugin-commonjs';
import resolve from '@rollup/plugin-node-resolve';
import typescript from '@rollup/plugin-typescript';
import peerDepsExternal from 'rollup-plugin-peer-deps-external';
const pkg = JSON.parse(readFileSync(new URL('./package.json', import.meta.url), 'utf8'));
const plugins = [
peerDepsExternal(),
resolve({
preferBuiltins: true,
}),
replace({
__IS_DEV__: process.env.NODE_ENV === 'development',
preventAssignment: true,
}),
commonjs({
transformMixedEsModules: true,
requireReturnsDefault: 'auto',
}),
typescript({
tsconfig: './tsconfig.json',
outDir: './dist',
sourceMap: true,
inlineSourceMap: true,
}),
terser(),
];
const cjsBuild = {
input: 'src/index.ts',
output: {
file: pkg.main,
format: 'cjs',
sourcemap: true,
exports: 'named',
},
external: [...Object.keys(pkg.dependencies || {}), ...Object.keys(pkg.devDependencies || {})],
preserveSymlinks: true,
plugins,
};
export default cjsBuild;

View file

@ -0,0 +1,3 @@
export * from './memory';
export * from './resources';
export * from './run';

View file

@ -0,0 +1,468 @@
/** Memories */
import { z } from 'zod';
import { tool } from '@langchain/core/tools';
import { Tools } from 'librechat-data-provider';
import { logger } from '@librechat/data-schemas';
import { Run, Providers, GraphEvents } from '@librechat/agents';
import type {
StreamEventData,
ToolEndCallback,
EventHandler,
ToolEndData,
LLMConfig,
} from '@librechat/agents';
import type { TAttachment, MemoryArtifact } from 'librechat-data-provider';
import type { ObjectId, MemoryMethods } from '@librechat/data-schemas';
import type { BaseMessage } from '@langchain/core/messages';
import type { Response as ServerResponse } from 'express';
import { Tokenizer } from '~/utils';
type RequiredMemoryMethods = Pick<
MemoryMethods,
'setMemory' | 'deleteMemory' | 'getFormattedMemories'
>;
type ToolEndMetadata = Record<string, unknown> & {
run_id?: string;
thread_id?: string;
};
export interface MemoryConfig {
validKeys?: string[];
instructions?: string;
llmConfig?: Partial<LLMConfig>;
tokenLimit?: number;
}
export const memoryInstructions =
'The system automatically stores important user information and can update or delete memories based on user requests, enabling dynamic memory management.';
const getDefaultInstructions = (
validKeys?: string[],
tokenLimit?: number,
) => `Use the \`set_memory\` tool to save important information about the user, but ONLY when the user has explicitly provided this information. If there is nothing to note about the user specifically, END THE TURN IMMEDIATELY.
The \`delete_memory\` tool should only be used in two scenarios:
1. When the user explicitly asks to forget or remove specific information
2. When updating existing memories, use the \`set_memory\` tool instead of deleting and re-adding the memory.
${
validKeys && validKeys.length > 0
? `CRITICAL INSTRUCTION: Only the following keys are valid for storing memories:
${validKeys.map((key) => `- ${key}`).join('\n ')}`
: 'You can use any appropriate key to store memories about the user.'
}
${
tokenLimit
? `⚠️ TOKEN LIMIT: Each memory value must not exceed ${tokenLimit} tokens. Be concise and store only essential information.`
: ''
}
WARNING
DO NOT STORE ANY INFORMATION UNLESS THE USER HAS EXPLICITLY PROVIDED IT.
ONLY store information the user has EXPLICITLY shared.
NEVER guess or assume user information.
ALL memory values must be factual statements about THIS specific user.
If nothing needs to be stored, DO NOT CALL any memory tools.
If you're unsure whether to store something, DO NOT store it.
If nothing needs to be stored, END THE TURN IMMEDIATELY.`;
/**
* Creates a memory tool instance with user context
*/
const createMemoryTool = ({
userId,
setMemory,
validKeys,
tokenLimit,
totalTokens = 0,
}: {
userId: string | ObjectId;
setMemory: MemoryMethods['setMemory'];
validKeys?: string[];
tokenLimit?: number;
totalTokens?: number;
}) => {
return tool(
async ({ key, value }) => {
try {
if (validKeys && validKeys.length > 0 && !validKeys.includes(key)) {
logger.warn(
`Memory Agent failed to set memory: Invalid key "${key}". Must be one of: ${validKeys.join(
', ',
)}`,
);
return `Invalid key "${key}". Must be one of: ${validKeys.join(', ')}`;
}
const tokenCount = Tokenizer.getTokenCount(value, 'o200k_base');
if (tokenLimit && tokenCount > tokenLimit) {
logger.warn(
`Memory Agent failed to set memory: Value exceeds token limit. Value has ${tokenCount} tokens, but limit is ${tokenLimit}`,
);
return `Memory value too large: ${tokenCount} tokens exceeds limit of ${tokenLimit}`;
}
if (tokenLimit && totalTokens + tokenCount > tokenLimit) {
const remainingCapacity = tokenLimit - totalTokens;
logger.warn(
`Memory Agent failed to set memory: Would exceed total token limit. Current usage: ${totalTokens}, new memory: ${tokenCount} tokens, limit: ${tokenLimit}`,
);
return `Cannot add memory: would exceed token limit. Current usage: ${totalTokens}/${tokenLimit} tokens. This memory requires ${tokenCount} tokens, but only ${remainingCapacity} tokens available.`;
}
const artifact: Record<Tools.memory, MemoryArtifact> = {
[Tools.memory]: {
key,
value,
tokenCount,
type: 'update',
},
};
const result = await setMemory({ userId, key, value, tokenCount });
if (result.ok) {
logger.debug(`Memory set for key "${key}" (${tokenCount} tokens) for user "${userId}"`);
return [`Memory set for key "${key}" (${tokenCount} tokens)`, artifact];
}
logger.warn(`Failed to set memory for key "${key}" for user "${userId}"`);
return [`Failed to set memory for key "${key}"`, undefined];
} catch (error) {
logger.error('Memory Agent failed to set memory', error);
return [`Error setting memory for key "${key}"`, undefined];
}
},
{
name: 'set_memory',
description: 'Saves important information about the user into memory.',
responseFormat: 'content_and_artifact',
schema: z.object({
key: z
.string()
.describe(
validKeys && validKeys.length > 0
? `The key of the memory value. Must be one of: ${validKeys.join(', ')}`
: 'The key identifier for this memory',
),
value: z
.string()
.describe(
'Value MUST be a complete sentence that fully describes relevant user information.',
),
}),
},
);
};
/**
* Creates a delete memory tool instance with user context
*/
const createDeleteMemoryTool = ({
userId,
deleteMemory,
validKeys,
}: {
userId: string | ObjectId;
deleteMemory: MemoryMethods['deleteMemory'];
validKeys?: string[];
}) => {
return tool(
async ({ key }) => {
try {
if (validKeys && validKeys.length > 0 && !validKeys.includes(key)) {
logger.warn(
`Memory Agent failed to delete memory: Invalid key "${key}". Must be one of: ${validKeys.join(
', ',
)}`,
);
return `Invalid key "${key}". Must be one of: ${validKeys.join(', ')}`;
}
const artifact: Record<Tools.memory, MemoryArtifact> = {
[Tools.memory]: {
key,
type: 'delete',
},
};
const result = await deleteMemory({ userId, key });
if (result.ok) {
logger.debug(`Memory deleted for key "${key}" for user "${userId}"`);
return [`Memory deleted for key "${key}"`, artifact];
}
logger.warn(`Failed to delete memory for key "${key}" for user "${userId}"`);
return [`Failed to delete memory for key "${key}"`, undefined];
} catch (error) {
logger.error('Memory Agent failed to delete memory', error);
return [`Error deleting memory for key "${key}"`, undefined];
}
},
{
name: 'delete_memory',
description:
'Deletes specific memory data about the user using the provided key. For updating existing memories, use the `set_memory` tool instead',
responseFormat: 'content_and_artifact',
schema: z.object({
key: z
.string()
.describe(
validKeys && validKeys.length > 0
? `The key of the memory to delete. Must be one of: ${validKeys.join(', ')}`
: 'The key identifier of the memory to delete',
),
}),
},
);
};
export class BasicToolEndHandler implements EventHandler {
private callback?: ToolEndCallback;
constructor(callback?: ToolEndCallback) {
this.callback = callback;
}
handle(
event: string,
data: StreamEventData | undefined,
metadata?: Record<string, unknown>,
): void {
if (!metadata) {
console.warn(`Graph or metadata not found in ${event} event`);
return;
}
const toolEndData = data as ToolEndData | undefined;
if (!toolEndData?.output) {
console.warn('No output found in tool_end event');
return;
}
this.callback?.(toolEndData, metadata);
}
}
export async function processMemory({
res,
userId,
setMemory,
deleteMemory,
messages,
memory,
messageId,
conversationId,
validKeys,
instructions,
llmConfig,
tokenLimit,
totalTokens = 0,
}: {
res: ServerResponse;
setMemory: MemoryMethods['setMemory'];
deleteMemory: MemoryMethods['deleteMemory'];
userId: string | ObjectId;
memory: string;
messageId: string;
conversationId: string;
messages: BaseMessage[];
validKeys?: string[];
instructions: string;
tokenLimit?: number;
totalTokens?: number;
llmConfig?: Partial<LLMConfig>;
}): Promise<(TAttachment | null)[] | undefined> {
try {
const memoryTool = createMemoryTool({ userId, tokenLimit, setMemory, validKeys, totalTokens });
const deleteMemoryTool = createDeleteMemoryTool({
userId,
validKeys,
deleteMemory,
});
const currentMemoryTokens = totalTokens;
let memoryStatus = `# Existing memory:\n${memory ?? 'No existing memories'}`;
if (tokenLimit) {
const remainingTokens = tokenLimit - currentMemoryTokens;
memoryStatus = `# Memory Status:
Current memory usage: ${currentMemoryTokens} tokens
Token limit: ${tokenLimit} tokens
Remaining capacity: ${remainingTokens} tokens
# Existing memory:
${memory ?? 'No existing memories'}`;
}
const defaultLLMConfig: LLMConfig = {
provider: Providers.OPENAI,
model: 'gpt-4.1-mini',
temperature: 0.4,
streaming: false,
disableStreaming: true,
};
const finalLLMConfig = {
...defaultLLMConfig,
...llmConfig,
/**
* Ensure streaming is always disabled for memory processing
*/
streaming: false,
disableStreaming: true,
};
const artifactPromises: Promise<TAttachment | null>[] = [];
const memoryCallback = createMemoryCallback({ res, artifactPromises });
const customHandlers = {
[GraphEvents.TOOL_END]: new BasicToolEndHandler(memoryCallback),
};
const run = await Run.create({
runId: messageId,
graphConfig: {
type: 'standard',
llmConfig: finalLLMConfig,
tools: [memoryTool, deleteMemoryTool],
instructions,
additional_instructions: memoryStatus,
toolEnd: true,
},
customHandlers,
returnContent: true,
});
const config = {
configurable: {
provider: llmConfig?.provider,
thread_id: `memory-run-${conversationId}`,
},
streamMode: 'values',
version: 'v2',
} as const;
const inputs = {
messages,
};
const content = await run.processStream(inputs, config);
if (content) {
logger.debug('Memory Agent processed memory successfully', content);
} else {
logger.warn('Memory Agent processed memory but returned no content');
}
return await Promise.all(artifactPromises);
} catch (error) {
logger.error('Memory Agent failed to process memory', error);
}
}
export async function createMemoryProcessor({
res,
userId,
messageId,
memoryMethods,
conversationId,
config = {},
}: {
res: ServerResponse;
messageId: string;
conversationId: string;
userId: string | ObjectId;
memoryMethods: RequiredMemoryMethods;
config?: MemoryConfig;
}): Promise<[string, (messages: BaseMessage[]) => Promise<(TAttachment | null)[] | undefined>]> {
const { validKeys, instructions, llmConfig, tokenLimit } = config;
const finalInstructions = instructions || getDefaultInstructions(validKeys, tokenLimit);
const { withKeys, withoutKeys, totalTokens } = await memoryMethods.getFormattedMemories({
userId,
});
return [
withoutKeys,
async function (messages: BaseMessage[]): Promise<(TAttachment | null)[] | undefined> {
try {
return await processMemory({
res,
userId,
messages,
validKeys,
llmConfig,
messageId,
tokenLimit,
conversationId,
memory: withKeys,
totalTokens: totalTokens || 0,
instructions: finalInstructions,
setMemory: memoryMethods.setMemory,
deleteMemory: memoryMethods.deleteMemory,
});
} catch (error) {
logger.error('Memory Agent failed to process memory', error);
}
},
];
}
async function handleMemoryArtifact({
res,
data,
metadata,
}: {
res: ServerResponse;
data: ToolEndData;
metadata?: ToolEndMetadata;
}) {
const output = data?.output;
if (!output) {
return null;
}
if (!output.artifact) {
return null;
}
const memoryArtifact = output.artifact[Tools.memory] as MemoryArtifact | undefined;
if (!memoryArtifact) {
return null;
}
const attachment: Partial<TAttachment> = {
type: Tools.memory,
toolCallId: output.tool_call_id,
messageId: metadata?.run_id ?? '',
conversationId: metadata?.thread_id ?? '',
[Tools.memory]: memoryArtifact,
};
if (!res.headersSent) {
return attachment;
}
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
return attachment;
}
/**
* Creates a memory callback for handling memory artifacts
* @param params - The parameters object
* @param params.res - The server response object
* @param params.artifactPromises - Array to collect artifact promises
* @returns The memory callback function
*/
export function createMemoryCallback({
res,
artifactPromises,
}: {
res: ServerResponse;
artifactPromises: Promise<Partial<TAttachment> | null>[];
}): ToolEndCallback {
return async (data: ToolEndData, metadata?: Record<string, unknown>) => {
const output = data?.output;
const memoryArtifact = output?.artifact?.[Tools.memory] as MemoryArtifact;
if (memoryArtifact == null) {
return;
}
artifactPromises.push(
handleMemoryArtifact({ res, data, metadata }).catch((error) => {
logger.error('Error processing memory artifact content:', error);
return null;
}),
);
};
}

View file

@ -0,0 +1,543 @@
import { primeResources } from './resources';
import { logger } from '@librechat/data-schemas';
import { EModelEndpoint, EToolResources, AgentCapabilities } from 'librechat-data-provider';
import type { Request as ServerRequest } from 'express';
import type { TFile } from 'librechat-data-provider';
import type { TGetFiles } from './resources';
// Mock logger
jest.mock('@librechat/data-schemas', () => ({
logger: {
error: jest.fn(),
},
}));
describe('primeResources', () => {
let mockReq: ServerRequest;
let mockGetFiles: jest.MockedFunction<TGetFiles>;
let requestFileSet: Set<string>;
beforeEach(() => {
// Reset mocks
jest.clearAllMocks();
// Setup mock request
mockReq = {
app: {
locals: {
[EModelEndpoint.agents]: {
capabilities: [AgentCapabilities.ocr],
},
},
},
} as unknown as ServerRequest;
// Setup mock getFiles function
mockGetFiles = jest.fn();
// Setup request file set
requestFileSet = new Set(['file1', 'file2', 'file3']);
});
describe('when OCR is enabled and tool_resources has OCR file_ids', () => {
it('should fetch OCR files and include them in attachments', async () => {
const mockOcrFiles: TFile[] = [
{
user: 'user1',
file_id: 'ocr-file-1',
filename: 'document.pdf',
filepath: '/uploads/document.pdf',
object: 'file',
type: 'application/pdf',
bytes: 1024,
embedded: false,
usage: 0,
},
];
mockGetFiles.mockResolvedValue(mockOcrFiles);
const tool_resources = {
[EToolResources.ocr]: {
file_ids: ['ocr-file-1'],
},
};
const result = await primeResources({
req: mockReq,
getFiles: mockGetFiles,
requestFileSet,
attachments: undefined,
tool_resources,
});
expect(mockGetFiles).toHaveBeenCalledWith({ file_id: { $in: ['ocr-file-1'] } }, {}, {});
expect(result.attachments).toEqual(mockOcrFiles);
expect(result.tool_resources).toEqual(tool_resources);
});
});
describe('when OCR is disabled', () => {
it('should not fetch OCR files even if tool_resources has OCR file_ids', async () => {
(mockReq.app as ServerRequest['app']).locals[EModelEndpoint.agents].capabilities = [];
const tool_resources = {
[EToolResources.ocr]: {
file_ids: ['ocr-file-1'],
},
};
const result = await primeResources({
req: mockReq,
getFiles: mockGetFiles,
requestFileSet,
attachments: undefined,
tool_resources,
});
expect(mockGetFiles).not.toHaveBeenCalled();
expect(result.attachments).toBeUndefined();
expect(result.tool_resources).toEqual(tool_resources);
});
});
describe('when attachments are provided', () => {
it('should process files with fileIdentifier as execute_code resources', async () => {
const mockFiles: TFile[] = [
{
user: 'user1',
file_id: 'file1',
filename: 'script.py',
filepath: '/uploads/script.py',
object: 'file',
type: 'text/x-python',
bytes: 512,
embedded: false,
usage: 0,
metadata: {
fileIdentifier: 'python-script',
},
},
];
const attachments = Promise.resolve(mockFiles);
const result = await primeResources({
req: mockReq,
getFiles: mockGetFiles,
requestFileSet,
attachments,
tool_resources: {},
});
expect(result.attachments).toEqual(mockFiles);
expect(result.tool_resources?.[EToolResources.execute_code]?.files).toEqual(mockFiles);
});
it('should process embedded files as file_search resources', async () => {
const mockFiles: TFile[] = [
{
user: 'user1',
file_id: 'file2',
filename: 'document.txt',
filepath: '/uploads/document.txt',
object: 'file',
type: 'text/plain',
bytes: 256,
embedded: true,
usage: 0,
},
];
const attachments = Promise.resolve(mockFiles);
const result = await primeResources({
req: mockReq,
getFiles: mockGetFiles,
requestFileSet,
attachments,
tool_resources: {},
});
expect(result.attachments).toEqual(mockFiles);
expect(result.tool_resources?.[EToolResources.file_search]?.files).toEqual(mockFiles);
});
it('should process image files in requestFileSet as image_edit resources', async () => {
const mockFiles: TFile[] = [
{
user: 'user1',
file_id: 'file1',
filename: 'image.png',
filepath: '/uploads/image.png',
object: 'file',
type: 'image/png',
bytes: 2048,
embedded: false,
usage: 0,
height: 800,
width: 600,
},
];
const attachments = Promise.resolve(mockFiles);
const result = await primeResources({
req: mockReq,
getFiles: mockGetFiles,
requestFileSet,
attachments,
tool_resources: {},
});
expect(result.attachments).toEqual(mockFiles);
expect(result.tool_resources?.[EToolResources.image_edit]?.files).toEqual(mockFiles);
});
it('should not process image files not in requestFileSet', async () => {
const mockFiles: TFile[] = [
{
user: 'user1',
file_id: 'file-not-in-set',
filename: 'image.png',
filepath: '/uploads/image.png',
object: 'file',
type: 'image/png',
bytes: 2048,
embedded: false,
usage: 0,
height: 800,
width: 600,
},
];
const attachments = Promise.resolve(mockFiles);
const result = await primeResources({
req: mockReq,
getFiles: mockGetFiles,
requestFileSet,
attachments,
tool_resources: {},
});
expect(result.attachments).toEqual(mockFiles);
expect(result.tool_resources?.[EToolResources.image_edit]).toBeUndefined();
});
it('should not process image files without height and width', async () => {
const mockFiles: TFile[] = [
{
user: 'user1',
file_id: 'file1',
filename: 'image.png',
filepath: '/uploads/image.png',
object: 'file',
type: 'image/png',
bytes: 2048,
embedded: false,
usage: 0,
// Missing height and width
},
];
const attachments = Promise.resolve(mockFiles);
const result = await primeResources({
req: mockReq,
getFiles: mockGetFiles,
requestFileSet,
attachments,
tool_resources: {},
});
expect(result.attachments).toEqual(mockFiles);
expect(result.tool_resources?.[EToolResources.image_edit]).toBeUndefined();
});
it('should filter out null files from attachments', async () => {
const mockFiles: Array<TFile | null> = [
{
user: 'user1',
file_id: 'file1',
filename: 'valid.txt',
filepath: '/uploads/valid.txt',
object: 'file',
type: 'text/plain',
bytes: 256,
embedded: false,
usage: 0,
},
null,
{
user: 'user1',
file_id: 'file2',
filename: 'valid2.txt',
filepath: '/uploads/valid2.txt',
object: 'file',
type: 'text/plain',
bytes: 128,
embedded: false,
usage: 0,
},
];
const attachments = Promise.resolve(mockFiles);
const result = await primeResources({
req: mockReq,
getFiles: mockGetFiles,
requestFileSet,
attachments,
tool_resources: {},
});
expect(result.attachments).toHaveLength(2);
expect(result.attachments?.[0]?.file_id).toBe('file1');
expect(result.attachments?.[1]?.file_id).toBe('file2');
});
it('should merge existing tool_resources with new files', async () => {
const mockFiles: TFile[] = [
{
user: 'user1',
file_id: 'file1',
filename: 'script.py',
filepath: '/uploads/script.py',
object: 'file',
type: 'text/x-python',
bytes: 512,
embedded: false,
usage: 0,
metadata: {
fileIdentifier: 'python-script',
},
},
];
const existingToolResources = {
[EToolResources.execute_code]: {
files: [
{
user: 'user1',
file_id: 'existing-file',
filename: 'existing.py',
filepath: '/uploads/existing.py',
object: 'file' as const,
type: 'text/x-python',
bytes: 256,
embedded: false,
usage: 0,
},
],
},
};
const attachments = Promise.resolve(mockFiles);
const result = await primeResources({
req: mockReq,
getFiles: mockGetFiles,
requestFileSet,
attachments,
tool_resources: existingToolResources,
});
expect(result.tool_resources?.[EToolResources.execute_code]?.files).toHaveLength(2);
expect(result.tool_resources?.[EToolResources.execute_code]?.files?.[0]?.file_id).toBe(
'existing-file',
);
expect(result.tool_resources?.[EToolResources.execute_code]?.files?.[1]?.file_id).toBe(
'file1',
);
});
});
describe('when both OCR and attachments are provided', () => {
it('should include both OCR files and attachment files', async () => {
const mockOcrFiles: TFile[] = [
{
user: 'user1',
file_id: 'ocr-file-1',
filename: 'document.pdf',
filepath: '/uploads/document.pdf',
object: 'file',
type: 'application/pdf',
bytes: 1024,
embedded: false,
usage: 0,
},
];
const mockAttachmentFiles: TFile[] = [
{
user: 'user1',
file_id: 'file1',
filename: 'attachment.txt',
filepath: '/uploads/attachment.txt',
object: 'file',
type: 'text/plain',
bytes: 256,
embedded: false,
usage: 0,
},
];
mockGetFiles.mockResolvedValue(mockOcrFiles);
const attachments = Promise.resolve(mockAttachmentFiles);
const tool_resources = {
[EToolResources.ocr]: {
file_ids: ['ocr-file-1'],
},
};
const result = await primeResources({
req: mockReq,
getFiles: mockGetFiles,
requestFileSet,
attachments,
tool_resources,
});
expect(result.attachments).toHaveLength(2);
expect(result.attachments?.[0]?.file_id).toBe('ocr-file-1');
expect(result.attachments?.[1]?.file_id).toBe('file1');
});
});
describe('error handling', () => {
it('should handle errors gracefully and log them', async () => {
const mockFiles: TFile[] = [
{
user: 'user1',
file_id: 'file1',
filename: 'test.txt',
filepath: '/uploads/test.txt',
object: 'file',
type: 'text/plain',
bytes: 256,
embedded: false,
usage: 0,
},
];
const attachments = Promise.resolve(mockFiles);
const error = new Error('Test error');
// Mock getFiles to throw an error when called for OCR
mockGetFiles.mockRejectedValue(error);
const tool_resources = {
[EToolResources.ocr]: {
file_ids: ['ocr-file-1'],
},
};
const result = await primeResources({
req: mockReq,
getFiles: mockGetFiles,
requestFileSet,
attachments,
tool_resources,
});
expect(logger.error).toHaveBeenCalledWith('Error priming resources', error);
expect(result.attachments).toEqual(mockFiles);
expect(result.tool_resources).toEqual(tool_resources);
});
it('should handle promise rejection in attachments', async () => {
const error = new Error('Attachment error');
const attachments = Promise.reject(error);
// The function should now handle rejected attachment promises gracefully
const result = await primeResources({
req: mockReq,
getFiles: mockGetFiles,
requestFileSet,
attachments,
tool_resources: {},
});
// Should log both the main error and the attachment error
expect(logger.error).toHaveBeenCalledWith('Error priming resources', error);
expect(logger.error).toHaveBeenCalledWith(
'Error resolving attachments in catch block',
error,
);
// Should return empty array when attachments promise is rejected
expect(result.attachments).toEqual([]);
expect(result.tool_resources).toEqual({});
});
});
describe('edge cases', () => {
it('should handle missing app.locals gracefully', async () => {
const reqWithoutLocals = {} as ServerRequest;
const result = await primeResources({
req: reqWithoutLocals,
getFiles: mockGetFiles,
requestFileSet,
attachments: undefined,
tool_resources: {
[EToolResources.ocr]: {
file_ids: ['ocr-file-1'],
},
},
});
expect(mockGetFiles).not.toHaveBeenCalled();
// When app.locals is missing and there's an error accessing properties,
// the function falls back to the catch block which returns an empty array
expect(result.attachments).toEqual([]);
});
it('should handle undefined tool_resources', async () => {
const result = await primeResources({
req: mockReq,
getFiles: mockGetFiles,
requestFileSet,
attachments: undefined,
tool_resources: undefined,
});
expect(result.tool_resources).toEqual({});
expect(result.attachments).toBeUndefined();
});
it('should handle empty requestFileSet', async () => {
const mockFiles: TFile[] = [
{
user: 'user1',
file_id: 'file1',
filename: 'image.png',
filepath: '/uploads/image.png',
object: 'file',
type: 'image/png',
bytes: 2048,
embedded: false,
usage: 0,
height: 800,
width: 600,
},
];
const attachments = Promise.resolve(mockFiles);
const emptyRequestFileSet = new Set<string>();
const result = await primeResources({
req: mockReq,
getFiles: mockGetFiles,
requestFileSet: emptyRequestFileSet,
attachments,
tool_resources: {},
});
expect(result.attachments).toEqual(mockFiles);
expect(result.tool_resources?.[EToolResources.image_edit]).toBeUndefined();
});
});
});

View file

@ -0,0 +1,114 @@
import { logger } from '@librechat/data-schemas';
import { EModelEndpoint, EToolResources, AgentCapabilities } from 'librechat-data-provider';
import type { FilterQuery, QueryOptions, ProjectionType } from 'mongoose';
import type { AgentToolResources, TFile } from 'librechat-data-provider';
import type { IMongoFile } from '@librechat/data-schemas';
import type { Request as ServerRequest } from 'express';
export type TGetFiles = (
filter: FilterQuery<IMongoFile>,
_sortOptions: ProjectionType<IMongoFile> | null | undefined,
selectFields: QueryOptions<IMongoFile> | null | undefined,
) => Promise<Array<TFile>>;
/**
* @param params
* @param params.req
* @param params.attachments
* @param params.requestFileSet
* @param params.tool_resources
*/
export const primeResources = async ({
req,
getFiles,
requestFileSet,
attachments: _attachments,
tool_resources: _tool_resources,
}: {
req: ServerRequest;
requestFileSet: Set<string>;
attachments: Promise<Array<TFile | null>> | undefined;
tool_resources: AgentToolResources | undefined;
getFiles: TGetFiles;
}): Promise<{
attachments: Array<TFile | undefined> | undefined;
tool_resources: AgentToolResources | undefined;
}> => {
try {
let attachments: Array<TFile | undefined> | undefined;
const tool_resources = _tool_resources ?? {};
const isOCREnabled = (req.app.locals?.[EModelEndpoint.agents]?.capabilities ?? []).includes(
AgentCapabilities.ocr,
);
if (tool_resources[EToolResources.ocr]?.file_ids && isOCREnabled) {
const context = await getFiles(
{
file_id: { $in: tool_resources.ocr.file_ids },
},
{},
{},
);
attachments = (attachments ?? []).concat(context);
}
if (!_attachments) {
return { attachments, tool_resources };
}
const files = await _attachments;
if (!attachments) {
attachments = [];
}
for (const file of files) {
if (!file) {
continue;
}
if (file.metadata?.fileIdentifier) {
const execute_code = tool_resources[EToolResources.execute_code] ?? {};
if (!execute_code.files) {
tool_resources[EToolResources.execute_code] = { ...execute_code, files: [] };
}
tool_resources[EToolResources.execute_code]?.files?.push(file);
} else if (file.embedded === true) {
const file_search = tool_resources[EToolResources.file_search] ?? {};
if (!file_search.files) {
tool_resources[EToolResources.file_search] = { ...file_search, files: [] };
}
tool_resources[EToolResources.file_search]?.files?.push(file);
} else if (
requestFileSet.has(file.file_id) &&
file.type.startsWith('image') &&
file.height &&
file.width
) {
const image_edit = tool_resources[EToolResources.image_edit] ?? {};
if (!image_edit.files) {
tool_resources[EToolResources.image_edit] = { ...image_edit, files: [] };
}
tool_resources[EToolResources.image_edit]?.files?.push(file);
}
attachments.push(file);
}
return { attachments, tool_resources };
} catch (error) {
logger.error('Error priming resources', error);
// Safely try to get attachments without rethrowing
let safeAttachments: Array<TFile | undefined> = [];
if (_attachments) {
try {
const attachmentFiles = await _attachments;
safeAttachments = (attachmentFiles?.filter((file) => !!file) ?? []) as Array<TFile>;
} catch (attachmentError) {
// If attachments promise is also rejected, just use empty array
logger.error('Error resolving attachments in catch block', attachmentError);
safeAttachments = [];
}
}
return {
attachments: safeAttachments,
tool_resources: _tool_resources,
};
}
};

View file

@ -0,0 +1,90 @@
import { Run, Providers } from '@librechat/agents';
import { providerEndpointMap, KnownEndpoints } from 'librechat-data-provider';
import type { StandardGraphConfig, EventHandler, GraphEvents, IState } from '@librechat/agents';
import type { Agent } from 'librechat-data-provider';
import type * as t from '~/types';
const customProviders = new Set([
Providers.XAI,
Providers.OLLAMA,
Providers.DEEPSEEK,
Providers.OPENROUTER,
]);
/**
* Creates a new Run instance with custom handlers and configuration.
*
* @param options - The options for creating the Run instance.
* @param options.agent - The agent for this run.
* @param options.signal - The signal for this run.
* @param options.req - The server request.
* @param options.runId - Optional run ID; otherwise, a new run ID will be generated.
* @param options.customHandlers - Custom event handlers.
* @param options.streaming - Whether to use streaming.
* @param options.streamUsage - Whether to stream usage information.
* @returns {Promise<Run<IState>>} A promise that resolves to a new Run instance.
*/
export async function createRun({
runId,
agent,
signal,
customHandlers,
streaming = true,
streamUsage = true,
}: {
agent: Agent;
signal: AbortSignal;
runId?: string;
streaming?: boolean;
streamUsage?: boolean;
customHandlers?: Record<GraphEvents, EventHandler>;
}): Promise<Run<IState>> {
const provider =
providerEndpointMap[agent.provider as keyof typeof providerEndpointMap] ?? agent.provider;
const llmConfig: t.RunLLMConfig = Object.assign(
{
provider,
streaming,
streamUsage,
},
agent.model_parameters,
);
/** Resolves issues with new OpenAI usage field */
if (
customProviders.has(agent.provider) ||
(agent.provider === Providers.OPENAI && agent.endpoint !== agent.provider)
) {
llmConfig.streamUsage = false;
llmConfig.usage = true;
}
let reasoningKey: 'reasoning_content' | 'reasoning' | undefined;
if (
llmConfig.configuration?.baseURL?.includes(KnownEndpoints.openrouter) ||
(agent.endpoint && agent.endpoint.toLowerCase().includes(KnownEndpoints.openrouter))
) {
reasoningKey = 'reasoning';
}
const graphConfig: StandardGraphConfig = {
signal,
llmConfig,
reasoningKey,
tools: agent.tools,
instructions: agent.instructions,
additional_instructions: agent.additional_instructions,
// toolEnd: agent.end_after_tools,
};
// TEMPORARY FOR TESTING
if (agent.provider === Providers.ANTHROPIC || agent.provider === Providers.BEDROCK) {
graphConfig.streamBuffer = 2000;
}
return Run.create({
runId,
graphConfig,
customHandlers,
});
}

View file

@ -0,0 +1 @@
export * from './openai';

View file

@ -0,0 +1,2 @@
export * from './llm';
export * from './initialize';

View file

@ -0,0 +1,176 @@
import {
ErrorTypes,
EModelEndpoint,
resolveHeaders,
mapModelToAzureConfig,
} from 'librechat-data-provider';
import type {
LLMConfigOptions,
UserKeyValues,
InitializeOpenAIOptionsParams,
OpenAIOptionsResult,
} from '~/types';
import { createHandleLLMNewToken } from '~/utils/generators';
import { getAzureCredentials } from '~/utils/azure';
import { isUserProvided } from '~/utils/common';
import { getOpenAIConfig } from './llm';
/**
* Initializes OpenAI options for agent usage. This function always returns configuration
* options and never creates a client instance (equivalent to optionsOnly=true behavior).
*
* @param params - Configuration parameters
* @returns Promise resolving to OpenAI configuration options
* @throws Error if API key is missing or user key has expired
*/
export const initializeOpenAI = async ({
req,
overrideModel,
endpointOption,
overrideEndpoint,
getUserKeyValues,
checkUserKeyExpiry,
}: InitializeOpenAIOptionsParams): Promise<OpenAIOptionsResult> => {
const { PROXY, OPENAI_API_KEY, AZURE_API_KEY, OPENAI_REVERSE_PROXY, AZURE_OPENAI_BASEURL } =
process.env;
const { key: expiresAt } = req.body;
const modelName = overrideModel ?? req.body.model;
const endpoint = overrideEndpoint ?? req.body.endpoint;
if (!endpoint) {
throw new Error('Endpoint is required');
}
const credentials = {
[EModelEndpoint.openAI]: OPENAI_API_KEY,
[EModelEndpoint.azureOpenAI]: AZURE_API_KEY,
};
const baseURLOptions = {
[EModelEndpoint.openAI]: OPENAI_REVERSE_PROXY,
[EModelEndpoint.azureOpenAI]: AZURE_OPENAI_BASEURL,
};
const userProvidesKey = isUserProvided(credentials[endpoint as keyof typeof credentials]);
const userProvidesURL = isUserProvided(baseURLOptions[endpoint as keyof typeof baseURLOptions]);
let userValues: UserKeyValues | null = null;
if (expiresAt && (userProvidesKey || userProvidesURL)) {
checkUserKeyExpiry(expiresAt, endpoint);
userValues = await getUserKeyValues({ userId: req.user.id, name: endpoint });
}
let apiKey = userProvidesKey
? userValues?.apiKey
: credentials[endpoint as keyof typeof credentials];
const baseURL = userProvidesURL
? userValues?.baseURL
: baseURLOptions[endpoint as keyof typeof baseURLOptions];
const clientOptions: LLMConfigOptions = {
proxy: PROXY ?? undefined,
reverseProxyUrl: baseURL || undefined,
streaming: true,
};
const isAzureOpenAI = endpoint === EModelEndpoint.azureOpenAI;
const azureConfig = isAzureOpenAI && req.app.locals[EModelEndpoint.azureOpenAI];
if (isAzureOpenAI && azureConfig) {
const { modelGroupMap, groupMap } = azureConfig;
const {
azureOptions,
baseURL: configBaseURL,
headers = {},
serverless,
} = mapModelToAzureConfig({
modelName: modelName || '',
modelGroupMap,
groupMap,
});
clientOptions.reverseProxyUrl = configBaseURL ?? clientOptions.reverseProxyUrl;
clientOptions.headers = resolveHeaders({ ...headers, ...(clientOptions.headers ?? {}) });
const groupName = modelGroupMap[modelName || '']?.group;
if (groupName && groupMap[groupName]) {
clientOptions.addParams = groupMap[groupName]?.addParams;
clientOptions.dropParams = groupMap[groupName]?.dropParams;
}
apiKey = azureOptions.azureOpenAIApiKey;
clientOptions.azure = !serverless ? azureOptions : undefined;
if (serverless === true) {
clientOptions.defaultQuery = azureOptions.azureOpenAIApiVersion
? { 'api-version': azureOptions.azureOpenAIApiVersion }
: undefined;
if (!clientOptions.headers) {
clientOptions.headers = {};
}
clientOptions.headers['api-key'] = apiKey;
}
} else if (isAzureOpenAI) {
clientOptions.azure =
userProvidesKey && userValues?.apiKey ? JSON.parse(userValues.apiKey) : getAzureCredentials();
apiKey = clientOptions.azure?.azureOpenAIApiKey;
}
if (userProvidesKey && !apiKey) {
throw new Error(
JSON.stringify({
type: ErrorTypes.NO_USER_KEY,
}),
);
}
if (!apiKey) {
throw new Error(`${endpoint} API Key not provided.`);
}
const modelOptions = {
...endpointOption.model_parameters,
model: modelName,
user: req.user.id,
};
const finalClientOptions: LLMConfigOptions = {
...clientOptions,
modelOptions,
};
const options = getOpenAIConfig(apiKey, finalClientOptions, endpoint);
const openAIConfig = req.app.locals[EModelEndpoint.openAI];
const allConfig = req.app.locals.all;
const azureRate = modelName?.includes('gpt-4') ? 30 : 17;
let streamRate: number | undefined;
if (isAzureOpenAI && azureConfig) {
streamRate = azureConfig.streamRate ?? azureRate;
} else if (!isAzureOpenAI && openAIConfig) {
streamRate = openAIConfig.streamRate;
}
if (allConfig?.streamRate) {
streamRate = allConfig.streamRate;
}
if (streamRate) {
options.llmConfig.callbacks = [
{
handleLLMNewToken: createHandleLLMNewToken(streamRate),
},
];
}
const result: OpenAIOptionsResult = {
...options,
streamRate,
};
return result;
};

View file

@ -0,0 +1,156 @@
import { HttpsProxyAgent } from 'https-proxy-agent';
import { KnownEndpoints } from 'librechat-data-provider';
import type * as t from '~/types';
import { sanitizeModelName, constructAzureURL } from '~/utils/azure';
import { isEnabled } from '~/utils/common';
/**
* Generates configuration options for creating a language model (LLM) instance.
* @param apiKey - The API key for authentication.
* @param options - Additional options for configuring the LLM.
* @param endpoint - The endpoint name
* @returns Configuration options for creating an LLM instance.
*/
export function getOpenAIConfig(
apiKey: string,
options: t.LLMConfigOptions = {},
endpoint?: string | null,
): t.LLMConfigResult {
const {
modelOptions = {},
reverseProxyUrl,
defaultQuery,
headers,
proxy,
azure,
streaming = true,
addParams,
dropParams,
} = options;
const llmConfig: Partial<t.ClientOptions> & Partial<t.OpenAIParameters> = Object.assign(
{
streaming,
model: modelOptions.model ?? '',
},
modelOptions,
);
if (addParams && typeof addParams === 'object') {
Object.assign(llmConfig, addParams);
}
// Note: OpenAI Web Search models do not support any known parameters besides `max_tokens`
if (modelOptions.model && /gpt-4o.*search/.test(modelOptions.model)) {
const searchExcludeParams = [
'frequency_penalty',
'presence_penalty',
'temperature',
'top_p',
'top_k',
'stop',
'logit_bias',
'seed',
'response_format',
'n',
'logprobs',
'user',
];
const updatedDropParams = dropParams || [];
const combinedDropParams = [...new Set([...updatedDropParams, ...searchExcludeParams])];
combinedDropParams.forEach((param) => {
if (param in llmConfig) {
delete llmConfig[param as keyof t.ClientOptions];
}
});
} else if (dropParams && Array.isArray(dropParams)) {
dropParams.forEach((param) => {
if (param in llmConfig) {
delete llmConfig[param as keyof t.ClientOptions];
}
});
}
let useOpenRouter = false;
const configOptions: t.OpenAIConfiguration = {};
if (
(reverseProxyUrl && reverseProxyUrl.includes(KnownEndpoints.openrouter)) ||
(endpoint && endpoint.toLowerCase().includes(KnownEndpoints.openrouter))
) {
useOpenRouter = true;
llmConfig.include_reasoning = true;
configOptions.baseURL = reverseProxyUrl;
configOptions.defaultHeaders = Object.assign(
{
'HTTP-Referer': 'https://librechat.ai',
'X-Title': 'LibreChat',
},
headers,
);
} else if (reverseProxyUrl) {
configOptions.baseURL = reverseProxyUrl;
if (headers) {
configOptions.defaultHeaders = headers;
}
}
if (defaultQuery) {
configOptions.defaultQuery = defaultQuery;
}
if (proxy) {
const proxyAgent = new HttpsProxyAgent(proxy);
configOptions.httpAgent = proxyAgent;
}
if (azure) {
const useModelName = isEnabled(process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME);
const updatedAzure = { ...azure };
updatedAzure.azureOpenAIApiDeploymentName = useModelName
? sanitizeModelName(llmConfig.model || '')
: azure.azureOpenAIApiDeploymentName;
if (process.env.AZURE_OPENAI_DEFAULT_MODEL) {
llmConfig.model = process.env.AZURE_OPENAI_DEFAULT_MODEL;
}
if (configOptions.baseURL) {
const azureURL = constructAzureURL({
baseURL: configOptions.baseURL,
azureOptions: updatedAzure,
});
updatedAzure.azureOpenAIBasePath = azureURL.split(
`/${updatedAzure.azureOpenAIApiDeploymentName}`,
)[0];
}
Object.assign(llmConfig, updatedAzure);
llmConfig.model = updatedAzure.azureOpenAIApiDeploymentName;
} else {
llmConfig.apiKey = apiKey;
}
if (process.env.OPENAI_ORGANIZATION && azure) {
configOptions.organization = process.env.OPENAI_ORGANIZATION;
}
if (useOpenRouter && llmConfig.reasoning_effort != null) {
llmConfig.reasoning = {
effort: llmConfig.reasoning_effort,
};
delete llmConfig.reasoning_effort;
}
if (llmConfig.max_tokens != null) {
llmConfig.maxTokens = llmConfig.max_tokens;
delete llmConfig.max_tokens;
}
return {
llmConfig,
configOptions,
};
}

View file

@ -0,0 +1,152 @@
import { FlowStateManager } from './manager';
import { Keyv } from 'keyv';
import type { FlowState } from './types';
// Create a mock class without extending Keyv
class MockKeyv {
private store: Map<string, FlowState<string>>;
constructor() {
this.store = new Map();
}
async get(key: string): Promise<FlowState<string> | undefined> {
return this.store.get(key);
}
// eslint-disable-next-line @typescript-eslint/no-unused-vars
async set(key: string, value: FlowState<string>, _ttl?: number): Promise<true> {
this.store.set(key, value);
return true;
}
async delete(key: string): Promise<boolean> {
return this.store.delete(key);
}
}
describe('FlowStateManager', () => {
let flowManager: FlowStateManager<string>;
let store: MockKeyv;
beforeEach(() => {
store = new MockKeyv();
// Type assertion here since we know our mock implements the necessary methods
flowManager = new FlowStateManager(store as unknown as Keyv, { ttl: 30000, ci: true });
});
afterEach(() => {
jest.clearAllMocks();
});
describe('Concurrency Tests', () => {
it('should handle concurrent flow creation and return same result', async () => {
const flowId = 'test-flow';
const type = 'test-type';
// Start two concurrent flow creations
const flow1Promise = flowManager.createFlowWithHandler(flowId, type, async () => {
await new Promise((resolve) => setTimeout(resolve, 100));
return 'result';
});
const flow2Promise = flowManager.createFlowWithHandler(flowId, type, async () => {
await new Promise((resolve) => setTimeout(resolve, 50));
return 'different-result';
});
// Both should resolve to the same result from the first handler
const [result1, result2] = await Promise.all([flow1Promise, flow2Promise]);
expect(result1).toBe('result');
expect(result2).toBe('result');
});
it('should handle flow timeout correctly', async () => {
const flowId = 'timeout-flow';
const type = 'test-type';
// Create flow with very short TTL
const shortTtlManager = new FlowStateManager(store as unknown as Keyv, {
ttl: 100,
ci: true,
});
const flowPromise = shortTtlManager.createFlow(flowId, type);
await expect(flowPromise).rejects.toThrow('test-type flow timed out');
});
it('should maintain flow state consistency under high concurrency', async () => {
const flowId = 'concurrent-flow';
const type = 'test-type';
// Create multiple concurrent operations
const operations = [];
for (let i = 0; i < 10; i++) {
operations.push(
flowManager.createFlowWithHandler(flowId, type, async () => {
await new Promise((resolve) => setTimeout(resolve, Math.random() * 50));
return `result-${i}`;
}),
);
}
// All operations should resolve to the same result
const results = await Promise.all(operations);
const firstResult = results[0];
results.forEach((result: string) => {
expect(result).toBe(firstResult);
});
});
it('should handle race conditions in flow completion', async () => {
const flowId = 'test-flow';
const type = 'test-type';
// Create initial flow
const flowPromise = flowManager.createFlow(flowId, type);
// Increase delay to ensure flow is properly created
await new Promise((resolve) => setTimeout(resolve, 500));
// Complete the flow
await flowManager.completeFlow(flowId, type, 'result1');
const result = await flowPromise;
expect(result).toBe('result1');
}, 15000);
it('should handle concurrent flow monitoring', async () => {
const flowId = 'test-flow';
const type = 'test-type';
// Create initial flow
const flowPromise = flowManager.createFlow(flowId, type);
// Increase delay
await new Promise((resolve) => setTimeout(resolve, 500));
// Complete the flow
await flowManager.completeFlow(flowId, type, 'success');
const result = await flowPromise;
expect(result).toBe('success');
}, 15000);
it('should handle concurrent success and failure attempts', async () => {
const flowId = 'race-flow';
const type = 'test-type';
const flowPromise = flowManager.createFlow(flowId, type);
// Increase delay
await new Promise((resolve) => setTimeout(resolve, 500));
// Fail the flow
await flowManager.failFlow(flowId, type, new Error('failure'));
await expect(flowPromise).rejects.toThrow('failure');
}, 15000);
});
});

View file

@ -0,0 +1,257 @@
import { Keyv } from 'keyv';
import type { StoredDataNoRaw } from 'keyv';
import type { Logger } from 'winston';
import type { FlowState, FlowMetadata, FlowManagerOptions } from './types';
export class FlowStateManager<T = unknown> {
private keyv: Keyv;
private ttl: number;
private logger: Logger;
private intervals: Set<NodeJS.Timeout>;
private static getDefaultLogger(): Logger {
return {
error: console.error,
warn: console.warn,
info: console.info,
debug: console.debug,
} as Logger;
}
constructor(store: Keyv, options?: FlowManagerOptions) {
if (!options) {
options = { ttl: 60000 * 3 };
}
const { ci = false, ttl, logger } = options;
if (!ci && !(store instanceof Keyv)) {
throw new Error('Invalid store provided to FlowStateManager');
}
this.ttl = ttl;
this.keyv = store;
this.logger = logger || FlowStateManager.getDefaultLogger();
this.intervals = new Set();
this.setupCleanupHandlers();
}
private setupCleanupHandlers() {
const cleanup = () => {
this.logger.info('Cleaning up FlowStateManager intervals...');
this.intervals.forEach((interval) => clearInterval(interval));
this.intervals.clear();
process.exit(0);
};
process.on('SIGTERM', cleanup);
process.on('SIGINT', cleanup);
process.on('SIGQUIT', cleanup);
process.on('SIGHUP', cleanup);
}
private getFlowKey(flowId: string, type: string): string {
return `${type}:${flowId}`;
}
/**
* Creates a new flow and waits for its completion
*/
async createFlow(
flowId: string,
type: string,
metadata: FlowMetadata = {},
signal?: AbortSignal,
): Promise<T> {
const flowKey = this.getFlowKey(flowId, type);
let existingState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined;
if (existingState) {
this.logger.debug(`[${flowKey}] Flow already exists`);
return this.monitorFlow(flowKey, type, signal);
}
await new Promise((resolve) => setTimeout(resolve, 250));
existingState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined;
if (existingState) {
this.logger.debug(`[${flowKey}] Flow exists on 2nd check`);
return this.monitorFlow(flowKey, type, signal);
}
const initialState: FlowState = {
type,
status: 'PENDING',
metadata,
createdAt: Date.now(),
};
this.logger.debug('Creating initial flow state:', flowKey);
await this.keyv.set(flowKey, initialState, this.ttl);
return this.monitorFlow(flowKey, type, signal);
}
private monitorFlow(flowKey: string, type: string, signal?: AbortSignal): Promise<T> {
return new Promise<T>((resolve, reject) => {
const checkInterval = 2000;
let elapsedTime = 0;
const intervalId = setInterval(async () => {
try {
const flowState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined;
if (!flowState) {
clearInterval(intervalId);
this.intervals.delete(intervalId);
this.logger.error(`[${flowKey}] Flow state not found`);
reject(new Error(`${type} Flow state not found`));
return;
}
if (signal?.aborted) {
clearInterval(intervalId);
this.intervals.delete(intervalId);
this.logger.warn(`[${flowKey}] Flow aborted`);
const message = `${type} flow aborted`;
await this.keyv.delete(flowKey);
reject(new Error(message));
return;
}
if (flowState.status !== 'PENDING') {
clearInterval(intervalId);
this.intervals.delete(intervalId);
this.logger.debug(`[${flowKey}] Flow completed`);
if (flowState.status === 'COMPLETED' && flowState.result !== undefined) {
resolve(flowState.result);
} else if (flowState.status === 'FAILED') {
await this.keyv.delete(flowKey);
reject(new Error(flowState.error ?? `${type} flow failed`));
}
return;
}
elapsedTime += checkInterval;
if (elapsedTime >= this.ttl) {
clearInterval(intervalId);
this.intervals.delete(intervalId);
this.logger.error(
`[${flowKey}] Flow timed out | Elapsed time: ${elapsedTime} | TTL: ${this.ttl}`,
);
await this.keyv.delete(flowKey);
reject(new Error(`${type} flow timed out`));
}
this.logger.debug(
`[${flowKey}] Flow state elapsed time: ${elapsedTime}, checking again...`,
);
} catch (error) {
this.logger.error(`[${flowKey}] Error checking flow state:`, error);
clearInterval(intervalId);
this.intervals.delete(intervalId);
reject(error);
}
}, checkInterval);
this.intervals.add(intervalId);
});
}
/**
* Completes a flow successfully
*/
async completeFlow(flowId: string, type: string, result: T): Promise<boolean> {
const flowKey = this.getFlowKey(flowId, type);
const flowState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined;
if (!flowState) {
return false;
}
const updatedState: FlowState<T> = {
...flowState,
status: 'COMPLETED',
result,
completedAt: Date.now(),
};
await this.keyv.set(flowKey, updatedState, this.ttl);
return true;
}
/**
* Marks a flow as failed
*/
async failFlow(flowId: string, type: string, error: Error | string): Promise<boolean> {
const flowKey = this.getFlowKey(flowId, type);
const flowState = (await this.keyv.get(flowKey)) as FlowState | undefined;
if (!flowState) {
return false;
}
const updatedState: FlowState = {
...flowState,
status: 'FAILED',
error: error instanceof Error ? error.message : error,
failedAt: Date.now(),
};
await this.keyv.set(flowKey, updatedState, this.ttl);
return true;
}
/**
* Gets current flow state
*/
async getFlowState(flowId: string, type: string): Promise<StoredDataNoRaw<FlowState<T>> | null> {
const flowKey = this.getFlowKey(flowId, type);
return this.keyv.get(flowKey);
}
/**
* Creates a new flow and waits for its completion, only executing the handler if no existing flow is found
* @param flowId - The ID of the flow
* @param type - The type of flow
* @param handler - Async function to execute if no existing flow is found
* @param signal - Optional AbortSignal to cancel the flow
*/
async createFlowWithHandler(
flowId: string,
type: string,
handler: () => Promise<T>,
signal?: AbortSignal,
): Promise<T> {
const flowKey = this.getFlowKey(flowId, type);
let existingState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined;
if (existingState) {
this.logger.debug(`[${flowKey}] Flow already exists`);
return this.monitorFlow(flowKey, type, signal);
}
await new Promise((resolve) => setTimeout(resolve, 250));
existingState = (await this.keyv.get(flowKey)) as FlowState<T> | undefined;
if (existingState) {
this.logger.debug(`[${flowKey}] Flow exists on 2nd check`);
return this.monitorFlow(flowKey, type, signal);
}
const initialState: FlowState = {
type,
status: 'PENDING',
metadata: {},
createdAt: Date.now(),
};
this.logger.debug(`[${flowKey}] Creating initial flow state`);
await this.keyv.set(flowKey, initialState, this.ttl);
try {
const result = await handler();
await this.completeFlow(flowId, type, result);
return result;
} catch (error) {
await this.failFlow(flowId, type, error instanceof Error ? error : new Error(String(error)));
throw error;
}
}
}

View file

@ -0,0 +1,23 @@
import type { Logger } from 'winston';
export type FlowStatus = 'PENDING' | 'COMPLETED' | 'FAILED';
export interface FlowMetadata {
[key: string]: unknown;
}
export interface FlowState<T = unknown> {
type: string;
status: FlowStatus;
metadata: FlowMetadata;
createdAt: number;
result?: T;
error?: string;
completedAt?: number;
failedAt?: number;
}
export interface FlowManagerOptions {
ttl: number;
ci?: boolean;
logger?: Logger;
}

14
packages/api/src/index.ts Normal file
View file

@ -0,0 +1,14 @@
/* MCP */
export * from './mcp/manager';
/* Utilities */
export * from './mcp/utils';
export * from './utils';
/* Flow */
export * from './flow/manager';
/* Agents */
export * from './agents';
/* Endpoints */
export * from './endpoints';
/* types */
export type * from './mcp/types';
export type * from './flow/types';

View file

@ -0,0 +1,583 @@
import { EventEmitter } from 'events';
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import {
StdioClientTransport,
getDefaultEnvironment,
} from '@modelcontextprotocol/sdk/client/stdio.js';
import { WebSocketClientTransport } from '@modelcontextprotocol/sdk/client/websocket.js';
import { ResourceListChangedNotificationSchema } from '@modelcontextprotocol/sdk/types.js';
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js';
import type { Logger } from 'winston';
import type * as t from './types';
function isStdioOptions(options: t.MCPOptions): options is t.StdioOptions {
return 'command' in options;
}
function isWebSocketOptions(options: t.MCPOptions): options is t.WebSocketOptions {
if ('url' in options) {
const protocol = new URL(options.url).protocol;
return protocol === 'ws:' || protocol === 'wss:';
}
return false;
}
function isSSEOptions(options: t.MCPOptions): options is t.SSEOptions {
if ('url' in options) {
const protocol = new URL(options.url).protocol;
return protocol !== 'ws:' && protocol !== 'wss:';
}
return false;
}
/**
* Checks if the provided options are for a Streamable HTTP transport.
*
* Streamable HTTP is an MCP transport that uses HTTP POST for sending messages
* and supports streaming responses. It provides better performance than
* SSE transport while maintaining compatibility with most network environments.
*
* @param options MCP connection options to check
* @returns True if options are for a streamable HTTP transport
*/
function isStreamableHTTPOptions(options: t.MCPOptions): options is t.StreamableHTTPOptions {
if ('url' in options && options.type === 'streamable-http') {
const protocol = new URL(options.url).protocol;
return protocol !== 'ws:' && protocol !== 'wss:';
}
return false;
}
const FIVE_MINUTES = 5 * 60 * 1000;
export class MCPConnection extends EventEmitter {
private static instance: MCPConnection | null = null;
public client: Client;
private transport: Transport | null = null; // Make this nullable
private connectionState: t.ConnectionState = 'disconnected';
private connectPromise: Promise<void> | null = null;
private lastError: Error | null = null;
private lastConfigUpdate = 0;
private readonly CONFIG_TTL = 5 * 60 * 1000; // 5 minutes
private readonly MAX_RECONNECT_ATTEMPTS = 3;
public readonly serverName: string;
private shouldStopReconnecting = false;
private isReconnecting = false;
private isInitializing = false;
private reconnectAttempts = 0;
iconPath?: string;
timeout?: number;
private readonly userId?: string;
private lastPingTime: number;
constructor(
serverName: string,
private readonly options: t.MCPOptions,
private logger?: Logger,
userId?: string,
) {
super();
this.serverName = serverName;
this.logger = logger;
this.userId = userId;
this.iconPath = options.iconPath;
this.timeout = options.timeout;
this.lastPingTime = Date.now();
this.client = new Client(
{
name: '@librechat/api-client',
version: '1.2.2',
},
{
capabilities: {},
},
);
this.setupEventListeners();
}
/** Helper to generate consistent log prefixes */
private getLogPrefix(): string {
const userPart = this.userId ? `[User: ${this.userId}]` : '';
return `[MCP]${userPart}[${this.serverName}]`;
}
public static getInstance(
serverName: string,
options: t.MCPOptions,
logger?: Logger,
userId?: string,
): MCPConnection {
if (!MCPConnection.instance) {
MCPConnection.instance = new MCPConnection(serverName, options, logger, userId);
}
return MCPConnection.instance;
}
public static getExistingInstance(): MCPConnection | null {
return MCPConnection.instance;
}
public static async destroyInstance(): Promise<void> {
if (MCPConnection.instance) {
await MCPConnection.instance.disconnect();
MCPConnection.instance = null;
}
}
private emitError(error: unknown, errorContext: string): void {
const errorMessage = error instanceof Error ? error.message : String(error);
this.logger?.error(`${this.getLogPrefix()} ${errorContext}: ${errorMessage}`);
this.emit('error', new Error(`${errorContext}: ${errorMessage}`));
}
private constructTransport(options: t.MCPOptions): Transport {
try {
let type: t.MCPOptions['type'];
if (isStdioOptions(options)) {
type = 'stdio';
} else if (isWebSocketOptions(options)) {
type = 'websocket';
} else if (isStreamableHTTPOptions(options)) {
type = 'streamable-http';
} else if (isSSEOptions(options)) {
type = 'sse';
} else {
throw new Error(
'Cannot infer transport type: options.type is not provided and cannot be inferred from other properties.',
);
}
switch (type) {
case 'stdio':
if (!isStdioOptions(options)) {
throw new Error('Invalid options for stdio transport.');
}
return new StdioClientTransport({
command: options.command,
args: options.args,
// workaround bug of mcp sdk that can't pass env:
// https://github.com/modelcontextprotocol/typescript-sdk/issues/216
env: { ...getDefaultEnvironment(), ...(options.env ?? {}) },
});
case 'websocket':
if (!isWebSocketOptions(options)) {
throw new Error('Invalid options for websocket transport.');
}
return new WebSocketClientTransport(new URL(options.url));
case 'sse': {
if (!isSSEOptions(options)) {
throw new Error('Invalid options for sse transport.');
}
const url = new URL(options.url);
this.logger?.info(`${this.getLogPrefix()} Creating SSE transport: ${url.toString()}`);
const abortController = new AbortController();
const transport = new SSEClientTransport(url, {
requestInit: {
headers: options.headers,
signal: abortController.signal,
},
eventSourceInit: {
fetch: (url, init) => {
const headers = new Headers(Object.assign({}, init?.headers, options.headers));
return fetch(url, {
...init,
headers,
});
},
},
});
transport.onclose = () => {
this.logger?.info(`${this.getLogPrefix()} SSE transport closed`);
this.emit('connectionChange', 'disconnected');
};
transport.onerror = (error) => {
this.logger?.error(`${this.getLogPrefix()} SSE transport error:`, error);
this.emitError(error, 'SSE transport error:');
};
transport.onmessage = (message) => {
this.logger?.info(
`${this.getLogPrefix()} Message received: ${JSON.stringify(message)}`,
);
};
this.setupTransportErrorHandlers(transport);
return transport;
}
case 'streamable-http': {
if (!isStreamableHTTPOptions(options)) {
throw new Error('Invalid options for streamable-http transport.');
}
const url = new URL(options.url);
this.logger?.info(
`${this.getLogPrefix()} Creating streamable-http transport: ${url.toString()}`,
);
const abortController = new AbortController();
const transport = new StreamableHTTPClientTransport(url, {
requestInit: {
headers: options.headers,
signal: abortController.signal,
},
});
transport.onclose = () => {
this.logger?.info(`${this.getLogPrefix()} Streamable-http transport closed`);
this.emit('connectionChange', 'disconnected');
};
transport.onerror = (error: Error | unknown) => {
this.logger?.error(`${this.getLogPrefix()} Streamable-http transport error:`, error);
this.emitError(error, 'Streamable-http transport error:');
};
transport.onmessage = (message: JSONRPCMessage) => {
this.logger?.info(
`${this.getLogPrefix()} Message received: ${JSON.stringify(message)}`,
);
};
this.setupTransportErrorHandlers(transport);
return transport;
}
default: {
throw new Error(`Unsupported transport type: ${type}`);
}
}
} catch (error) {
this.emitError(error, 'Failed to construct transport:');
throw error;
}
}
private setupEventListeners(): void {
this.isInitializing = true;
this.on('connectionChange', (state: t.ConnectionState) => {
this.connectionState = state;
if (state === 'connected') {
this.isReconnecting = false;
this.isInitializing = false;
this.shouldStopReconnecting = false;
this.reconnectAttempts = 0;
/**
* // FOR DEBUGGING
* // this.client.setRequestHandler(PingRequestSchema, async (request, extra) => {
* // this.logger?.info(`[MCP][${this.serverName}] PingRequest: ${JSON.stringify(request)}`);
* // if (getEventListeners && extra.signal) {
* // const listenerCount = getEventListeners(extra.signal, 'abort').length;
* // this.logger?.debug(`Signal has ${listenerCount} abort listeners`);
* // }
* // return {};
* // });
*/
} else if (state === 'error' && !this.isReconnecting && !this.isInitializing) {
this.handleReconnection().catch((error) => {
this.logger?.error(`${this.getLogPrefix()} Reconnection handler failed:`, error);
});
}
});
this.subscribeToResources();
}
private async handleReconnection(): Promise<void> {
if (this.isReconnecting || this.shouldStopReconnecting || this.isInitializing) {
return;
}
this.isReconnecting = true;
const backoffDelay = (attempt: number) => Math.min(1000 * Math.pow(2, attempt), 30000);
try {
while (
this.reconnectAttempts < this.MAX_RECONNECT_ATTEMPTS &&
!(this.shouldStopReconnecting as boolean)
) {
this.reconnectAttempts++;
const delay = backoffDelay(this.reconnectAttempts);
this.logger?.info(
`${this.getLogPrefix()} Reconnecting ${this.reconnectAttempts}/${this.MAX_RECONNECT_ATTEMPTS} (delay: ${delay}ms)`,
);
await new Promise((resolve) => setTimeout(resolve, delay));
try {
await this.connect();
this.reconnectAttempts = 0;
return;
} catch (error) {
this.logger?.error(`${this.getLogPrefix()} Reconnection attempt failed:`, error);
if (
this.reconnectAttempts === this.MAX_RECONNECT_ATTEMPTS ||
(this.shouldStopReconnecting as boolean)
) {
this.logger?.error(`${this.getLogPrefix()} Stopping reconnection attempts`);
return;
}
}
}
} finally {
this.isReconnecting = false;
}
}
private subscribeToResources(): void {
this.client.setNotificationHandler(ResourceListChangedNotificationSchema, async () => {
this.invalidateCache();
this.emit('resourcesChanged');
});
}
private invalidateCache(): void {
// this.cachedConfig = null;
this.lastConfigUpdate = 0;
}
async connectClient(): Promise<void> {
if (this.connectionState === 'connected') {
return;
}
if (this.connectPromise) {
return this.connectPromise;
}
if (this.shouldStopReconnecting) {
return;
}
this.emit('connectionChange', 'connecting');
this.connectPromise = (async () => {
try {
if (this.transport) {
try {
await this.client.close();
this.transport = null;
} catch (error) {
this.logger?.warn(`${this.getLogPrefix()} Error closing connection:`, error);
}
}
this.transport = this.constructTransport(this.options);
this.setupTransportDebugHandlers();
const connectTimeout = this.options.initTimeout ?? 10000;
await Promise.race([
this.client.connect(this.transport),
new Promise((_resolve, reject) =>
setTimeout(() => reject(new Error('Connection timeout')), connectTimeout),
),
]);
this.connectionState = 'connected';
this.emit('connectionChange', 'connected');
this.reconnectAttempts = 0;
} catch (error) {
this.connectionState = 'error';
this.emit('connectionChange', 'error');
this.lastError = error instanceof Error ? error : new Error(String(error));
throw error;
} finally {
this.connectPromise = null;
}
})();
return this.connectPromise;
}
private setupTransportDebugHandlers(): void {
if (!this.transport) {
return;
}
this.transport.onmessage = (msg) => {
this.logger?.debug(`${this.getLogPrefix()} Transport received: ${JSON.stringify(msg)}`);
};
const originalSend = this.transport.send.bind(this.transport);
this.transport.send = async (msg) => {
if ('result' in msg && !('method' in msg) && Object.keys(msg.result ?? {}).length === 0) {
if (Date.now() - this.lastPingTime < FIVE_MINUTES) {
throw new Error('Empty result');
}
this.lastPingTime = Date.now();
}
this.logger?.debug(`${this.getLogPrefix()} Transport sending: ${JSON.stringify(msg)}`);
return originalSend(msg);
};
}
async connect(): Promise<void> {
try {
await this.disconnect();
await this.connectClient();
if (!this.isConnected()) {
throw new Error('Connection not established');
}
} catch (error) {
this.logger?.error(`${this.getLogPrefix()} Connection failed:`, error);
throw error;
}
}
private setupTransportErrorHandlers(transport: Transport): void {
transport.onerror = (error) => {
this.logger?.error(`${this.getLogPrefix()} Transport error:`, error);
this.emit('connectionChange', 'error');
};
}
public async disconnect(): Promise<void> {
try {
if (this.transport) {
await this.client.close();
this.transport = null;
}
if (this.connectionState === 'disconnected') {
return;
}
this.connectionState = 'disconnected';
this.emit('connectionChange', 'disconnected');
} catch (error) {
this.emit('error', error);
throw error;
} finally {
this.invalidateCache();
this.connectPromise = null;
}
}
async fetchResources(): Promise<t.MCPResource[]> {
try {
const { resources } = await this.client.listResources();
return resources;
} catch (error) {
this.emitError(error, 'Failed to fetch resources:');
return [];
}
}
async fetchTools() {
try {
const { tools } = await this.client.listTools();
return tools;
} catch (error) {
this.emitError(error, 'Failed to fetch tools:');
return [];
}
}
async fetchPrompts(): Promise<t.MCPPrompt[]> {
try {
const { prompts } = await this.client.listPrompts();
return prompts;
} catch (error) {
this.emitError(error, 'Failed to fetch prompts:');
return [];
}
}
// public async modifyConfig(config: ContinueConfig): Promise<ContinueConfig> {
// try {
// // Check cache
// if (this.cachedConfig && Date.now() - this.lastConfigUpdate < this.CONFIG_TTL) {
// return this.cachedConfig;
// }
// await this.connectClient();
// // Fetch and process resources
// const resources = await this.fetchResources();
// const submenuItems = resources.map(resource => ({
// title: resource.name,
// description: resource.description,
// id: resource.uri,
// }));
// if (!config.contextProviders) {
// config.contextProviders = [];
// }
// config.contextProviders.push(
// new MCPContextProvider({
// submenuItems,
// client: this.client,
// }),
// );
// // Fetch and process tools
// const tools = await this.fetchTools();
// const continueTools: Tool[] = tools.map(tool => ({
// displayTitle: tool.name,
// function: {
// description: tool.description,
// name: tool.name,
// parameters: tool.inputSchema,
// },
// readonly: false,
// type: 'function',
// wouldLikeTo: `use the ${tool.name} tool`,
// uri: `mcp://${tool.name}`,
// }));
// config.tools = [...(config.tools || []), ...continueTools];
// // Fetch and process prompts
// const prompts = await this.fetchPrompts();
// if (!config.slashCommands) {
// config.slashCommands = [];
// }
// const slashCommands: SlashCommand[] = prompts.map(prompt =>
// constructMcpSlashCommand(
// this.client,
// prompt.name,
// prompt.description,
// prompt.arguments?.map(a => a.name),
// ),
// );
// config.slashCommands.push(...slashCommands);
// // Update cache
// this.cachedConfig = config;
// this.lastConfigUpdate = Date.now();
// return config;
// } catch (error) {
// this.emit('error', error);
// // Return original config if modification fails
// return config;
// }
// }
// Public getters for state information
public getConnectionState(): t.ConnectionState {
return this.connectionState;
}
public async isConnected(): Promise<boolean> {
try {
await this.client.ping();
return this.connectionState === 'connected';
} catch (error) {
this.logger?.error(`${this.getLogPrefix()} Ping failed:`, error);
return false;
}
}
public getLastError(): Error | null {
return this.lastError;
}
}

View file

@ -0,0 +1,3 @@
export enum CONSTANTS {
mcp_delimiter = '_mcp_',
}

View file

@ -0,0 +1,617 @@
import { CallToolResultSchema, ErrorCode, McpError } from '@modelcontextprotocol/sdk/types.js';
import type { RequestOptions } from '@modelcontextprotocol/sdk/shared/protocol.js';
import type { JsonSchemaType, MCPOptions } from 'librechat-data-provider';
import type { Logger } from 'winston';
import type * as t from './types';
import { formatToolContent } from './parsers';
import { MCPConnection } from './connection';
import { CONSTANTS } from './enum';
export interface CallToolOptions extends RequestOptions {
userId?: string;
}
export class MCPManager {
private static instance: MCPManager | null = null;
/** App-level connections initialized at startup */
private connections: Map<string, MCPConnection> = new Map();
/** User-specific connections initialized on demand */
private userConnections: Map<string, Map<string, MCPConnection>> = new Map();
/** Last activity timestamp for users (not per server) */
private userLastActivity: Map<string, number> = new Map();
private readonly USER_CONNECTION_IDLE_TIMEOUT = 15 * 60 * 1000; // 15 minutes (TODO: make configurable)
private mcpConfigs: t.MCPServers = {};
private processMCPEnv?: (obj: MCPOptions, userId?: string) => MCPOptions; // Store the processing function
/** Store MCP server instructions */
private serverInstructions: Map<string, string> = new Map();
private logger: Logger;
private static getDefaultLogger(): Logger {
return {
error: console.error,
warn: console.warn,
info: console.info,
debug: console.debug,
} as Logger;
}
private constructor(logger?: Logger) {
this.logger = logger || MCPManager.getDefaultLogger();
}
public static getInstance(logger?: Logger): MCPManager {
if (!MCPManager.instance) {
MCPManager.instance = new MCPManager(logger);
}
// Check for idle connections when getInstance is called
MCPManager.instance.checkIdleConnections();
return MCPManager.instance;
}
/** Stores configs and initializes app-level connections */
public async initializeMCP(
mcpServers: t.MCPServers,
processMCPEnv?: (obj: MCPOptions) => MCPOptions,
): Promise<void> {
this.logger.info('[MCP] Initializing app-level servers');
this.processMCPEnv = processMCPEnv; // Store the function
this.mcpConfigs = mcpServers;
const entries = Object.entries(mcpServers);
const initializedServers = new Set();
const connectionResults = await Promise.allSettled(
entries.map(async ([serverName, _config], i) => {
/** Process env for app-level connections */
const config = this.processMCPEnv ? this.processMCPEnv(_config) : _config;
const connection = new MCPConnection(serverName, config, this.logger);
try {
const connectionTimeout = new Promise<void>((_, reject) =>
setTimeout(() => reject(new Error('Connection timeout')), 30000),
);
const connectionAttempt = this.initializeServer(connection, `[MCP][${serverName}]`);
await Promise.race([connectionAttempt, connectionTimeout]);
if (await connection.isConnected()) {
initializedServers.add(i);
this.connections.set(serverName, connection); // Store in app-level map
// Handle unified serverInstructions configuration
const configInstructions = config.serverInstructions;
if (configInstructions !== undefined) {
if (typeof configInstructions === 'string') {
// Custom instructions provided
this.serverInstructions.set(serverName, configInstructions);
this.logger.info(
`[MCP][${serverName}] Custom instructions stored for context inclusion: ${configInstructions}`,
);
} else if (configInstructions === true) {
// Use server-provided instructions
const serverInstructions = connection.client.getInstructions();
if (serverInstructions) {
this.serverInstructions.set(serverName, serverInstructions);
this.logger.info(
`[MCP][${serverName}] Server instructions stored for context inclusion: ${serverInstructions}`,
);
} else {
this.logger.info(
`[MCP][${serverName}] serverInstructions=true but no server instructions available`,
);
}
} else {
// configInstructions is false - explicitly disabled
this.logger.info(
`[MCP][${serverName}] Instructions explicitly disabled (serverInstructions=false)`,
);
}
} else {
this.logger.info(
`[MCP][${serverName}] Instructions not included (serverInstructions not configured)`,
);
}
const serverCapabilities = connection.client.getServerCapabilities();
this.logger.info(
`[MCP][${serverName}] Capabilities: ${JSON.stringify(serverCapabilities)}`,
);
if (serverCapabilities?.tools) {
const tools = await connection.client.listTools();
if (tools.tools.length) {
this.logger.info(
`[MCP][${serverName}] Available tools: ${tools.tools
.map((tool) => tool.name)
.join(', ')}`,
);
}
}
}
} catch (error) {
this.logger.error(`[MCP][${serverName}] Initialization failed`, error);
throw error;
}
}),
);
const failedConnections = connectionResults.filter(
(result): result is PromiseRejectedResult => result.status === 'rejected',
);
this.logger.info(
`[MCP] Initialized ${initializedServers.size}/${entries.length} app-level server(s)`,
);
if (failedConnections.length > 0) {
this.logger.warn(
`[MCP] ${failedConnections.length}/${entries.length} app-level server(s) failed to initialize`,
);
}
entries.forEach(([serverName], index) => {
if (initializedServers.has(index)) {
this.logger.info(`[MCP][${serverName}] ✓ Initialized`);
} else {
this.logger.info(`[MCP][${serverName}] ✗ Failed`);
}
});
if (initializedServers.size === entries.length) {
this.logger.info('[MCP] All app-level servers initialized successfully');
} else if (initializedServers.size === 0) {
this.logger.warn('[MCP] No app-level servers initialized');
}
}
/** Generic server initialization logic */
private async initializeServer(connection: MCPConnection, logPrefix: string): Promise<void> {
const maxAttempts = 3;
let attempts = 0;
while (attempts < maxAttempts) {
try {
await connection.connect();
if (await connection.isConnected()) {
return;
}
throw new Error('Connection attempt succeeded but status is not connected');
} catch (error) {
attempts++;
if (attempts === maxAttempts) {
this.logger.error(`${logPrefix} Failed to connect after ${maxAttempts} attempts`, error);
throw error; // Re-throw the last error
}
await new Promise((resolve) => setTimeout(resolve, 2000 * attempts));
}
}
}
/** Check for and disconnect idle connections */
private checkIdleConnections(currentUserId?: string): void {
const now = Date.now();
// Iterate through all users to check for idle ones
for (const [userId, lastActivity] of this.userLastActivity.entries()) {
if (currentUserId && currentUserId === userId) {
continue;
}
if (now - lastActivity > this.USER_CONNECTION_IDLE_TIMEOUT) {
this.logger.info(
`[MCP][User: ${userId}] User idle for too long. Disconnecting all connections...`,
);
// Disconnect all user connections asynchronously (fire and forget)
this.disconnectUserConnections(userId).catch((err) =>
this.logger.error(`[MCP][User: ${userId}] Error disconnecting idle connections:`, err),
);
}
}
}
/** Updates the last activity timestamp for a user */
private updateUserLastActivity(userId: string): void {
const now = Date.now();
this.userLastActivity.set(userId, now);
this.logger.debug(
`[MCP][User: ${userId}] Updated last activity timestamp: ${new Date(now).toISOString()}`,
);
}
/** Gets or creates a connection for a specific user */
public async getUserConnection(userId: string, serverName: string): Promise<MCPConnection> {
const userServerMap = this.userConnections.get(userId);
let connection = userServerMap?.get(serverName);
const now = Date.now();
// Check if user is idle
const lastActivity = this.userLastActivity.get(userId);
if (lastActivity && now - lastActivity > this.USER_CONNECTION_IDLE_TIMEOUT) {
this.logger.info(
`[MCP][User: ${userId}] User idle for too long. Disconnecting all connections.`,
);
// Disconnect all user connections
try {
await this.disconnectUserConnections(userId);
} catch (err) {
this.logger.error(`[MCP][User: ${userId}] Error disconnecting idle connections:`, err);
}
connection = undefined; // Force creation of a new connection
} else if (connection) {
if (await connection.isConnected()) {
this.logger.debug(`[MCP][User: ${userId}][${serverName}] Reusing active connection`);
// Update timestamp on reuse
this.updateUserLastActivity(userId);
return connection;
} else {
// Connection exists but is not connected, attempt to remove potentially stale entry
this.logger.warn(
`[MCP][User: ${userId}][${serverName}] Found existing but disconnected connection object. Cleaning up.`,
);
this.removeUserConnection(userId, serverName); // Clean up maps
connection = undefined;
}
}
// If no valid connection exists, create a new one
if (!connection) {
this.logger.info(`[MCP][User: ${userId}][${serverName}] Establishing new connection`);
}
let config = this.mcpConfigs[serverName];
if (!config) {
throw new McpError(
ErrorCode.InvalidRequest,
`[MCP][User: ${userId}] Configuration for server "${serverName}" not found.`,
);
}
if (this.processMCPEnv) {
config = { ...(this.processMCPEnv(config, userId) ?? {}) };
}
connection = new MCPConnection(serverName, config, this.logger, userId);
try {
const connectionTimeout = new Promise<void>((_, reject) =>
setTimeout(() => reject(new Error('Connection timeout')), 30000),
);
const connectionAttempt = this.initializeServer(
connection,
`[MCP][User: ${userId}][${serverName}]`,
);
await Promise.race([connectionAttempt, connectionTimeout]);
if (!(await connection.isConnected())) {
throw new Error('Failed to establish connection after initialization attempt.');
}
if (!this.userConnections.has(userId)) {
this.userConnections.set(userId, new Map());
}
this.userConnections.get(userId)?.set(serverName, connection);
this.logger.info(`[MCP][User: ${userId}][${serverName}] Connection successfully established`);
// Update timestamp on creation
this.updateUserLastActivity(userId);
return connection;
} catch (error) {
this.logger.error(
`[MCP][User: ${userId}][${serverName}] Failed to establish connection`,
error,
);
// Ensure partial connection state is cleaned up if initialization fails
await connection.disconnect().catch((disconnectError) => {
this.logger.error(
`[MCP][User: ${userId}][${serverName}] Error during cleanup after failed connection`,
disconnectError,
);
});
// Ensure cleanup even if connection attempt fails
this.removeUserConnection(userId, serverName);
throw error; // Re-throw the error to the caller
}
}
/** Removes a specific user connection entry */
private removeUserConnection(userId: string, serverName: string): void {
// Remove connection object
const userMap = this.userConnections.get(userId);
if (userMap) {
userMap.delete(serverName);
if (userMap.size === 0) {
this.userConnections.delete(userId);
// Only remove user activity timestamp if all connections are gone
this.userLastActivity.delete(userId);
}
}
this.logger.debug(`[MCP][User: ${userId}][${serverName}] Removed connection entry.`);
}
/** Disconnects and removes a specific user connection */
public async disconnectUserConnection(userId: string, serverName: string): Promise<void> {
const userMap = this.userConnections.get(userId);
const connection = userMap?.get(serverName);
if (connection) {
this.logger.info(`[MCP][User: ${userId}][${serverName}] Disconnecting...`);
await connection.disconnect();
this.removeUserConnection(userId, serverName);
}
}
/** Disconnects and removes all connections for a specific user */
public async disconnectUserConnections(userId: string): Promise<void> {
const userMap = this.userConnections.get(userId);
const disconnectPromises: Promise<void>[] = [];
if (userMap) {
this.logger.info(`[MCP][User: ${userId}] Disconnecting all servers...`);
const userServers = Array.from(userMap.keys());
for (const serverName of userServers) {
disconnectPromises.push(
this.disconnectUserConnection(userId, serverName).catch((error) => {
this.logger.error(
`[MCP][User: ${userId}][${serverName}] Error during disconnection:`,
error,
);
}),
);
}
await Promise.allSettled(disconnectPromises);
// Ensure user activity timestamp is removed
this.userLastActivity.delete(userId);
this.logger.info(`[MCP][User: ${userId}] All connections processed for disconnection.`);
}
}
/** Returns the app-level connection (used for mapping tools, etc.) */
public getConnection(serverName: string): MCPConnection | undefined {
return this.connections.get(serverName);
}
/** Returns all app-level connections */
public getAllConnections(): Map<string, MCPConnection> {
return this.connections;
}
/**
* Maps available tools from all app-level connections into the provided object.
* The object is modified in place.
*/
public async mapAvailableTools(availableTools: t.LCAvailableTools): Promise<void> {
for (const [serverName, connection] of this.connections.entries()) {
try {
if ((await connection.isConnected()) !== true) {
this.logger.warn(
`[MCP][${serverName}] Connection not established. Skipping tool mapping.`,
);
continue;
}
const tools = await connection.fetchTools();
for (const tool of tools) {
const name = `${tool.name}${CONSTANTS.mcp_delimiter}${serverName}`;
availableTools[name] = {
type: 'function',
['function']: {
name,
description: tool.description,
parameters: tool.inputSchema as JsonSchemaType,
},
};
}
} catch (error) {
this.logger.warn(`[MCP][${serverName}] Error fetching tools for mapping:`, error);
}
}
}
/**
* Loads tools from all app-level connections into the manifest.
*/
public async loadManifestTools(manifestTools: t.LCToolManifest): Promise<t.LCToolManifest> {
const mcpTools: t.LCManifestTool[] = [];
for (const [serverName, connection] of this.connections.entries()) {
try {
if ((await connection.isConnected()) !== true) {
this.logger.warn(
`[MCP][${serverName}] Connection not established. Skipping manifest loading.`,
);
continue;
}
const tools = await connection.fetchTools();
for (const tool of tools) {
const pluginKey = `${tool.name}${CONSTANTS.mcp_delimiter}${serverName}`;
const manifestTool: t.LCManifestTool = {
name: tool.name,
pluginKey,
description: tool.description ?? '',
icon: connection.iconPath,
};
const config = this.mcpConfigs[serverName];
if (config?.chatMenu === false) {
manifestTool.chatMenu = false;
}
mcpTools.push(manifestTool);
}
} catch (error) {
this.logger.error(`[MCP][${serverName}] Error fetching tools for manifest:`, error);
}
}
return [...mcpTools, ...manifestTools];
}
/**
* Calls a tool on an MCP server, using either a user-specific connection
* (if userId is provided) or an app-level connection. Updates the last activity timestamp
* for user-specific connections upon successful call initiation.
*/
async callTool({
serverName,
toolName,
provider,
toolArguments,
options,
}: {
serverName: string;
toolName: string;
provider: t.Provider;
toolArguments?: Record<string, unknown>;
options?: CallToolOptions;
}): Promise<t.FormattedToolResponse> {
let connection: MCPConnection | undefined;
const { userId, ...callOptions } = options ?? {};
const logPrefix = userId ? `[MCP][User: ${userId}][${serverName}]` : `[MCP][${serverName}]`;
try {
if (userId) {
this.updateUserLastActivity(userId);
// Get or create user-specific connection
connection = await this.getUserConnection(userId, serverName);
} else {
// Use app-level connection
connection = this.connections.get(serverName);
if (!connection) {
throw new McpError(
ErrorCode.InvalidRequest,
`${logPrefix} No app-level connection found. Cannot execute tool ${toolName}.`,
);
}
}
if (!(await connection.isConnected())) {
// This might happen if getUserConnection failed silently or app connection dropped
throw new McpError(
ErrorCode.InternalError, // Use InternalError for connection issues
`${logPrefix} Connection is not active. Cannot execute tool ${toolName}.`,
);
}
const result = await connection.client.request(
{
method: 'tools/call',
params: {
name: toolName,
arguments: toolArguments,
},
},
CallToolResultSchema,
{
timeout: connection.timeout,
...callOptions,
},
);
if (userId) {
this.updateUserLastActivity(userId);
}
this.checkIdleConnections();
return formatToolContent(result, provider);
} catch (error) {
// Log with context and re-throw or handle as needed
this.logger.error(`${logPrefix}[${toolName}] Tool call failed`, error);
// Rethrowing allows the caller (createMCPTool) to handle the final user message
throw error;
}
}
/** Disconnects a specific app-level server */
public async disconnectServer(serverName: string): Promise<void> {
const connection = this.connections.get(serverName);
if (connection) {
this.logger.info(`[MCP][${serverName}] Disconnecting...`);
await connection.disconnect();
this.connections.delete(serverName);
}
}
/** Disconnects all app-level and user-level connections */
public async disconnectAll(): Promise<void> {
this.logger.info('[MCP] Disconnecting all app-level and user-level connections...');
const userDisconnectPromises = Array.from(this.userConnections.keys()).map((userId) =>
this.disconnectUserConnections(userId),
);
await Promise.allSettled(userDisconnectPromises);
this.userLastActivity.clear();
// Disconnect all app-level connections
const appDisconnectPromises = Array.from(this.connections.values()).map((connection) =>
connection.disconnect().catch((error) => {
this.logger.error(`[MCP][${connection.serverName}] Error during disconnectAll:`, error);
}),
);
await Promise.allSettled(appDisconnectPromises);
this.connections.clear();
this.logger.info('[MCP] All connections processed for disconnection.');
}
/** Destroys the singleton instance and disconnects all connections */
public static async destroyInstance(): Promise<void> {
if (MCPManager.instance) {
await MCPManager.instance.disconnectAll();
MCPManager.instance = null;
const logger = MCPManager.getDefaultLogger();
logger.info('[MCP] Manager instance destroyed.');
}
}
/**
* Get instructions for MCP servers
* @param serverNames Optional array of server names. If not provided or empty, returns all servers.
* @returns Object mapping server names to their instructions
*/
public getInstructions(serverNames?: string[]): Record<string, string> {
const instructions: Record<string, string> = {};
if (!serverNames || serverNames.length === 0) {
// Return all instructions if no specific servers requested
for (const [serverName, serverInstructions] of this.serverInstructions.entries()) {
instructions[serverName] = serverInstructions;
}
} else {
// Return instructions for specific servers
for (const serverName of serverNames) {
const serverInstructions = this.serverInstructions.get(serverName);
if (serverInstructions) {
instructions[serverName] = serverInstructions;
}
}
}
return instructions;
}
/**
* Format MCP server instructions for injection into context
* @param serverNames Optional array of server names to include. If not provided, includes all servers.
* @returns Formatted instructions string ready for context injection
*/
public formatInstructionsForContext(serverNames?: string[]): string {
/** Instructions for specified servers or all stored instructions */
const instructionsToInclude = this.getInstructions(serverNames);
if (Object.keys(instructionsToInclude).length === 0) {
return '';
}
// Format instructions for context injection
const formattedInstructions = Object.entries(instructionsToInclude)
.map(([serverName, instructions]) => {
return `## ${serverName} MCP Server Instructions
${instructions}`;
})
.join('\n\n');
return `# MCP Server Instructions
The following MCP servers are available with their specific instructions:
${formattedInstructions}
Please follow these instructions when using tools from the respective MCP servers.`;
}
}

View file

@ -0,0 +1,183 @@
import type * as t from './types';
const RECOGNIZED_PROVIDERS = new Set([
'google',
'anthropic',
'openai',
'openrouter',
'xai',
'deepseek',
'ollama',
]);
const CONTENT_ARRAY_PROVIDERS = new Set(['google', 'anthropic', 'openai']);
const imageFormatters: Record<string, undefined | t.ImageFormatter> = {
// google: (item) => ({
// type: 'image',
// inlineData: {
// mimeType: item.mimeType,
// data: item.data,
// },
// }),
// anthropic: (item) => ({
// type: 'image',
// source: {
// type: 'base64',
// media_type: item.mimeType,
// data: item.data,
// },
// }),
default: (item) => ({
type: 'image_url',
image_url: {
url: item.data.startsWith('http') ? item.data : `data:${item.mimeType};base64,${item.data}`,
},
}),
};
function isImageContent(item: t.ToolContentPart): item is t.ImageContent {
return item.type === 'image';
}
function parseAsString(result: t.MCPToolCallResponse): string {
const content = result?.content ?? [];
if (!content.length) {
return '(No response)';
}
const text = content
.map((item) => {
if (item.type === 'text') {
return item.text;
}
if (item.type === 'resource') {
const resourceText = [];
if (item.resource.text != null && item.resource.text) {
resourceText.push(item.resource.text);
}
if (item.resource.uri) {
resourceText.push(`Resource URI: ${item.resource.uri}`);
}
if (item.resource.name) {
resourceText.push(`Resource: ${item.resource.name}`);
}
if (item.resource.description) {
resourceText.push(`Description: ${item.resource.description}`);
}
if (item.resource.mimeType != null && item.resource.mimeType) {
resourceText.push(`Type: ${item.resource.mimeType}`);
}
return resourceText.join('\n');
}
return JSON.stringify(item, null, 2);
})
.filter(Boolean)
.join('\n\n');
return text;
}
/**
* Converts MCPToolCallResponse content into recognized content block types
* Recognized types: "image", "image_url", "text", "json"
*
* @param {t.MCPToolCallResponse} result - The MCPToolCallResponse object
* @param {string} provider - The provider name (google, anthropic, openai)
* @returns {Array<Object>} Formatted content blocks
*/
/**
* Converts MCPToolCallResponse content into recognized content block types
* First element: string or formatted content (excluding image_url)
* Second element: image_url content if any
*
* @param {t.MCPToolCallResponse} result - The MCPToolCallResponse object
* @param {string} provider - The provider name (google, anthropic, openai)
* @returns {t.FormattedContentResult} Tuple of content and image_urls
*/
export function formatToolContent(
result: t.MCPToolCallResponse,
provider: t.Provider,
): t.FormattedContentResult {
if (!RECOGNIZED_PROVIDERS.has(provider)) {
return [parseAsString(result), undefined];
}
const content = result?.content ?? [];
if (!content.length) {
return [[{ type: 'text', text: '(No response)' }], undefined];
}
const formattedContent: t.FormattedContent[] = [];
const imageUrls: t.FormattedContent[] = [];
let currentTextBlock = '';
type ContentHandler = undefined | ((item: t.ToolContentPart) => void);
const contentHandlers: {
text: (item: Extract<t.ToolContentPart, { type: 'text' }>) => void;
image: (item: t.ToolContentPart) => void;
resource: (item: Extract<t.ToolContentPart, { type: 'resource' }>) => void;
} = {
text: (item) => {
currentTextBlock += (currentTextBlock ? '\n\n' : '') + item.text;
},
image: (item) => {
if (!isImageContent(item)) {
return;
}
if (CONTENT_ARRAY_PROVIDERS.has(provider) && currentTextBlock) {
formattedContent.push({ type: 'text', text: currentTextBlock });
currentTextBlock = '';
}
const formatter = imageFormatters.default as t.ImageFormatter;
const formattedImage = formatter(item);
if (formattedImage.type === 'image_url') {
imageUrls.push(formattedImage);
} else {
formattedContent.push(formattedImage);
}
},
resource: (item) => {
const resourceText = [];
if (item.resource.text != null && item.resource.text) {
resourceText.push(item.resource.text);
}
if (item.resource.uri.length) {
resourceText.push(`Resource URI: ${item.resource.uri}`);
}
if (item.resource.name) {
resourceText.push(`Resource: ${item.resource.name}`);
}
if (item.resource.description) {
resourceText.push(`Description: ${item.resource.description}`);
}
if (item.resource.mimeType != null && item.resource.mimeType) {
resourceText.push(`Type: ${item.resource.mimeType}`);
}
currentTextBlock += (currentTextBlock ? '\n\n' : '') + resourceText.join('\n');
},
};
for (const item of content) {
const handler = contentHandlers[item.type as keyof typeof contentHandlers] as ContentHandler;
if (handler) {
handler(item as never);
} else {
const stringified = JSON.stringify(item, null, 2);
currentTextBlock += (currentTextBlock ? '\n\n' : '') + stringified;
}
}
if (CONTENT_ARRAY_PROVIDERS.has(provider) && currentTextBlock) {
formattedContent.push({ type: 'text', text: currentTextBlock });
}
const artifacts = imageUrls.length ? { content: imageUrls } : undefined;
if (CONTENT_ARRAY_PROVIDERS.has(provider)) {
return [formattedContent, artifacts];
}
return [currentTextBlock, artifacts];
}

View file

@ -0,0 +1,98 @@
import { z } from 'zod';
import {
SSEOptionsSchema,
MCPOptionsSchema,
MCPServersSchema,
StdioOptionsSchema,
WebSocketOptionsSchema,
StreamableHTTPOptionsSchema,
} from 'librechat-data-provider';
import type { JsonSchemaType, TPlugin } from 'librechat-data-provider';
import type * as t from '@modelcontextprotocol/sdk/types.js';
export type StdioOptions = z.infer<typeof StdioOptionsSchema>;
export type WebSocketOptions = z.infer<typeof WebSocketOptionsSchema>;
export type SSEOptions = z.infer<typeof SSEOptionsSchema>;
export type StreamableHTTPOptions = z.infer<typeof StreamableHTTPOptionsSchema>;
export type MCPOptions = z.infer<typeof MCPOptionsSchema>;
export type MCPServers = z.infer<typeof MCPServersSchema>;
export interface MCPResource {
uri: string;
name: string;
description?: string;
mimeType?: string;
}
export interface LCTool {
name: string;
description?: string;
parameters: JsonSchemaType;
}
export interface LCFunctionTool {
type: 'function';
['function']: LCTool;
}
export type LCAvailableTools = Record<string, LCFunctionTool>;
export type LCManifestTool = TPlugin;
export type LCToolManifest = TPlugin[];
export interface MCPPrompt {
name: string;
description?: string;
arguments?: Array<{ name: string }>;
}
export type ConnectionState = 'disconnected' | 'connecting' | 'connected' | 'error';
export type MCPTool = z.infer<typeof t.ToolSchema>;
export type MCPToolListResponse = z.infer<typeof t.ListToolsResultSchema>;
export type ToolContentPart = t.TextContent | t.ImageContent | t.EmbeddedResource | t.AudioContent;
export type ImageContent = Extract<ToolContentPart, { type: 'image' }>;
export type MCPToolCallResponse =
| undefined
| {
_meta?: Record<string, unknown>;
content?: Array<ToolContentPart>;
isError?: boolean;
};
export type Provider = 'google' | 'anthropic' | 'openAI';
export type FormattedContent =
| {
type: 'text';
text: string;
}
| {
type: 'image';
inlineData: {
mimeType: string;
data: string;
};
}
| {
type: 'image';
source: {
type: 'base64';
media_type: string;
data: string;
};
}
| {
type: 'image_url';
image_url: {
url: string;
};
};
export type FormattedContentResult = [
string | FormattedContent[],
undefined | { content: FormattedContent[] },
];
export type ImageFormatter = (item: ImageContent) => FormattedContent;
export type FormattedToolResponse = [
string | FormattedContent[],
{ content: FormattedContent[] } | undefined,
];

View file

@ -0,0 +1,28 @@
import { normalizeServerName } from './utils';
describe('normalizeServerName', () => {
it('should not modify server names that already match the pattern', () => {
const result = normalizeServerName('valid-server_name.123');
expect(result).toBe('valid-server_name.123');
});
it('should normalize server names with non-ASCII characters', () => {
const result = normalizeServerName('我的服务');
// Should generate a fallback name with a hash
expect(result).toMatch(/^server_\d+$/);
expect(result).toMatch(/^[a-zA-Z0-9_.-]+$/);
});
it('should normalize server names with special characters', () => {
const result = normalizeServerName('server@name!');
// The actual result doesn't have the trailing underscore after trimming
expect(result).toBe('server_name');
expect(result).toMatch(/^[a-zA-Z0-9_.-]+$/);
});
it('should trim leading and trailing underscores', () => {
const result = normalizeServerName('!server-name!');
expect(result).toBe('server-name');
expect(result).toMatch(/^[a-zA-Z0-9_.-]+$/);
});
});

View file

@ -0,0 +1,30 @@
/**
* Normalizes a server name to match the pattern ^[a-zA-Z0-9_.-]+$
* This is required for Azure OpenAI models with Tool Calling
*/
export function normalizeServerName(serverName: string): string {
// Check if the server name already matches the pattern
if (/^[a-zA-Z0-9_.-]+$/.test(serverName)) {
return serverName;
}
/** Replace non-matching characters with underscores.
This preserves the general structure while ensuring compatibility.
Trims leading/trailing underscores
*/
const normalized = serverName.replace(/[^a-zA-Z0-9_.-]/g, '_').replace(/^_+|_+$/g, '');
// If the result is empty (e.g., all characters were non-ASCII and got trimmed),
// generate a fallback name to ensure we always have a valid function name
if (!normalized) {
/** Hash of the original name to ensure uniqueness */
let hash = 0;
for (let i = 0; i < serverName.length; i++) {
hash = (hash << 5) - hash + serverName.charCodeAt(i);
hash |= 0; // Convert to 32bit integer
}
return `server_${Math.abs(hash)}`;
}
return normalized;
}

View file

@ -0,0 +1,19 @@
/**
* Azure OpenAI configuration interface
*/
export interface AzureOptions {
azureOpenAIApiKey?: string;
azureOpenAIApiInstanceName?: string;
azureOpenAIApiDeploymentName?: string;
azureOpenAIApiVersion?: string;
azureOpenAIBasePath?: string;
}
/**
* Client with azure property for setting deployment name
*/
export interface GenericClient {
azure: {
azureOpenAIApiDeploymentName?: string;
};
}

View file

@ -0,0 +1,4 @@
export type ServerSentEvent = {
data: string | Record<string, unknown>;
event?: string;
};

View file

@ -0,0 +1,4 @@
export * from './azure';
export * from './events';
export * from './openai';
export * from './run';

View file

@ -0,0 +1,97 @@
import { z } from 'zod';
import { openAISchema, EModelEndpoint } from 'librechat-data-provider';
import type { TEndpointOption, TAzureConfig, TEndpoint } from 'librechat-data-provider';
import type { OpenAIClientOptions } from '@librechat/agents';
import type { AzureOptions } from './azure';
export type OpenAIParameters = z.infer<typeof openAISchema>;
/**
* Configuration options for the getLLMConfig function
*/
export interface LLMConfigOptions {
modelOptions?: Partial<OpenAIParameters>;
reverseProxyUrl?: string;
defaultQuery?: Record<string, string | undefined>;
headers?: Record<string, string>;
proxy?: string;
azure?: AzureOptions;
streaming?: boolean;
addParams?: Record<string, unknown>;
dropParams?: string[];
}
export type OpenAIConfiguration = OpenAIClientOptions['configuration'];
export type ClientOptions = OpenAIClientOptions & {
include_reasoning?: boolean;
};
/**
* Return type for getLLMConfig function
*/
export interface LLMConfigResult {
llmConfig: ClientOptions;
configOptions: OpenAIConfiguration;
}
/**
* Interface for user values retrieved from the database
*/
export interface UserKeyValues {
apiKey?: string;
baseURL?: string;
}
/**
* Request interface with only the properties we need (avoids Express typing conflicts)
*/
export interface RequestData {
user: {
id: string;
};
body: {
model?: string;
endpoint?: string;
key?: string;
};
app: {
locals: {
[EModelEndpoint.azureOpenAI]?: TAzureConfig;
[EModelEndpoint.openAI]?: TEndpoint;
all?: TEndpoint;
};
};
}
/**
* Function type for getting user key values
*/
export type GetUserKeyValuesFunction = (params: {
userId: string;
name: string;
}) => Promise<UserKeyValues>;
/**
* Function type for checking user key expiry
*/
export type CheckUserKeyExpiryFunction = (expiresAt: string, endpoint: string) => void;
/**
* Parameters for the initializeOpenAI function
*/
export interface InitializeOpenAIOptionsParams {
req: RequestData;
overrideModel?: string;
overrideEndpoint?: string;
endpointOption: Partial<TEndpointOption>;
getUserKeyValues: GetUserKeyValuesFunction;
checkUserKeyExpiry: CheckUserKeyExpiryFunction;
}
/**
* Extended LLM config result with stream rate handling
*/
export interface OpenAIOptionsResult extends LLMConfigResult {
streamRate?: number;
}

View file

@ -0,0 +1,10 @@
import type { AgentModelParameters, EModelEndpoint } from 'librechat-data-provider';
import type { OpenAIConfiguration } from './openai';
export type RunLLMConfig = {
provider: EModelEndpoint;
streaming: boolean;
streamUsage: boolean;
usage?: boolean;
configuration?: OpenAIConfiguration;
} & AgentModelParameters;

View file

@ -0,0 +1,269 @@
import {
genAzureChatCompletion,
getAzureCredentials,
constructAzureURL,
sanitizeModelName,
genAzureEndpoint,
} from './azure';
import type { GenericClient } from '~/types';
describe('sanitizeModelName', () => {
test('removes periods from the model name', () => {
const sanitized = sanitizeModelName('model.name');
expect(sanitized).toBe('modelname');
});
test('leaves model name unchanged if no periods are present', () => {
const sanitized = sanitizeModelName('modelname');
expect(sanitized).toBe('modelname');
});
});
describe('genAzureEndpoint', () => {
test('generates correct endpoint URL', () => {
const url = genAzureEndpoint({
azureOpenAIApiInstanceName: 'instanceName',
azureOpenAIApiDeploymentName: 'deploymentName',
});
expect(url).toBe('https://instanceName.openai.azure.com/openai/deployments/deploymentName');
});
});
describe('genAzureChatCompletion', () => {
// Test with both deployment name and model name provided
test('prefers model name over deployment name when both are provided and feature enabled', () => {
process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME = 'true';
const url = genAzureChatCompletion(
{
azureOpenAIApiInstanceName: 'instanceName',
azureOpenAIApiDeploymentName: 'deploymentName',
azureOpenAIApiVersion: 'v1',
},
'modelName',
);
expect(url).toBe(
'https://instanceName.openai.azure.com/openai/deployments/modelName/chat/completions?api-version=v1',
);
});
// Test with only deployment name provided
test('uses deployment name when model name is not provided', () => {
const url = genAzureChatCompletion({
azureOpenAIApiInstanceName: 'instanceName',
azureOpenAIApiDeploymentName: 'deploymentName',
azureOpenAIApiVersion: 'v1',
});
expect(url).toBe(
'https://instanceName.openai.azure.com/openai/deployments/deploymentName/chat/completions?api-version=v1',
);
});
// Test with only model name provided
test('uses model name when deployment name is not provided and feature enabled', () => {
process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME = 'true';
const url = genAzureChatCompletion(
{
azureOpenAIApiInstanceName: 'instanceName',
azureOpenAIApiVersion: 'v1',
},
'modelName',
);
expect(url).toBe(
'https://instanceName.openai.azure.com/openai/deployments/modelName/chat/completions?api-version=v1',
);
});
// Test with neither deployment name nor model name provided
test('throws error if neither deployment name nor model name is provided', () => {
expect(() => {
genAzureChatCompletion({
azureOpenAIApiInstanceName: 'instanceName',
azureOpenAIApiVersion: 'v1',
});
}).toThrow(
'Either a model name with the `AZURE_USE_MODEL_AS_DEPLOYMENT_NAME` setting or a deployment name must be provided if `AZURE_OPENAI_BASEURL` is omitted.',
);
});
// Test with feature disabled but model name provided
test('ignores model name and uses deployment name when feature is disabled', () => {
process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME = 'false';
const url = genAzureChatCompletion(
{
azureOpenAIApiInstanceName: 'instanceName',
azureOpenAIApiDeploymentName: 'deploymentName',
azureOpenAIApiVersion: 'v1',
},
'modelName',
);
expect(url).toBe(
'https://instanceName.openai.azure.com/openai/deployments/deploymentName/chat/completions?api-version=v1',
);
});
// Test with sanitized model name
test('sanitizes model name when used in URL', () => {
process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME = 'true';
const url = genAzureChatCompletion(
{
azureOpenAIApiInstanceName: 'instanceName',
azureOpenAIApiVersion: 'v1',
},
'model.name',
);
expect(url).toBe(
'https://instanceName.openai.azure.com/openai/deployments/modelname/chat/completions?api-version=v1',
);
});
// Test with client parameter and model name
test('updates client with sanitized model name when provided and feature enabled', () => {
process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME = 'true';
const clientMock = { azure: {} } as GenericClient;
const url = genAzureChatCompletion(
{
azureOpenAIApiInstanceName: 'instanceName',
azureOpenAIApiVersion: 'v1',
},
'model.name',
clientMock,
);
expect(url).toBe(
'https://instanceName.openai.azure.com/openai/deployments/modelname/chat/completions?api-version=v1',
);
expect(clientMock.azure.azureOpenAIApiDeploymentName).toBe('modelname');
});
// Test with client parameter but without model name
test('does not update client when model name is not provided', () => {
const clientMock = { azure: {} } as GenericClient;
const url = genAzureChatCompletion(
{
azureOpenAIApiInstanceName: 'instanceName',
azureOpenAIApiDeploymentName: 'deploymentName',
azureOpenAIApiVersion: 'v1',
},
undefined,
clientMock,
);
expect(url).toBe(
'https://instanceName.openai.azure.com/openai/deployments/deploymentName/chat/completions?api-version=v1',
);
expect(clientMock.azure.azureOpenAIApiDeploymentName).toBeUndefined();
});
// Test with client parameter and deployment name when feature is disabled
test('does not update client when feature is disabled', () => {
process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME = 'false';
const clientMock = { azure: {} } as GenericClient;
const url = genAzureChatCompletion(
{
azureOpenAIApiInstanceName: 'instanceName',
azureOpenAIApiDeploymentName: 'deploymentName',
azureOpenAIApiVersion: 'v1',
},
'modelName',
clientMock,
);
expect(url).toBe(
'https://instanceName.openai.azure.com/openai/deployments/deploymentName/chat/completions?api-version=v1',
);
expect(clientMock.azure.azureOpenAIApiDeploymentName).toBeUndefined();
});
// Reset environment variable after tests
afterEach(() => {
delete process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME;
});
});
describe('getAzureCredentials', () => {
beforeEach(() => {
process.env.AZURE_API_KEY = 'testApiKey';
process.env.AZURE_OPENAI_API_INSTANCE_NAME = 'instanceName';
process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME = 'deploymentName';
process.env.AZURE_OPENAI_API_VERSION = 'v1';
});
test('retrieves Azure OpenAI API credentials from environment variables', () => {
const credentials = getAzureCredentials();
expect(credentials).toEqual({
azureOpenAIApiKey: 'testApiKey',
azureOpenAIApiInstanceName: 'instanceName',
azureOpenAIApiDeploymentName: 'deploymentName',
azureOpenAIApiVersion: 'v1',
});
});
});
describe('constructAzureURL', () => {
test('replaces both placeholders when both properties are provided', () => {
const url = constructAzureURL({
baseURL: 'https://example.com/${INSTANCE_NAME}/${DEPLOYMENT_NAME}',
azureOptions: {
azureOpenAIApiInstanceName: 'instance1',
azureOpenAIApiDeploymentName: 'deployment1',
},
});
expect(url).toBe('https://example.com/instance1/deployment1');
});
test('replaces only INSTANCE_NAME when only azureOpenAIApiInstanceName is provided', () => {
const url = constructAzureURL({
baseURL: 'https://example.com/${INSTANCE_NAME}/${DEPLOYMENT_NAME}',
azureOptions: {
azureOpenAIApiInstanceName: 'instance2',
},
});
expect(url).toBe('https://example.com/instance2/');
});
test('replaces only DEPLOYMENT_NAME when only azureOpenAIApiDeploymentName is provided', () => {
const url = constructAzureURL({
baseURL: 'https://example.com/${INSTANCE_NAME}/${DEPLOYMENT_NAME}',
azureOptions: {
azureOpenAIApiDeploymentName: 'deployment2',
},
});
expect(url).toBe('https://example.com//deployment2');
});
test('does not replace any placeholders when azure object is empty', () => {
const url = constructAzureURL({
baseURL: 'https://example.com/${INSTANCE_NAME}/${DEPLOYMENT_NAME}',
azureOptions: {},
});
expect(url).toBe('https://example.com//');
});
test('returns baseURL as is when `azureOptions` object is not provided', () => {
const url = constructAzureURL({
baseURL: 'https://example.com/${INSTANCE_NAME}/${DEPLOYMENT_NAME}',
});
expect(url).toBe('https://example.com/${INSTANCE_NAME}/${DEPLOYMENT_NAME}');
});
test('returns baseURL as is when no placeholders are set', () => {
const url = constructAzureURL({
baseURL: 'https://example.com/my_custom_instance/my_deployment',
azureOptions: {
azureOpenAIApiInstanceName: 'instance1',
azureOpenAIApiDeploymentName: 'deployment1',
},
});
expect(url).toBe('https://example.com/my_custom_instance/my_deployment');
});
test('returns regular Azure OpenAI baseURL with placeholders set', () => {
const baseURL =
'https://${INSTANCE_NAME}.openai.azure.com/openai/deployments/${DEPLOYMENT_NAME}';
const url = constructAzureURL({
baseURL,
azureOptions: {
azureOpenAIApiInstanceName: 'instance1',
azureOpenAIApiDeploymentName: 'deployment1',
},
});
expect(url).toBe('https://instance1.openai.azure.com/openai/deployments/deployment1');
});
});

View file

@ -0,0 +1,120 @@
import { isEnabled } from './common';
import type { AzureOptions, GenericClient } from '~/types';
/**
* Sanitizes the model name to be used in the URL by removing or replacing disallowed characters.
* @param modelName - The model name to be sanitized.
* @returns The sanitized model name.
*/
export const sanitizeModelName = (modelName: string): string => {
// Replace periods with empty strings and other disallowed characters as needed.
return modelName.replace(/\./g, '');
};
/**
* Generates the Azure OpenAI API endpoint URL.
* @param params - The parameters object.
* @param params.azureOpenAIApiInstanceName - The Azure OpenAI API instance name.
* @param params.azureOpenAIApiDeploymentName - The Azure OpenAI API deployment name.
* @returns The complete endpoint URL for the Azure OpenAI API.
*/
export const genAzureEndpoint = ({
azureOpenAIApiInstanceName,
azureOpenAIApiDeploymentName,
}: {
azureOpenAIApiInstanceName: string;
azureOpenAIApiDeploymentName: string;
}): string => {
return `https://${azureOpenAIApiInstanceName}.openai.azure.com/openai/deployments/${azureOpenAIApiDeploymentName}`;
};
/**
* Generates the Azure OpenAI API chat completion endpoint URL with the API version.
* If both deploymentName and modelName are provided, modelName takes precedence.
* @param azureConfig - The Azure configuration object.
* @param azureConfig.azureOpenAIApiInstanceName - The Azure OpenAI API instance name.
* @param azureConfig.azureOpenAIApiDeploymentName - The Azure OpenAI API deployment name (optional).
* @param azureConfig.azureOpenAIApiVersion - The Azure OpenAI API version.
* @param modelName - The model name to be included in the deployment name (optional).
* @param client - The API Client class for optionally setting properties (optional).
* @returns The complete chat completion endpoint URL for the Azure OpenAI API.
* @throws Error if neither azureOpenAIApiDeploymentName nor modelName is provided.
*/
export const genAzureChatCompletion = (
{
azureOpenAIApiInstanceName,
azureOpenAIApiDeploymentName,
azureOpenAIApiVersion,
}: {
azureOpenAIApiInstanceName: string;
azureOpenAIApiDeploymentName?: string;
azureOpenAIApiVersion: string;
},
modelName?: string,
client?: GenericClient,
): string => {
// Determine the deployment segment of the URL based on provided modelName or azureOpenAIApiDeploymentName
let deploymentSegment: string;
if (isEnabled(process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME) && modelName) {
const sanitizedModelName = sanitizeModelName(modelName);
deploymentSegment = sanitizedModelName;
if (client && typeof client === 'object') {
client.azure.azureOpenAIApiDeploymentName = sanitizedModelName;
}
} else if (azureOpenAIApiDeploymentName) {
deploymentSegment = azureOpenAIApiDeploymentName;
} else if (!process.env.AZURE_OPENAI_BASEURL) {
throw new Error(
'Either a model name with the `AZURE_USE_MODEL_AS_DEPLOYMENT_NAME` setting or a deployment name must be provided if `AZURE_OPENAI_BASEURL` is omitted.',
);
} else {
deploymentSegment = '';
}
return `https://${azureOpenAIApiInstanceName}.openai.azure.com/openai/deployments/${deploymentSegment}/chat/completions?api-version=${azureOpenAIApiVersion}`;
};
/**
* Retrieves the Azure OpenAI API credentials from environment variables.
* @returns An object containing the Azure OpenAI API credentials.
*/
export const getAzureCredentials = (): AzureOptions => {
return {
azureOpenAIApiKey: process.env.AZURE_API_KEY ?? process.env.AZURE_OPENAI_API_KEY,
azureOpenAIApiInstanceName: process.env.AZURE_OPENAI_API_INSTANCE_NAME,
azureOpenAIApiDeploymentName: process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME,
azureOpenAIApiVersion: process.env.AZURE_OPENAI_API_VERSION,
};
};
/**
* Constructs a URL by replacing placeholders in the baseURL with values from the azure object.
* It specifically looks for '${INSTANCE_NAME}' and '${DEPLOYMENT_NAME}' within the baseURL and replaces
* them with 'azureOpenAIApiInstanceName' and 'azureOpenAIApiDeploymentName' from the azure object.
* If the respective azure property is not provided, the placeholder is replaced with an empty string.
*
* @param params - The parameters object.
* @param params.baseURL - The baseURL to inspect for replacement placeholders.
* @param params.azureOptions - The azure options object containing the instance and deployment names.
* @returns The complete baseURL with credentials injected for the Azure OpenAI API.
*/
export function constructAzureURL({
baseURL,
azureOptions,
}: {
baseURL: string;
azureOptions?: AzureOptions;
}): string {
let finalURL = baseURL;
// Replace INSTANCE_NAME and DEPLOYMENT_NAME placeholders with actual values if available
if (azureOptions) {
finalURL = finalURL.replace('${INSTANCE_NAME}', azureOptions.azureOpenAIApiInstanceName ?? '');
finalURL = finalURL.replace(
'${DEPLOYMENT_NAME}',
azureOptions.azureOpenAIApiDeploymentName ?? '',
);
}
return finalURL;
}

View file

@ -0,0 +1,55 @@
/* eslint-disable @typescript-eslint/ban-ts-comment */
import { isEnabled } from './common';
describe('isEnabled', () => {
test('should return true when input is "true"', () => {
expect(isEnabled('true')).toBe(true);
});
test('should return true when input is "TRUE"', () => {
expect(isEnabled('TRUE')).toBe(true);
});
test('should return true when input is true', () => {
expect(isEnabled(true)).toBe(true);
});
test('should return false when input is "false"', () => {
expect(isEnabled('false')).toBe(false);
});
test('should return false when input is false', () => {
expect(isEnabled(false)).toBe(false);
});
test('should return false when input is null', () => {
expect(isEnabled(null)).toBe(false);
});
test('should return false when input is undefined', () => {
expect(isEnabled()).toBe(false);
});
test('should return false when input is an empty string', () => {
expect(isEnabled('')).toBe(false);
});
test('should return false when input is a whitespace string', () => {
expect(isEnabled(' ')).toBe(false);
});
test('should return false when input is a number', () => {
// @ts-expect-error
expect(isEnabled(123)).toBe(false);
});
test('should return false when input is an object', () => {
// @ts-expect-error
expect(isEnabled({})).toBe(false);
});
test('should return false when input is an array', () => {
// @ts-expect-error
expect(isEnabled([])).toBe(false);
});
});

View file

@ -0,0 +1,48 @@
/**
* Checks if the given value is truthy by being either the boolean `true` or a string
* that case-insensitively matches 'true'.
*
* @param value - The value to check.
* @returns Returns `true` if the value is the boolean `true` or a case-insensitive
* match for the string 'true', otherwise returns `false`.
* @example
*
* isEnabled("True"); // returns true
* isEnabled("TRUE"); // returns true
* isEnabled(true); // returns true
* isEnabled("false"); // returns false
* isEnabled(false); // returns false
* isEnabled(null); // returns false
* isEnabled(); // returns false
*/
export function isEnabled(value?: string | boolean | null | undefined): boolean {
if (typeof value === 'boolean') {
return value;
}
if (typeof value === 'string') {
return value.toLowerCase().trim() === 'true';
}
return false;
}
/**
* Checks if the provided value is 'user_provided'.
*
* @param value - The value to check.
* @returns - Returns true if the value is 'user_provided', otherwise false.
*/
export const isUserProvided = (value?: string): boolean => value === 'user_provided';
/**
* @param values
*/
export function optionalChainWithEmptyCheck(
...values: (string | number | undefined)[]
): string | number | undefined {
for (const value of values) {
if (value !== undefined && value !== null && value !== '') {
return value;
}
}
return values[values.length - 1];
}

View file

@ -0,0 +1,16 @@
import type { Response as ServerResponse } from 'express';
import type { ServerSentEvent } from '~/types';
/**
* Sends message data in Server Sent Events format.
* @param res - The server response.
* @param event - The message event.
* @param event.event - The type of event.
* @param event.data - The message to be sent.
*/
export function sendEvent(res: ServerResponse, event: ServerSentEvent): void {
if (typeof event.data === 'string' && event.data.length === 0) {
return;
}
res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`);
}

View file

@ -0,0 +1,115 @@
import { sanitizeFilename } from './files';
jest.mock('node:crypto', () => {
const actualModule = jest.requireActual('node:crypto');
return {
...actualModule,
randomBytes: jest.fn().mockReturnValue(Buffer.from('abc123', 'hex')),
};
});
describe('sanitizeFilename', () => {
test('removes directory components (1/2)', () => {
expect(sanitizeFilename('/path/to/file.txt')).toBe('file.txt');
});
test('removes directory components (2/2)', () => {
expect(sanitizeFilename('../../../../file.txt')).toBe('file.txt');
});
test('replaces non-alphanumeric characters', () => {
expect(sanitizeFilename('file name@#$.txt')).toBe('file_name___.txt');
});
test('preserves dots and hyphens', () => {
expect(sanitizeFilename('file-name.with.dots.txt')).toBe('file-name.with.dots.txt');
});
test('prepends underscore to filenames starting with a dot', () => {
expect(sanitizeFilename('.hiddenfile')).toBe('_.hiddenfile');
});
test('truncates long filenames', () => {
const longName = 'a'.repeat(300) + '.txt';
const result = sanitizeFilename(longName);
expect(result.length).toBe(255);
expect(result).toMatch(/^a+-abc123\.txt$/);
});
test('handles filenames with no extension', () => {
const longName = 'a'.repeat(300);
const result = sanitizeFilename(longName);
expect(result.length).toBe(255);
expect(result).toMatch(/^a+-abc123$/);
});
test('handles empty input', () => {
expect(sanitizeFilename('')).toBe('_');
});
test('handles input with only special characters', () => {
expect(sanitizeFilename('@#$%^&*')).toBe('_______');
});
});
describe('sanitizeFilename with real crypto', () => {
// Temporarily unmock crypto for these tests
beforeAll(() => {
jest.resetModules();
jest.unmock('node:crypto');
});
afterAll(() => {
jest.resetModules();
jest.mock('node:crypto', () => {
const actualModule = jest.requireActual('node:crypto');
return {
...actualModule,
randomBytes: jest.fn().mockReturnValue(Buffer.from('abc123', 'hex')),
};
});
});
test('truncates long filenames with real crypto', async () => {
const { sanitizeFilename: realSanitizeFilename } = await import('./files');
const longName = 'b'.repeat(300) + '.pdf';
const result = realSanitizeFilename(longName);
expect(result.length).toBe(255);
expect(result).toMatch(/^b+-[a-f0-9]{6}\.pdf$/);
expect(result.endsWith('.pdf')).toBe(true);
});
test('handles filenames with no extension with real crypto', async () => {
const { sanitizeFilename: realSanitizeFilename } = await import('./files');
const longName = 'c'.repeat(300);
const result = realSanitizeFilename(longName);
expect(result.length).toBe(255);
expect(result).toMatch(/^c+-[a-f0-9]{6}$/);
expect(result).not.toContain('.');
});
test('generates unique suffixes for identical long filenames', async () => {
const { sanitizeFilename: realSanitizeFilename } = await import('./files');
const longName = 'd'.repeat(300) + '.doc';
const result1 = realSanitizeFilename(longName);
const result2 = realSanitizeFilename(longName);
expect(result1.length).toBe(255);
expect(result2.length).toBe(255);
expect(result1).not.toBe(result2); // Should be different due to random suffix
expect(result1.endsWith('.doc')).toBe(true);
expect(result2.endsWith('.doc')).toBe(true);
});
test('real crypto produces valid hex strings', async () => {
const { sanitizeFilename: realSanitizeFilename } = await import('./files');
const longName = 'test'.repeat(100) + '.txt';
const result = realSanitizeFilename(longName);
const hexMatch = result.match(/-([a-f0-9]{6})\.txt$/);
expect(hexMatch).toBeTruthy();
expect(hexMatch![1]).toMatch(/^[a-f0-9]{6}$/);
});
});

View file

@ -0,0 +1,33 @@
import path from 'path';
import crypto from 'node:crypto';
/**
* Sanitize a filename by removing any directory components, replacing non-alphanumeric characters
* @param inputName
*/
export function sanitizeFilename(inputName: string): string {
// Remove any directory components
let name = path.basename(inputName);
// Replace any non-alphanumeric characters except for '.' and '-'
name = name.replace(/[^a-zA-Z0-9.-]/g, '_');
// Ensure the name doesn't start with a dot (hidden file in Unix-like systems)
if (name.startsWith('.') || name === '') {
name = '_' + name;
}
// Limit the length of the filename
const MAX_LENGTH = 255;
if (name.length > MAX_LENGTH) {
const ext = path.extname(name);
const nameWithoutExt = path.basename(name, ext);
name =
nameWithoutExt.slice(0, MAX_LENGTH - ext.length - 7) +
'-' +
crypto.randomBytes(3).toString('hex') +
ext;
}
return name;
}

View file

@ -0,0 +1,75 @@
import fetch from 'node-fetch';
import { logger } from '@librechat/data-schemas';
import { GraphEvents, sleep } from '@librechat/agents';
import type { Response as ServerResponse } from 'express';
import type { ServerSentEvent } from '~/types';
import { sendEvent } from './events';
/**
* Makes a function to make HTTP request and logs the process.
* @param params
* @param params.directEndpoint - Whether to use a direct endpoint.
* @param params.reverseProxyUrl - The reverse proxy URL to use for the request.
* @returns A promise that resolves to the response of the fetch request.
*/
export function createFetch({
directEndpoint = false,
reverseProxyUrl = '',
}: {
directEndpoint?: boolean;
reverseProxyUrl?: string;
}) {
/**
* Makes an HTTP request and logs the process.
* @param url - The URL to make the request to. Can be a string or a Request object.
* @param init - Optional init options for the request.
* @returns A promise that resolves to the response of the fetch request.
*/
return async function (
_url: fetch.RequestInfo,
init: fetch.RequestInit,
): Promise<fetch.Response> {
let url = _url;
if (directEndpoint) {
url = reverseProxyUrl;
}
logger.debug(`Making request to ${url}`);
if (typeof Bun !== 'undefined') {
return await fetch(url, init);
}
return await fetch(url, init);
};
}
/**
* Creates event handlers for stream events that don't capture client references
* @param res - The response object to send events to
* @returns Object containing handler functions
*/
export function createStreamEventHandlers(res: ServerResponse) {
return {
[GraphEvents.ON_RUN_STEP]: function (event: ServerSentEvent) {
if (res) {
sendEvent(res, event);
}
},
[GraphEvents.ON_MESSAGE_DELTA]: function (event: ServerSentEvent) {
if (res) {
sendEvent(res, event);
}
},
[GraphEvents.ON_REASONING_DELTA]: function (event: ServerSentEvent) {
if (res) {
sendEvent(res, event);
}
},
};
}
export function createHandleLLMNewToken(streamRate: number) {
return async function () {
if (streamRate) {
await sleep(streamRate);
}
};
}

View file

@ -0,0 +1,5 @@
export * from './azure';
export * from './common';
export * from './events';
export * from './generators';
export { default as Tokenizer } from './tokenizer';

View file

@ -0,0 +1,143 @@
/**
* @file Tokenizer.spec.cjs
*
* Tests the real TokenizerSingleton (no mocking of `tiktoken`).
* Make sure to install `tiktoken` and have it configured properly.
*/
import { logger } from '@librechat/data-schemas';
import type { Tiktoken } from 'tiktoken';
import Tokenizer from './tokenizer';
jest.mock('@librechat/data-schemas', () => ({
logger: {
error: jest.fn(),
},
}));
describe('Tokenizer', () => {
it('should be a singleton (same instance)', async () => {
const AnotherTokenizer = await import('./tokenizer'); // same path
expect(Tokenizer).toBe(AnotherTokenizer.default);
});
describe('getTokenizer', () => {
it('should create an encoder for an explicit model name (e.g., "gpt-4")', () => {
// The real `encoding_for_model` will be called internally
// as soon as we pass isModelName = true.
const tokenizer = Tokenizer.getTokenizer('gpt-4', true);
// Basic sanity checks
expect(tokenizer).toBeDefined();
// You can optionally check certain properties from `tiktoken` if they exist
// e.g., expect(typeof tokenizer.encode).toBe('function');
});
it('should create an encoder for a known encoding (e.g., "cl100k_base")', () => {
// The real `get_encoding` will be called internally
// as soon as we pass isModelName = false.
const tokenizer = Tokenizer.getTokenizer('cl100k_base', false);
expect(tokenizer).toBeDefined();
// e.g., expect(typeof tokenizer.encode).toBe('function');
});
it('should return cached tokenizer if previously fetched', () => {
const tokenizer1 = Tokenizer.getTokenizer('cl100k_base', false);
const tokenizer2 = Tokenizer.getTokenizer('cl100k_base', false);
// Should be the exact same instance from the cache
expect(tokenizer1).toBe(tokenizer2);
});
});
describe('freeAndResetAllEncoders', () => {
beforeEach(() => {
jest.clearAllMocks();
});
it('should free all encoders and reset tokenizerCallsCount to 1', () => {
// By creating two different encodings, we populate the cache
Tokenizer.getTokenizer('cl100k_base', false);
Tokenizer.getTokenizer('r50k_base', false);
// Now free them
Tokenizer.freeAndResetAllEncoders();
// The internal cache is cleared
expect(Tokenizer.tokenizersCache['cl100k_base']).toBeUndefined();
expect(Tokenizer.tokenizersCache['r50k_base']).toBeUndefined();
// tokenizerCallsCount is reset to 1
expect(Tokenizer.tokenizerCallsCount).toBe(1);
});
it('should catch and log errors if freeing fails', () => {
// Mock logger.error before the test
const mockLoggerError = jest.spyOn(logger, 'error');
// Set up a problematic tokenizer in the cache
Tokenizer.tokenizersCache['cl100k_base'] = {
free() {
throw new Error('Intentional free error');
},
} as unknown as Tiktoken;
// Should not throw uncaught errors
Tokenizer.freeAndResetAllEncoders();
// Verify logger.error was called with correct arguments
expect(mockLoggerError).toHaveBeenCalledWith(
'[Tokenizer] Free and reset encoders error',
expect.any(Error),
);
// Clean up
mockLoggerError.mockRestore();
Tokenizer.tokenizersCache = {};
});
});
describe('getTokenCount', () => {
beforeEach(() => {
jest.clearAllMocks();
Tokenizer.freeAndResetAllEncoders();
});
it('should return the number of tokens in the given text', () => {
const text = 'Hello, world!';
const count = Tokenizer.getTokenCount(text, 'cl100k_base');
expect(count).toBeGreaterThan(0);
});
it('should reset encoders if an error is thrown', () => {
// We can simulate an error by temporarily overriding the selected tokenizer's `encode` method.
const tokenizer = Tokenizer.getTokenizer('cl100k_base', false);
const originalEncode = tokenizer.encode;
tokenizer.encode = () => {
throw new Error('Forced error');
};
// Despite the forced error, the code should catch and reset, then re-encode
const count = Tokenizer.getTokenCount('Hello again', 'cl100k_base');
expect(count).toBeGreaterThan(0);
// Restore the original encode
tokenizer.encode = originalEncode;
});
it('should reset tokenizers after 25 calls', () => {
// Spy on freeAndResetAllEncoders
const resetSpy = jest.spyOn(Tokenizer, 'freeAndResetAllEncoders');
// Make 24 calls; should NOT reset yet
for (let i = 0; i < 24; i++) {
Tokenizer.getTokenCount('test text', 'cl100k_base');
}
expect(resetSpy).not.toHaveBeenCalled();
// 25th call triggers the reset
Tokenizer.getTokenCount('the 25th call!', 'cl100k_base');
expect(resetSpy).toHaveBeenCalledTimes(1);
});
});
});

View file

@ -0,0 +1,78 @@
import { logger } from '@librechat/data-schemas';
import { encoding_for_model as encodingForModel, get_encoding as getEncoding } from 'tiktoken';
import type { Tiktoken, TiktokenModel, TiktokenEncoding } from 'tiktoken';
interface TokenizerOptions {
debug?: boolean;
}
class Tokenizer {
tokenizersCache: Record<string, Tiktoken>;
tokenizerCallsCount: number;
private options?: TokenizerOptions;
constructor() {
this.tokenizersCache = {};
this.tokenizerCallsCount = 0;
}
getTokenizer(
encoding: TiktokenModel | TiktokenEncoding,
isModelName = false,
extendSpecialTokens: Record<string, number> = {},
): Tiktoken {
let tokenizer: Tiktoken;
if (this.tokenizersCache[encoding]) {
tokenizer = this.tokenizersCache[encoding];
} else {
if (isModelName) {
tokenizer = encodingForModel(encoding as TiktokenModel, extendSpecialTokens);
} else {
tokenizer = getEncoding(encoding as TiktokenEncoding, extendSpecialTokens);
}
this.tokenizersCache[encoding] = tokenizer;
}
return tokenizer;
}
freeAndResetAllEncoders(): void {
try {
Object.keys(this.tokenizersCache).forEach((key) => {
if (this.tokenizersCache[key]) {
this.tokenizersCache[key].free();
delete this.tokenizersCache[key];
}
});
this.tokenizerCallsCount = 1;
} catch (error) {
logger.error('[Tokenizer] Free and reset encoders error', error);
}
}
resetTokenizersIfNecessary(): void {
if (this.tokenizerCallsCount >= 25) {
if (this.options?.debug) {
logger.debug('[Tokenizer] freeAndResetAllEncoders: reached 25 encodings, resetting...');
}
this.freeAndResetAllEncoders();
}
this.tokenizerCallsCount++;
}
getTokenCount(text: string, encoding: TiktokenModel | TiktokenEncoding = 'cl100k_base'): number {
this.resetTokenizersIfNecessary();
try {
const tokenizer = this.getTokenizer(encoding);
return tokenizer.encode(text, 'all').length;
} catch (error) {
logger.error('[Tokenizer] Error getting token count:', error);
this.freeAndResetAllEncoders();
const tokenizer = this.getTokenizer(encoding);
return tokenizer.encode(text, 'all').length;
}
}
}
const TokenizerSingleton = new Tokenizer();
export default TokenizerSingleton;

View file

@ -0,0 +1,23 @@
import path from 'path';
import { pathToFileURL } from 'url';
// @ts-ignore
import { resolve as resolveTs } from 'ts-node/esm';
import * as tsConfigPaths from 'tsconfig-paths';
// @ts-ignore
const { absoluteBaseUrl, paths } = tsConfigPaths.loadConfig(
path.resolve('./tsconfig.json'), // Updated path
);
const matchPath = tsConfigPaths.createMatchPath(absoluteBaseUrl, paths);
export function resolve(specifier, context, defaultResolve) {
const match = matchPath(specifier);
if (match) {
return resolveTs(pathToFileURL(match).href, context, defaultResolve);
}
return resolveTs(specifier, context, defaultResolve);
}
// @ts-ignore
export { load, getFormat, transformSource } from 'ts-node/esm';
// node -r dotenv/config --loader ./tsconfig-paths-bootstrap.mjs --experimental-specifier-resolution=node ../../api/demo/everything.ts

View file

@ -0,0 +1,33 @@
{
"compilerOptions": {
"declaration": true,
"declarationDir": "./dist/types",
"module": "esnext",
"noImplicitAny": true,
"outDir": "./types",
"target": "es2015",
"moduleResolution": "node",
"allowSyntheticDefaultImports": true,
"lib": ["es2017", "dom", "ES2021.String"],
"allowJs": true,
"skipLibCheck": true,
"esModuleInterop": true,
"strict": true,
"forceConsistentCasingInFileNames": true,
"resolveJsonModule": true,
"isolatedModules": true,
"noEmit": true,
"sourceMap": true,
"baseUrl": ".",
"paths": {
"~/*": ["./src/*"]
}
},
"ts-node": {
"experimentalSpecifierResolution": "node",
"transpileOnly": true,
"esm": true
},
"exclude": ["node_modules", "dist", "types"],
"include": ["src/**/*", "types/index.d.ts", "types/react-query/index.d.ts"]
}

View file

@ -0,0 +1,10 @@
{
"extends": "./tsconfig.json",
"compilerOptions": {
"noEmit": true,
"outDir": "./dist/tests",
"baseUrl": "."
},
"include": ["specs/**/*", "src/**/*"],
"exclude": ["node_modules", "dist"]
}