mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-02-20 01:18:10 +01:00
🧠 feat: User Memories for Conversational Context (#7760)
* 🧠 feat: User Memories for Conversational Context
chore: mcp typing, use `t`
WIP: first pass, Memories UI
- Added MemoryViewer component for displaying, editing, and deleting user memories.
- Integrated data provider hooks for fetching, updating, and deleting memories.
- Implemented pagination and loading states for better user experience.
- Created unit tests for MemoryViewer to ensure functionality and interaction with data provider.
- Updated translation files to include new UI strings related to memories.
chore: move mcp-related files to own directory
chore: rename librechat-mcp to librechat-api
WIP: first pass, memory processing and data schemas
chore: linting in fileSearch.js query description
chore: rename librechat-api to @librechat/api across the project
WIP: first pass, functional memory agent
feat: add MemoryEditDialog and MemoryViewer components for managing user memories
- Introduced MemoryEditDialog for editing memory entries with validation and toast notifications.
- Updated MemoryViewer to support editing and deleting memories, including pagination and loading states.
- Enhanced data provider to handle memory updates with optional original key for better management.
- Added new localization strings for memory-related UI elements.
feat: add memory permissions management
- Implemented memory permissions in the backend, allowing roles to have specific permissions for using, creating, updating, and reading memories.
- Added new API endpoints for updating memory permissions associated with roles.
- Created a new AdminSettings component for managing memory permissions in the frontend.
- Integrated memory permissions into the existing roles and permissions schemas.
- Updated the interface to include memory settings and permissions.
- Enhanced the MemoryViewer component to conditionally render admin settings based on user roles.
- Added localization support for memory permissions in the translation files.
feat: move AdminSettings component to a new position in MemoryViewer for better visibility
refactor: clean up commented code in MemoryViewer component
feat: enhance MemoryViewer with search functionality and improve MemoryEditDialog integration
- Added a search input to filter memories in the MemoryViewer component.
- Refactored MemoryEditDialog to accept children for better customization.
- Updated MemoryViewer to utilize the new EditMemoryButton and DeleteMemoryButton components for editing and deleting memories.
- Improved localization support by adding new strings for memory filtering and deletion confirmation.
refactor: optimize memory filtering in MemoryViewer using match-sorter
- Replaced manual filtering logic with match-sorter for improved search functionality.
- Enhanced performance and readability of the filteredMemories computation.
feat: enhance MemoryEditDialog with triggerRef and improve updateMemory mutation handling
feat: implement access control for MemoryEditDialog and MemoryViewer components
refactor: remove commented out code and create runMemory method
refactor: rename role based files
feat: implement access control for memory usage in AgentClient
refactor: simplify checkVisionRequest method in AgentClient by removing commented-out code
refactor: make `agents` dir in api package
refactor: migrate Azure utilities to TypeScript and consolidate imports
refactor: move sanitizeFilename function to a new file and update imports, add related tests
refactor: update LLM configuration types and consolidate Azure options in the API package
chore: linting
chore: import order
refactor: replace getLLMConfig with getOpenAIConfig and remove unused LLM configuration file
chore: update winston-daily-rotate-file to version 5.0.0 and add object-hash dependency in package-lock.json
refactor: move primeResources and optionalChainWithEmptyCheck functions to resources.ts and update imports
refactor: move createRun function to a new run.ts file and update related imports
fix: ensure safeAttachments is correctly typed as an array of TFile
chore: add node-fetch dependency and refactor fetch-related functions into packages/api/utils, removing the old generators file
refactor: enhance TEndpointOption type by using Pick to streamline endpoint fields and add new properties for model parameters and client options
feat: implement initializeOpenAIOptions function and update OpenAI types for enhanced configuration handling
fix: update types due to new TEndpointOption typing
fix: ensure safe access to group parameters in initializeOpenAIOptions function
fix: remove redundant API key validation comment in initializeOpenAIOptions function
refactor: rename initializeOpenAIOptions to initializeOpenAI for consistency and update related documentation
refactor: decouple req.body fields and tool loading from initializeAgentOptions
chore: linting
refactor: adjust column widths in MemoryViewer for improved layout
refactor: simplify agent initialization by creating loadAgent function and removing unused code
feat: add memory configuration loading and validation functions
WIP: first pass, memory processing with config
feat: implement memory callback and artifact handling
feat: implement memory artifacts display and processing updates
feat: add memory configuration options and schema validation for validKeys
fix: update MemoryEditDialog and MemoryViewer to handle memory state and display improvements
refactor: remove padding from BookmarkTable and MemoryViewer headers for consistent styling
WIP: initial tokenLimit config and move Tokenizer to @librechat/api
refactor: update mongoMeili plugin methods to use callback for better error handling
feat: enhance memory management with token tracking and usage metrics
- Added token counting for memory entries to enforce limits and provide usage statistics.
- Updated memory retrieval and update routes to include total token usage and limit.
- Enhanced MemoryEditDialog and MemoryViewer components to display memory usage and token information.
- Refactored memory processing functions to handle token limits and provide feedback on memory capacity.
feat: implement memory artifact handling in attachment handler
- Enhanced useAttachmentHandler to process memory artifacts when receiving updates.
- Introduced handleMemoryArtifact utility to manage memory updates and deletions.
- Updated query client to reflect changes in memory state based on incoming data.
refactor: restructure web search key extraction logic
- Moved the logic for extracting API keys from the webSearchAuth configuration into a dedicated function, getWebSearchKeys.
- Updated webSearchKeys to utilize the new function for improved clarity and maintainability.
- Prevents build time errors
feat: add personalization settings and memory preferences management
- Introduced a new Personalization tab in settings to manage user memory preferences.
- Implemented API endpoints and client-side logic for updating memory preferences.
- Enhanced user interface components to reflect personalization options and memory usage.
- Updated permissions to allow users to opt out of memory features.
- Added localization support for new settings and messages related to personalization.
style: personalization switch class
feat: add PersonalizationIcon and align Side Panel UI
feat: implement memory creation functionality
- Added a new API endpoint for creating memory entries, including validation for key and value.
- Introduced MemoryCreateDialog component for user interface to facilitate memory creation.
- Integrated token limit checks to prevent exceeding user memory capacity.
- Updated MemoryViewer to include a button for opening the memory creation dialog.
- Enhanced localization support for new messages related to memory creation.
feat: enhance message processing with configurable window size
- Updated AgentClient to use a configurable message window size for processing messages.
- Introduced messageWindowSize option in memory configuration schema with a default value of 5.
- Improved logic for selecting messages to process based on the configured window size.
chore: update librechat-data-provider version to 0.7.87 in package.json and package-lock.json
chore: remove OpenAPIPlugin and its associated tests
chore: remove MIGRATION_README.md as migration tasks are completed
ci: fix backend tests
chore: remove unused translation keys from localization file
chore: remove problematic test file and unused var in AgentClient
chore: remove unused import and import directly for JSDoc
* feat: add api package build stage in Dockerfile for improved modularity
* docs: reorder build steps in contributing guide for clarity
This commit is contained in:
parent
cd7dd576c1
commit
29ef91b4dd
170 changed files with 5700 additions and 3632 deletions
|
|
@ -1,5 +1,5 @@
|
|||
{
|
||||
"name": "librechat-mcp",
|
||||
"name": "@librechat/api",
|
||||
"version": "1.2.2",
|
||||
"type": "commonjs",
|
||||
"description": "MCP services for LibreChat",
|
||||
|
|
@ -47,6 +47,7 @@
|
|||
"@rollup/plugin-replace": "^5.0.5",
|
||||
"@rollup/plugin-terser": "^0.4.4",
|
||||
"@rollup/plugin-typescript": "^12.1.2",
|
||||
"@types/bun": "^1.2.15",
|
||||
"@types/diff": "^6.0.0",
|
||||
"@types/express": "^5.0.0",
|
||||
"@types/jest": "^29.5.2",
|
||||
|
|
@ -66,13 +67,17 @@
|
|||
"publishConfig": {
|
||||
"registry": "https://registry.npmjs.org/"
|
||||
},
|
||||
"dependencies": {
|
||||
"peerDependencies": {
|
||||
"@librechat/agents": "^2.4.37",
|
||||
"@librechat/data-schemas": "*",
|
||||
"librechat-data-provider": "*",
|
||||
"@modelcontextprotocol/sdk": "^1.11.2",
|
||||
"diff": "^7.0.0",
|
||||
"eventsource": "^3.0.2",
|
||||
"express": "^4.21.2"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"keyv": "^5.3.2"
|
||||
"express": "^4.21.2",
|
||||
"node-fetch": "2.7.0",
|
||||
"keyv": "^5.3.2",
|
||||
"zod": "^3.22.4",
|
||||
"tiktoken": "^1.0.15"
|
||||
}
|
||||
}
|
||||
3
packages/api/src/agents/index.ts
Normal file
3
packages/api/src/agents/index.ts
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
export * from './memory';
|
||||
export * from './resources';
|
||||
export * from './run';
|
||||
468
packages/api/src/agents/memory.ts
Normal file
468
packages/api/src/agents/memory.ts
Normal file
|
|
@ -0,0 +1,468 @@
|
|||
/** Memories */
|
||||
import { z } from 'zod';
|
||||
import { tool } from '@langchain/core/tools';
|
||||
import { Tools } from 'librechat-data-provider';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import { Run, Providers, GraphEvents } from '@librechat/agents';
|
||||
import type {
|
||||
StreamEventData,
|
||||
ToolEndCallback,
|
||||
EventHandler,
|
||||
ToolEndData,
|
||||
LLMConfig,
|
||||
} from '@librechat/agents';
|
||||
import type { TAttachment, MemoryArtifact } from 'librechat-data-provider';
|
||||
import type { ObjectId, MemoryMethods } from '@librechat/data-schemas';
|
||||
import type { BaseMessage } from '@langchain/core/messages';
|
||||
import type { Response as ServerResponse } from 'express';
|
||||
import { Tokenizer } from '~/utils';
|
||||
|
||||
type RequiredMemoryMethods = Pick<
|
||||
MemoryMethods,
|
||||
'setMemory' | 'deleteMemory' | 'getFormattedMemories'
|
||||
>;
|
||||
|
||||
type ToolEndMetadata = Record<string, unknown> & {
|
||||
run_id?: string;
|
||||
thread_id?: string;
|
||||
};
|
||||
|
||||
export interface MemoryConfig {
|
||||
validKeys?: string[];
|
||||
instructions?: string;
|
||||
llmConfig?: Partial<LLMConfig>;
|
||||
tokenLimit?: number;
|
||||
}
|
||||
|
||||
export const memoryInstructions =
|
||||
'The system automatically stores important user information and can update or delete memories based on user requests, enabling dynamic memory management.';
|
||||
|
||||
const getDefaultInstructions = (
|
||||
validKeys?: string[],
|
||||
tokenLimit?: number,
|
||||
) => `Use the \`set_memory\` tool to save important information about the user, but ONLY when the user has explicitly provided this information. If there is nothing to note about the user specifically, END THE TURN IMMEDIATELY.
|
||||
|
||||
The \`delete_memory\` tool should only be used in two scenarios:
|
||||
1. When the user explicitly asks to forget or remove specific information
|
||||
2. When updating existing memories, use the \`set_memory\` tool instead of deleting and re-adding the memory.
|
||||
|
||||
${
|
||||
validKeys && validKeys.length > 0
|
||||
? `CRITICAL INSTRUCTION: Only the following keys are valid for storing memories:
|
||||
${validKeys.map((key) => `- ${key}`).join('\n ')}`
|
||||
: 'You can use any appropriate key to store memories about the user.'
|
||||
}
|
||||
|
||||
${
|
||||
tokenLimit
|
||||
? `⚠️ TOKEN LIMIT: Each memory value must not exceed ${tokenLimit} tokens. Be concise and store only essential information.`
|
||||
: ''
|
||||
}
|
||||
|
||||
⚠️ WARNING ⚠️
|
||||
DO NOT STORE ANY INFORMATION UNLESS THE USER HAS EXPLICITLY PROVIDED IT.
|
||||
ONLY store information the user has EXPLICITLY shared.
|
||||
NEVER guess or assume user information.
|
||||
ALL memory values must be factual statements about THIS specific user.
|
||||
If nothing needs to be stored, DO NOT CALL any memory tools.
|
||||
If you're unsure whether to store something, DO NOT store it.
|
||||
If nothing needs to be stored, END THE TURN IMMEDIATELY.`;
|
||||
|
||||
/**
|
||||
* Creates a memory tool instance with user context
|
||||
*/
|
||||
const createMemoryTool = ({
|
||||
userId,
|
||||
setMemory,
|
||||
validKeys,
|
||||
tokenLimit,
|
||||
totalTokens = 0,
|
||||
}: {
|
||||
userId: string | ObjectId;
|
||||
setMemory: MemoryMethods['setMemory'];
|
||||
validKeys?: string[];
|
||||
tokenLimit?: number;
|
||||
totalTokens?: number;
|
||||
}) => {
|
||||
return tool(
|
||||
async ({ key, value }) => {
|
||||
try {
|
||||
if (validKeys && validKeys.length > 0 && !validKeys.includes(key)) {
|
||||
logger.warn(
|
||||
`Memory Agent failed to set memory: Invalid key "${key}". Must be one of: ${validKeys.join(
|
||||
', ',
|
||||
)}`,
|
||||
);
|
||||
return `Invalid key "${key}". Must be one of: ${validKeys.join(', ')}`;
|
||||
}
|
||||
|
||||
const tokenCount = Tokenizer.getTokenCount(value, 'o200k_base');
|
||||
|
||||
if (tokenLimit && tokenCount > tokenLimit) {
|
||||
logger.warn(
|
||||
`Memory Agent failed to set memory: Value exceeds token limit. Value has ${tokenCount} tokens, but limit is ${tokenLimit}`,
|
||||
);
|
||||
return `Memory value too large: ${tokenCount} tokens exceeds limit of ${tokenLimit}`;
|
||||
}
|
||||
|
||||
if (tokenLimit && totalTokens + tokenCount > tokenLimit) {
|
||||
const remainingCapacity = tokenLimit - totalTokens;
|
||||
logger.warn(
|
||||
`Memory Agent failed to set memory: Would exceed total token limit. Current usage: ${totalTokens}, new memory: ${tokenCount} tokens, limit: ${tokenLimit}`,
|
||||
);
|
||||
return `Cannot add memory: would exceed token limit. Current usage: ${totalTokens}/${tokenLimit} tokens. This memory requires ${tokenCount} tokens, but only ${remainingCapacity} tokens available.`;
|
||||
}
|
||||
|
||||
const artifact: Record<Tools.memory, MemoryArtifact> = {
|
||||
[Tools.memory]: {
|
||||
key,
|
||||
value,
|
||||
tokenCount,
|
||||
type: 'update',
|
||||
},
|
||||
};
|
||||
|
||||
const result = await setMemory({ userId, key, value, tokenCount });
|
||||
if (result.ok) {
|
||||
logger.debug(`Memory set for key "${key}" (${tokenCount} tokens) for user "${userId}"`);
|
||||
return [`Memory set for key "${key}" (${tokenCount} tokens)`, artifact];
|
||||
}
|
||||
logger.warn(`Failed to set memory for key "${key}" for user "${userId}"`);
|
||||
return [`Failed to set memory for key "${key}"`, undefined];
|
||||
} catch (error) {
|
||||
logger.error('Memory Agent failed to set memory', error);
|
||||
return [`Error setting memory for key "${key}"`, undefined];
|
||||
}
|
||||
},
|
||||
{
|
||||
name: 'set_memory',
|
||||
description: 'Saves important information about the user into memory.',
|
||||
responseFormat: 'content_and_artifact',
|
||||
schema: z.object({
|
||||
key: z
|
||||
.string()
|
||||
.describe(
|
||||
validKeys && validKeys.length > 0
|
||||
? `The key of the memory value. Must be one of: ${validKeys.join(', ')}`
|
||||
: 'The key identifier for this memory',
|
||||
),
|
||||
value: z
|
||||
.string()
|
||||
.describe(
|
||||
'Value MUST be a complete sentence that fully describes relevant user information.',
|
||||
),
|
||||
}),
|
||||
},
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates a delete memory tool instance with user context
|
||||
*/
|
||||
const createDeleteMemoryTool = ({
|
||||
userId,
|
||||
deleteMemory,
|
||||
validKeys,
|
||||
}: {
|
||||
userId: string | ObjectId;
|
||||
deleteMemory: MemoryMethods['deleteMemory'];
|
||||
validKeys?: string[];
|
||||
}) => {
|
||||
return tool(
|
||||
async ({ key }) => {
|
||||
try {
|
||||
if (validKeys && validKeys.length > 0 && !validKeys.includes(key)) {
|
||||
logger.warn(
|
||||
`Memory Agent failed to delete memory: Invalid key "${key}". Must be one of: ${validKeys.join(
|
||||
', ',
|
||||
)}`,
|
||||
);
|
||||
return `Invalid key "${key}". Must be one of: ${validKeys.join(', ')}`;
|
||||
}
|
||||
|
||||
const artifact: Record<Tools.memory, MemoryArtifact> = {
|
||||
[Tools.memory]: {
|
||||
key,
|
||||
type: 'delete',
|
||||
},
|
||||
};
|
||||
|
||||
const result = await deleteMemory({ userId, key });
|
||||
if (result.ok) {
|
||||
logger.debug(`Memory deleted for key "${key}" for user "${userId}"`);
|
||||
return [`Memory deleted for key "${key}"`, artifact];
|
||||
}
|
||||
logger.warn(`Failed to delete memory for key "${key}" for user "${userId}"`);
|
||||
return [`Failed to delete memory for key "${key}"`, undefined];
|
||||
} catch (error) {
|
||||
logger.error('Memory Agent failed to delete memory', error);
|
||||
return [`Error deleting memory for key "${key}"`, undefined];
|
||||
}
|
||||
},
|
||||
{
|
||||
name: 'delete_memory',
|
||||
description:
|
||||
'Deletes specific memory data about the user using the provided key. For updating existing memories, use the `set_memory` tool instead',
|
||||
responseFormat: 'content_and_artifact',
|
||||
schema: z.object({
|
||||
key: z
|
||||
.string()
|
||||
.describe(
|
||||
validKeys && validKeys.length > 0
|
||||
? `The key of the memory to delete. Must be one of: ${validKeys.join(', ')}`
|
||||
: 'The key identifier of the memory to delete',
|
||||
),
|
||||
}),
|
||||
},
|
||||
);
|
||||
};
|
||||
export class BasicToolEndHandler implements EventHandler {
|
||||
private callback?: ToolEndCallback;
|
||||
constructor(callback?: ToolEndCallback) {
|
||||
this.callback = callback;
|
||||
}
|
||||
handle(
|
||||
event: string,
|
||||
data: StreamEventData | undefined,
|
||||
metadata?: Record<string, unknown>,
|
||||
): void {
|
||||
if (!metadata) {
|
||||
console.warn(`Graph or metadata not found in ${event} event`);
|
||||
return;
|
||||
}
|
||||
const toolEndData = data as ToolEndData | undefined;
|
||||
if (!toolEndData?.output) {
|
||||
console.warn('No output found in tool_end event');
|
||||
return;
|
||||
}
|
||||
this.callback?.(toolEndData, metadata);
|
||||
}
|
||||
}
|
||||
|
||||
export async function processMemory({
|
||||
res,
|
||||
userId,
|
||||
setMemory,
|
||||
deleteMemory,
|
||||
messages,
|
||||
memory,
|
||||
messageId,
|
||||
conversationId,
|
||||
validKeys,
|
||||
instructions,
|
||||
llmConfig,
|
||||
tokenLimit,
|
||||
totalTokens = 0,
|
||||
}: {
|
||||
res: ServerResponse;
|
||||
setMemory: MemoryMethods['setMemory'];
|
||||
deleteMemory: MemoryMethods['deleteMemory'];
|
||||
userId: string | ObjectId;
|
||||
memory: string;
|
||||
messageId: string;
|
||||
conversationId: string;
|
||||
messages: BaseMessage[];
|
||||
validKeys?: string[];
|
||||
instructions: string;
|
||||
tokenLimit?: number;
|
||||
totalTokens?: number;
|
||||
llmConfig?: Partial<LLMConfig>;
|
||||
}): Promise<(TAttachment | null)[] | undefined> {
|
||||
try {
|
||||
const memoryTool = createMemoryTool({ userId, tokenLimit, setMemory, validKeys, totalTokens });
|
||||
const deleteMemoryTool = createDeleteMemoryTool({
|
||||
userId,
|
||||
validKeys,
|
||||
deleteMemory,
|
||||
});
|
||||
|
||||
const currentMemoryTokens = totalTokens;
|
||||
|
||||
let memoryStatus = `# Existing memory:\n${memory ?? 'No existing memories'}`;
|
||||
|
||||
if (tokenLimit) {
|
||||
const remainingTokens = tokenLimit - currentMemoryTokens;
|
||||
memoryStatus = `# Memory Status:
|
||||
Current memory usage: ${currentMemoryTokens} tokens
|
||||
Token limit: ${tokenLimit} tokens
|
||||
Remaining capacity: ${remainingTokens} tokens
|
||||
|
||||
# Existing memory:
|
||||
${memory ?? 'No existing memories'}`;
|
||||
}
|
||||
|
||||
const defaultLLMConfig: LLMConfig = {
|
||||
provider: Providers.OPENAI,
|
||||
model: 'gpt-4.1-mini',
|
||||
temperature: 0.4,
|
||||
streaming: false,
|
||||
disableStreaming: true,
|
||||
};
|
||||
|
||||
const finalLLMConfig = {
|
||||
...defaultLLMConfig,
|
||||
...llmConfig,
|
||||
/**
|
||||
* Ensure streaming is always disabled for memory processing
|
||||
*/
|
||||
streaming: false,
|
||||
disableStreaming: true,
|
||||
};
|
||||
|
||||
const artifactPromises: Promise<TAttachment | null>[] = [];
|
||||
const memoryCallback = createMemoryCallback({ res, artifactPromises });
|
||||
const customHandlers = {
|
||||
[GraphEvents.TOOL_END]: new BasicToolEndHandler(memoryCallback),
|
||||
};
|
||||
|
||||
const run = await Run.create({
|
||||
runId: messageId,
|
||||
graphConfig: {
|
||||
type: 'standard',
|
||||
llmConfig: finalLLMConfig,
|
||||
tools: [memoryTool, deleteMemoryTool],
|
||||
instructions,
|
||||
additional_instructions: memoryStatus,
|
||||
toolEnd: true,
|
||||
},
|
||||
customHandlers,
|
||||
returnContent: true,
|
||||
});
|
||||
|
||||
const config = {
|
||||
configurable: {
|
||||
provider: llmConfig?.provider,
|
||||
thread_id: `memory-run-${conversationId}`,
|
||||
},
|
||||
streamMode: 'values',
|
||||
version: 'v2',
|
||||
} as const;
|
||||
|
||||
const inputs = {
|
||||
messages,
|
||||
};
|
||||
const content = await run.processStream(inputs, config);
|
||||
if (content) {
|
||||
logger.debug('Memory Agent processed memory successfully', content);
|
||||
} else {
|
||||
logger.warn('Memory Agent processed memory but returned no content');
|
||||
}
|
||||
return await Promise.all(artifactPromises);
|
||||
} catch (error) {
|
||||
logger.error('Memory Agent failed to process memory', error);
|
||||
}
|
||||
}
|
||||
|
||||
export async function createMemoryProcessor({
|
||||
res,
|
||||
userId,
|
||||
messageId,
|
||||
memoryMethods,
|
||||
conversationId,
|
||||
config = {},
|
||||
}: {
|
||||
res: ServerResponse;
|
||||
messageId: string;
|
||||
conversationId: string;
|
||||
userId: string | ObjectId;
|
||||
memoryMethods: RequiredMemoryMethods;
|
||||
config?: MemoryConfig;
|
||||
}): Promise<[string, (messages: BaseMessage[]) => Promise<(TAttachment | null)[] | undefined>]> {
|
||||
const { validKeys, instructions, llmConfig, tokenLimit } = config;
|
||||
const finalInstructions = instructions || getDefaultInstructions(validKeys, tokenLimit);
|
||||
|
||||
const { withKeys, withoutKeys, totalTokens } = await memoryMethods.getFormattedMemories({
|
||||
userId,
|
||||
});
|
||||
|
||||
return [
|
||||
withoutKeys,
|
||||
async function (messages: BaseMessage[]): Promise<(TAttachment | null)[] | undefined> {
|
||||
try {
|
||||
return await processMemory({
|
||||
res,
|
||||
userId,
|
||||
messages,
|
||||
validKeys,
|
||||
llmConfig,
|
||||
messageId,
|
||||
tokenLimit,
|
||||
conversationId,
|
||||
memory: withKeys,
|
||||
totalTokens: totalTokens || 0,
|
||||
instructions: finalInstructions,
|
||||
setMemory: memoryMethods.setMemory,
|
||||
deleteMemory: memoryMethods.deleteMemory,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('Memory Agent failed to process memory', error);
|
||||
}
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
async function handleMemoryArtifact({
|
||||
res,
|
||||
data,
|
||||
metadata,
|
||||
}: {
|
||||
res: ServerResponse;
|
||||
data: ToolEndData;
|
||||
metadata?: ToolEndMetadata;
|
||||
}) {
|
||||
const output = data?.output;
|
||||
if (!output) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!output.artifact) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const memoryArtifact = output.artifact[Tools.memory] as MemoryArtifact | undefined;
|
||||
if (!memoryArtifact) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const attachment: Partial<TAttachment> = {
|
||||
type: Tools.memory,
|
||||
toolCallId: output.tool_call_id,
|
||||
messageId: metadata?.run_id ?? '',
|
||||
conversationId: metadata?.thread_id ?? '',
|
||||
[Tools.memory]: memoryArtifact,
|
||||
};
|
||||
if (!res.headersSent) {
|
||||
return attachment;
|
||||
}
|
||||
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
|
||||
return attachment;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a memory callback for handling memory artifacts
|
||||
* @param params - The parameters object
|
||||
* @param params.res - The server response object
|
||||
* @param params.artifactPromises - Array to collect artifact promises
|
||||
* @returns The memory callback function
|
||||
*/
|
||||
export function createMemoryCallback({
|
||||
res,
|
||||
artifactPromises,
|
||||
}: {
|
||||
res: ServerResponse;
|
||||
artifactPromises: Promise<Partial<TAttachment> | null>[];
|
||||
}): ToolEndCallback {
|
||||
return async (data: ToolEndData, metadata?: Record<string, unknown>) => {
|
||||
const output = data?.output;
|
||||
const memoryArtifact = output?.artifact?.[Tools.memory] as MemoryArtifact;
|
||||
if (memoryArtifact == null) {
|
||||
return;
|
||||
}
|
||||
artifactPromises.push(
|
||||
handleMemoryArtifact({ res, data, metadata }).catch((error) => {
|
||||
logger.error('Error processing memory artifact content:', error);
|
||||
return null;
|
||||
}),
|
||||
);
|
||||
};
|
||||
}
|
||||
543
packages/api/src/agents/resources.test.ts
Normal file
543
packages/api/src/agents/resources.test.ts
Normal file
|
|
@ -0,0 +1,543 @@
|
|||
import { primeResources } from './resources';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import { EModelEndpoint, EToolResources, AgentCapabilities } from 'librechat-data-provider';
|
||||
import type { Request as ServerRequest } from 'express';
|
||||
import type { TFile } from 'librechat-data-provider';
|
||||
import type { TGetFiles } from './resources';
|
||||
|
||||
// Mock logger
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
error: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
describe('primeResources', () => {
|
||||
let mockReq: ServerRequest;
|
||||
let mockGetFiles: jest.MockedFunction<TGetFiles>;
|
||||
let requestFileSet: Set<string>;
|
||||
|
||||
beforeEach(() => {
|
||||
// Reset mocks
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Setup mock request
|
||||
mockReq = {
|
||||
app: {
|
||||
locals: {
|
||||
[EModelEndpoint.agents]: {
|
||||
capabilities: [AgentCapabilities.ocr],
|
||||
},
|
||||
},
|
||||
},
|
||||
} as unknown as ServerRequest;
|
||||
|
||||
// Setup mock getFiles function
|
||||
mockGetFiles = jest.fn();
|
||||
|
||||
// Setup request file set
|
||||
requestFileSet = new Set(['file1', 'file2', 'file3']);
|
||||
});
|
||||
|
||||
describe('when OCR is enabled and tool_resources has OCR file_ids', () => {
|
||||
it('should fetch OCR files and include them in attachments', async () => {
|
||||
const mockOcrFiles: TFile[] = [
|
||||
{
|
||||
user: 'user1',
|
||||
file_id: 'ocr-file-1',
|
||||
filename: 'document.pdf',
|
||||
filepath: '/uploads/document.pdf',
|
||||
object: 'file',
|
||||
type: 'application/pdf',
|
||||
bytes: 1024,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
},
|
||||
];
|
||||
|
||||
mockGetFiles.mockResolvedValue(mockOcrFiles);
|
||||
|
||||
const tool_resources = {
|
||||
[EToolResources.ocr]: {
|
||||
file_ids: ['ocr-file-1'],
|
||||
},
|
||||
};
|
||||
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments: undefined,
|
||||
tool_resources,
|
||||
});
|
||||
|
||||
expect(mockGetFiles).toHaveBeenCalledWith({ file_id: { $in: ['ocr-file-1'] } }, {}, {});
|
||||
expect(result.attachments).toEqual(mockOcrFiles);
|
||||
expect(result.tool_resources).toEqual(tool_resources);
|
||||
});
|
||||
});
|
||||
|
||||
describe('when OCR is disabled', () => {
|
||||
it('should not fetch OCR files even if tool_resources has OCR file_ids', async () => {
|
||||
(mockReq.app as ServerRequest['app']).locals[EModelEndpoint.agents].capabilities = [];
|
||||
|
||||
const tool_resources = {
|
||||
[EToolResources.ocr]: {
|
||||
file_ids: ['ocr-file-1'],
|
||||
},
|
||||
};
|
||||
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments: undefined,
|
||||
tool_resources,
|
||||
});
|
||||
|
||||
expect(mockGetFiles).not.toHaveBeenCalled();
|
||||
expect(result.attachments).toBeUndefined();
|
||||
expect(result.tool_resources).toEqual(tool_resources);
|
||||
});
|
||||
});
|
||||
|
||||
describe('when attachments are provided', () => {
|
||||
it('should process files with fileIdentifier as execute_code resources', async () => {
|
||||
const mockFiles: TFile[] = [
|
||||
{
|
||||
user: 'user1',
|
||||
file_id: 'file1',
|
||||
filename: 'script.py',
|
||||
filepath: '/uploads/script.py',
|
||||
object: 'file',
|
||||
type: 'text/x-python',
|
||||
bytes: 512,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
metadata: {
|
||||
fileIdentifier: 'python-script',
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
const attachments = Promise.resolve(mockFiles);
|
||||
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments,
|
||||
tool_resources: {},
|
||||
});
|
||||
|
||||
expect(result.attachments).toEqual(mockFiles);
|
||||
expect(result.tool_resources?.[EToolResources.execute_code]?.files).toEqual(mockFiles);
|
||||
});
|
||||
|
||||
it('should process embedded files as file_search resources', async () => {
|
||||
const mockFiles: TFile[] = [
|
||||
{
|
||||
user: 'user1',
|
||||
file_id: 'file2',
|
||||
filename: 'document.txt',
|
||||
filepath: '/uploads/document.txt',
|
||||
object: 'file',
|
||||
type: 'text/plain',
|
||||
bytes: 256,
|
||||
embedded: true,
|
||||
usage: 0,
|
||||
},
|
||||
];
|
||||
|
||||
const attachments = Promise.resolve(mockFiles);
|
||||
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments,
|
||||
tool_resources: {},
|
||||
});
|
||||
|
||||
expect(result.attachments).toEqual(mockFiles);
|
||||
expect(result.tool_resources?.[EToolResources.file_search]?.files).toEqual(mockFiles);
|
||||
});
|
||||
|
||||
it('should process image files in requestFileSet as image_edit resources', async () => {
|
||||
const mockFiles: TFile[] = [
|
||||
{
|
||||
user: 'user1',
|
||||
file_id: 'file1',
|
||||
filename: 'image.png',
|
||||
filepath: '/uploads/image.png',
|
||||
object: 'file',
|
||||
type: 'image/png',
|
||||
bytes: 2048,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
height: 800,
|
||||
width: 600,
|
||||
},
|
||||
];
|
||||
|
||||
const attachments = Promise.resolve(mockFiles);
|
||||
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments,
|
||||
tool_resources: {},
|
||||
});
|
||||
|
||||
expect(result.attachments).toEqual(mockFiles);
|
||||
expect(result.tool_resources?.[EToolResources.image_edit]?.files).toEqual(mockFiles);
|
||||
});
|
||||
|
||||
it('should not process image files not in requestFileSet', async () => {
|
||||
const mockFiles: TFile[] = [
|
||||
{
|
||||
user: 'user1',
|
||||
file_id: 'file-not-in-set',
|
||||
filename: 'image.png',
|
||||
filepath: '/uploads/image.png',
|
||||
object: 'file',
|
||||
type: 'image/png',
|
||||
bytes: 2048,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
height: 800,
|
||||
width: 600,
|
||||
},
|
||||
];
|
||||
|
||||
const attachments = Promise.resolve(mockFiles);
|
||||
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments,
|
||||
tool_resources: {},
|
||||
});
|
||||
|
||||
expect(result.attachments).toEqual(mockFiles);
|
||||
expect(result.tool_resources?.[EToolResources.image_edit]).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should not process image files without height and width', async () => {
|
||||
const mockFiles: TFile[] = [
|
||||
{
|
||||
user: 'user1',
|
||||
file_id: 'file1',
|
||||
filename: 'image.png',
|
||||
filepath: '/uploads/image.png',
|
||||
object: 'file',
|
||||
type: 'image/png',
|
||||
bytes: 2048,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
// Missing height and width
|
||||
},
|
||||
];
|
||||
|
||||
const attachments = Promise.resolve(mockFiles);
|
||||
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments,
|
||||
tool_resources: {},
|
||||
});
|
||||
|
||||
expect(result.attachments).toEqual(mockFiles);
|
||||
expect(result.tool_resources?.[EToolResources.image_edit]).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should filter out null files from attachments', async () => {
|
||||
const mockFiles: Array<TFile | null> = [
|
||||
{
|
||||
user: 'user1',
|
||||
file_id: 'file1',
|
||||
filename: 'valid.txt',
|
||||
filepath: '/uploads/valid.txt',
|
||||
object: 'file',
|
||||
type: 'text/plain',
|
||||
bytes: 256,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
},
|
||||
null,
|
||||
{
|
||||
user: 'user1',
|
||||
file_id: 'file2',
|
||||
filename: 'valid2.txt',
|
||||
filepath: '/uploads/valid2.txt',
|
||||
object: 'file',
|
||||
type: 'text/plain',
|
||||
bytes: 128,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
},
|
||||
];
|
||||
|
||||
const attachments = Promise.resolve(mockFiles);
|
||||
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments,
|
||||
tool_resources: {},
|
||||
});
|
||||
|
||||
expect(result.attachments).toHaveLength(2);
|
||||
expect(result.attachments?.[0]?.file_id).toBe('file1');
|
||||
expect(result.attachments?.[1]?.file_id).toBe('file2');
|
||||
});
|
||||
|
||||
it('should merge existing tool_resources with new files', async () => {
|
||||
const mockFiles: TFile[] = [
|
||||
{
|
||||
user: 'user1',
|
||||
file_id: 'file1',
|
||||
filename: 'script.py',
|
||||
filepath: '/uploads/script.py',
|
||||
object: 'file',
|
||||
type: 'text/x-python',
|
||||
bytes: 512,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
metadata: {
|
||||
fileIdentifier: 'python-script',
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
const existingToolResources = {
|
||||
[EToolResources.execute_code]: {
|
||||
files: [
|
||||
{
|
||||
user: 'user1',
|
||||
file_id: 'existing-file',
|
||||
filename: 'existing.py',
|
||||
filepath: '/uploads/existing.py',
|
||||
object: 'file' as const,
|
||||
type: 'text/x-python',
|
||||
bytes: 256,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
},
|
||||
],
|
||||
},
|
||||
};
|
||||
|
||||
const attachments = Promise.resolve(mockFiles);
|
||||
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments,
|
||||
tool_resources: existingToolResources,
|
||||
});
|
||||
|
||||
expect(result.tool_resources?.[EToolResources.execute_code]?.files).toHaveLength(2);
|
||||
expect(result.tool_resources?.[EToolResources.execute_code]?.files?.[0]?.file_id).toBe(
|
||||
'existing-file',
|
||||
);
|
||||
expect(result.tool_resources?.[EToolResources.execute_code]?.files?.[1]?.file_id).toBe(
|
||||
'file1',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('when both OCR and attachments are provided', () => {
|
||||
it('should include both OCR files and attachment files', async () => {
|
||||
const mockOcrFiles: TFile[] = [
|
||||
{
|
||||
user: 'user1',
|
||||
file_id: 'ocr-file-1',
|
||||
filename: 'document.pdf',
|
||||
filepath: '/uploads/document.pdf',
|
||||
object: 'file',
|
||||
type: 'application/pdf',
|
||||
bytes: 1024,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
},
|
||||
];
|
||||
|
||||
const mockAttachmentFiles: TFile[] = [
|
||||
{
|
||||
user: 'user1',
|
||||
file_id: 'file1',
|
||||
filename: 'attachment.txt',
|
||||
filepath: '/uploads/attachment.txt',
|
||||
object: 'file',
|
||||
type: 'text/plain',
|
||||
bytes: 256,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
},
|
||||
];
|
||||
|
||||
mockGetFiles.mockResolvedValue(mockOcrFiles);
|
||||
const attachments = Promise.resolve(mockAttachmentFiles);
|
||||
|
||||
const tool_resources = {
|
||||
[EToolResources.ocr]: {
|
||||
file_ids: ['ocr-file-1'],
|
||||
},
|
||||
};
|
||||
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments,
|
||||
tool_resources,
|
||||
});
|
||||
|
||||
expect(result.attachments).toHaveLength(2);
|
||||
expect(result.attachments?.[0]?.file_id).toBe('ocr-file-1');
|
||||
expect(result.attachments?.[1]?.file_id).toBe('file1');
|
||||
});
|
||||
});
|
||||
|
||||
describe('error handling', () => {
|
||||
it('should handle errors gracefully and log them', async () => {
|
||||
const mockFiles: TFile[] = [
|
||||
{
|
||||
user: 'user1',
|
||||
file_id: 'file1',
|
||||
filename: 'test.txt',
|
||||
filepath: '/uploads/test.txt',
|
||||
object: 'file',
|
||||
type: 'text/plain',
|
||||
bytes: 256,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
},
|
||||
];
|
||||
|
||||
const attachments = Promise.resolve(mockFiles);
|
||||
const error = new Error('Test error');
|
||||
|
||||
// Mock getFiles to throw an error when called for OCR
|
||||
mockGetFiles.mockRejectedValue(error);
|
||||
|
||||
const tool_resources = {
|
||||
[EToolResources.ocr]: {
|
||||
file_ids: ['ocr-file-1'],
|
||||
},
|
||||
};
|
||||
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments,
|
||||
tool_resources,
|
||||
});
|
||||
|
||||
expect(logger.error).toHaveBeenCalledWith('Error priming resources', error);
|
||||
expect(result.attachments).toEqual(mockFiles);
|
||||
expect(result.tool_resources).toEqual(tool_resources);
|
||||
});
|
||||
|
||||
it('should handle promise rejection in attachments', async () => {
|
||||
const error = new Error('Attachment error');
|
||||
const attachments = Promise.reject(error);
|
||||
|
||||
// The function should now handle rejected attachment promises gracefully
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments,
|
||||
tool_resources: {},
|
||||
});
|
||||
|
||||
// Should log both the main error and the attachment error
|
||||
expect(logger.error).toHaveBeenCalledWith('Error priming resources', error);
|
||||
expect(logger.error).toHaveBeenCalledWith(
|
||||
'Error resolving attachments in catch block',
|
||||
error,
|
||||
);
|
||||
|
||||
// Should return empty array when attachments promise is rejected
|
||||
expect(result.attachments).toEqual([]);
|
||||
expect(result.tool_resources).toEqual({});
|
||||
});
|
||||
});
|
||||
|
||||
describe('edge cases', () => {
|
||||
it('should handle missing app.locals gracefully', async () => {
|
||||
const reqWithoutLocals = {} as ServerRequest;
|
||||
|
||||
const result = await primeResources({
|
||||
req: reqWithoutLocals,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments: undefined,
|
||||
tool_resources: {
|
||||
[EToolResources.ocr]: {
|
||||
file_ids: ['ocr-file-1'],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
expect(mockGetFiles).not.toHaveBeenCalled();
|
||||
// When app.locals is missing and there's an error accessing properties,
|
||||
// the function falls back to the catch block which returns an empty array
|
||||
expect(result.attachments).toEqual([]);
|
||||
});
|
||||
|
||||
it('should handle undefined tool_resources', async () => {
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet,
|
||||
attachments: undefined,
|
||||
tool_resources: undefined,
|
||||
});
|
||||
|
||||
expect(result.tool_resources).toEqual({});
|
||||
expect(result.attachments).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should handle empty requestFileSet', async () => {
|
||||
const mockFiles: TFile[] = [
|
||||
{
|
||||
user: 'user1',
|
||||
file_id: 'file1',
|
||||
filename: 'image.png',
|
||||
filepath: '/uploads/image.png',
|
||||
object: 'file',
|
||||
type: 'image/png',
|
||||
bytes: 2048,
|
||||
embedded: false,
|
||||
usage: 0,
|
||||
height: 800,
|
||||
width: 600,
|
||||
},
|
||||
];
|
||||
|
||||
const attachments = Promise.resolve(mockFiles);
|
||||
const emptyRequestFileSet = new Set<string>();
|
||||
|
||||
const result = await primeResources({
|
||||
req: mockReq,
|
||||
getFiles: mockGetFiles,
|
||||
requestFileSet: emptyRequestFileSet,
|
||||
attachments,
|
||||
tool_resources: {},
|
||||
});
|
||||
|
||||
expect(result.attachments).toEqual(mockFiles);
|
||||
expect(result.tool_resources?.[EToolResources.image_edit]).toBeUndefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
114
packages/api/src/agents/resources.ts
Normal file
114
packages/api/src/agents/resources.ts
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
import { logger } from '@librechat/data-schemas';
|
||||
import { EModelEndpoint, EToolResources, AgentCapabilities } from 'librechat-data-provider';
|
||||
import type { FilterQuery, QueryOptions, ProjectionType } from 'mongoose';
|
||||
import type { AgentToolResources, TFile } from 'librechat-data-provider';
|
||||
import type { IMongoFile } from '@librechat/data-schemas';
|
||||
import type { Request as ServerRequest } from 'express';
|
||||
|
||||
export type TGetFiles = (
|
||||
filter: FilterQuery<IMongoFile>,
|
||||
_sortOptions: ProjectionType<IMongoFile> | null | undefined,
|
||||
selectFields: QueryOptions<IMongoFile> | null | undefined,
|
||||
) => Promise<Array<TFile>>;
|
||||
|
||||
/**
|
||||
* @param params
|
||||
* @param params.req
|
||||
* @param params.attachments
|
||||
* @param params.requestFileSet
|
||||
* @param params.tool_resources
|
||||
*/
|
||||
export const primeResources = async ({
|
||||
req,
|
||||
getFiles,
|
||||
requestFileSet,
|
||||
attachments: _attachments,
|
||||
tool_resources: _tool_resources,
|
||||
}: {
|
||||
req: ServerRequest;
|
||||
requestFileSet: Set<string>;
|
||||
attachments: Promise<Array<TFile | null>> | undefined;
|
||||
tool_resources: AgentToolResources | undefined;
|
||||
getFiles: TGetFiles;
|
||||
}): Promise<{
|
||||
attachments: Array<TFile | undefined> | undefined;
|
||||
tool_resources: AgentToolResources | undefined;
|
||||
}> => {
|
||||
try {
|
||||
let attachments: Array<TFile | undefined> | undefined;
|
||||
const tool_resources = _tool_resources ?? {};
|
||||
const isOCREnabled = (req.app.locals?.[EModelEndpoint.agents]?.capabilities ?? []).includes(
|
||||
AgentCapabilities.ocr,
|
||||
);
|
||||
if (tool_resources[EToolResources.ocr]?.file_ids && isOCREnabled) {
|
||||
const context = await getFiles(
|
||||
{
|
||||
file_id: { $in: tool_resources.ocr.file_ids },
|
||||
},
|
||||
{},
|
||||
{},
|
||||
);
|
||||
attachments = (attachments ?? []).concat(context);
|
||||
}
|
||||
if (!_attachments) {
|
||||
return { attachments, tool_resources };
|
||||
}
|
||||
const files = await _attachments;
|
||||
if (!attachments) {
|
||||
attachments = [];
|
||||
}
|
||||
|
||||
for (const file of files) {
|
||||
if (!file) {
|
||||
continue;
|
||||
}
|
||||
if (file.metadata?.fileIdentifier) {
|
||||
const execute_code = tool_resources[EToolResources.execute_code] ?? {};
|
||||
if (!execute_code.files) {
|
||||
tool_resources[EToolResources.execute_code] = { ...execute_code, files: [] };
|
||||
}
|
||||
tool_resources[EToolResources.execute_code]?.files?.push(file);
|
||||
} else if (file.embedded === true) {
|
||||
const file_search = tool_resources[EToolResources.file_search] ?? {};
|
||||
if (!file_search.files) {
|
||||
tool_resources[EToolResources.file_search] = { ...file_search, files: [] };
|
||||
}
|
||||
tool_resources[EToolResources.file_search]?.files?.push(file);
|
||||
} else if (
|
||||
requestFileSet.has(file.file_id) &&
|
||||
file.type.startsWith('image') &&
|
||||
file.height &&
|
||||
file.width
|
||||
) {
|
||||
const image_edit = tool_resources[EToolResources.image_edit] ?? {};
|
||||
if (!image_edit.files) {
|
||||
tool_resources[EToolResources.image_edit] = { ...image_edit, files: [] };
|
||||
}
|
||||
tool_resources[EToolResources.image_edit]?.files?.push(file);
|
||||
}
|
||||
|
||||
attachments.push(file);
|
||||
}
|
||||
return { attachments, tool_resources };
|
||||
} catch (error) {
|
||||
logger.error('Error priming resources', error);
|
||||
|
||||
// Safely try to get attachments without rethrowing
|
||||
let safeAttachments: Array<TFile | undefined> = [];
|
||||
if (_attachments) {
|
||||
try {
|
||||
const attachmentFiles = await _attachments;
|
||||
safeAttachments = (attachmentFiles?.filter((file) => !!file) ?? []) as Array<TFile>;
|
||||
} catch (attachmentError) {
|
||||
// If attachments promise is also rejected, just use empty array
|
||||
logger.error('Error resolving attachments in catch block', attachmentError);
|
||||
safeAttachments = [];
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
attachments: safeAttachments,
|
||||
tool_resources: _tool_resources,
|
||||
};
|
||||
}
|
||||
};
|
||||
90
packages/api/src/agents/run.ts
Normal file
90
packages/api/src/agents/run.ts
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
import { Run, Providers } from '@librechat/agents';
|
||||
import { providerEndpointMap, KnownEndpoints } from 'librechat-data-provider';
|
||||
import type { StandardGraphConfig, EventHandler, GraphEvents, IState } from '@librechat/agents';
|
||||
import type { Agent } from 'librechat-data-provider';
|
||||
import type * as t from '~/types';
|
||||
|
||||
const customProviders = new Set([
|
||||
Providers.XAI,
|
||||
Providers.OLLAMA,
|
||||
Providers.DEEPSEEK,
|
||||
Providers.OPENROUTER,
|
||||
]);
|
||||
|
||||
/**
|
||||
* Creates a new Run instance with custom handlers and configuration.
|
||||
*
|
||||
* @param options - The options for creating the Run instance.
|
||||
* @param options.agent - The agent for this run.
|
||||
* @param options.signal - The signal for this run.
|
||||
* @param options.req - The server request.
|
||||
* @param options.runId - Optional run ID; otherwise, a new run ID will be generated.
|
||||
* @param options.customHandlers - Custom event handlers.
|
||||
* @param options.streaming - Whether to use streaming.
|
||||
* @param options.streamUsage - Whether to stream usage information.
|
||||
* @returns {Promise<Run<IState>>} A promise that resolves to a new Run instance.
|
||||
*/
|
||||
export async function createRun({
|
||||
runId,
|
||||
agent,
|
||||
signal,
|
||||
customHandlers,
|
||||
streaming = true,
|
||||
streamUsage = true,
|
||||
}: {
|
||||
agent: Agent;
|
||||
signal: AbortSignal;
|
||||
runId?: string;
|
||||
streaming?: boolean;
|
||||
streamUsage?: boolean;
|
||||
customHandlers?: Record<GraphEvents, EventHandler>;
|
||||
}): Promise<Run<IState>> {
|
||||
const provider =
|
||||
providerEndpointMap[agent.provider as keyof typeof providerEndpointMap] ?? agent.provider;
|
||||
const llmConfig: t.RunLLMConfig = Object.assign(
|
||||
{
|
||||
provider,
|
||||
streaming,
|
||||
streamUsage,
|
||||
},
|
||||
agent.model_parameters,
|
||||
);
|
||||
|
||||
/** Resolves issues with new OpenAI usage field */
|
||||
if (
|
||||
customProviders.has(agent.provider) ||
|
||||
(agent.provider === Providers.OPENAI && agent.endpoint !== agent.provider)
|
||||
) {
|
||||
llmConfig.streamUsage = false;
|
||||
llmConfig.usage = true;
|
||||
}
|
||||
|
||||
let reasoningKey: 'reasoning_content' | 'reasoning' | undefined;
|
||||
if (
|
||||
llmConfig.configuration?.baseURL?.includes(KnownEndpoints.openrouter) ||
|
||||
(agent.endpoint && agent.endpoint.toLowerCase().includes(KnownEndpoints.openrouter))
|
||||
) {
|
||||
reasoningKey = 'reasoning';
|
||||
}
|
||||
|
||||
const graphConfig: StandardGraphConfig = {
|
||||
signal,
|
||||
llmConfig,
|
||||
reasoningKey,
|
||||
tools: agent.tools,
|
||||
instructions: agent.instructions,
|
||||
additional_instructions: agent.additional_instructions,
|
||||
// toolEnd: agent.end_after_tools,
|
||||
};
|
||||
|
||||
// TEMPORARY FOR TESTING
|
||||
if (agent.provider === Providers.ANTHROPIC || agent.provider === Providers.BEDROCK) {
|
||||
graphConfig.streamBuffer = 2000;
|
||||
}
|
||||
|
||||
return Run.create({
|
||||
runId,
|
||||
graphConfig,
|
||||
customHandlers,
|
||||
});
|
||||
}
|
||||
1
packages/api/src/endpoints/index.ts
Normal file
1
packages/api/src/endpoints/index.ts
Normal file
|
|
@ -0,0 +1 @@
|
|||
export * from './openai';
|
||||
2
packages/api/src/endpoints/openai/index.ts
Normal file
2
packages/api/src/endpoints/openai/index.ts
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
export * from './llm';
|
||||
export * from './initialize';
|
||||
176
packages/api/src/endpoints/openai/initialize.ts
Normal file
176
packages/api/src/endpoints/openai/initialize.ts
Normal file
|
|
@ -0,0 +1,176 @@
|
|||
import {
|
||||
ErrorTypes,
|
||||
EModelEndpoint,
|
||||
resolveHeaders,
|
||||
mapModelToAzureConfig,
|
||||
} from 'librechat-data-provider';
|
||||
import type {
|
||||
LLMConfigOptions,
|
||||
UserKeyValues,
|
||||
InitializeOpenAIOptionsParams,
|
||||
OpenAIOptionsResult,
|
||||
} from '~/types';
|
||||
import { createHandleLLMNewToken } from '~/utils/generators';
|
||||
import { getAzureCredentials } from '~/utils/azure';
|
||||
import { isUserProvided } from '~/utils/common';
|
||||
import { getOpenAIConfig } from './llm';
|
||||
|
||||
/**
|
||||
* Initializes OpenAI options for agent usage. This function always returns configuration
|
||||
* options and never creates a client instance (equivalent to optionsOnly=true behavior).
|
||||
*
|
||||
* @param params - Configuration parameters
|
||||
* @returns Promise resolving to OpenAI configuration options
|
||||
* @throws Error if API key is missing or user key has expired
|
||||
*/
|
||||
export const initializeOpenAI = async ({
|
||||
req,
|
||||
overrideModel,
|
||||
endpointOption,
|
||||
overrideEndpoint,
|
||||
getUserKeyValues,
|
||||
checkUserKeyExpiry,
|
||||
}: InitializeOpenAIOptionsParams): Promise<OpenAIOptionsResult> => {
|
||||
const { PROXY, OPENAI_API_KEY, AZURE_API_KEY, OPENAI_REVERSE_PROXY, AZURE_OPENAI_BASEURL } =
|
||||
process.env;
|
||||
|
||||
const { key: expiresAt } = req.body;
|
||||
const modelName = overrideModel ?? req.body.model;
|
||||
const endpoint = overrideEndpoint ?? req.body.endpoint;
|
||||
|
||||
if (!endpoint) {
|
||||
throw new Error('Endpoint is required');
|
||||
}
|
||||
|
||||
const credentials = {
|
||||
[EModelEndpoint.openAI]: OPENAI_API_KEY,
|
||||
[EModelEndpoint.azureOpenAI]: AZURE_API_KEY,
|
||||
};
|
||||
|
||||
const baseURLOptions = {
|
||||
[EModelEndpoint.openAI]: OPENAI_REVERSE_PROXY,
|
||||
[EModelEndpoint.azureOpenAI]: AZURE_OPENAI_BASEURL,
|
||||
};
|
||||
|
||||
const userProvidesKey = isUserProvided(credentials[endpoint as keyof typeof credentials]);
|
||||
const userProvidesURL = isUserProvided(baseURLOptions[endpoint as keyof typeof baseURLOptions]);
|
||||
|
||||
let userValues: UserKeyValues | null = null;
|
||||
if (expiresAt && (userProvidesKey || userProvidesURL)) {
|
||||
checkUserKeyExpiry(expiresAt, endpoint);
|
||||
userValues = await getUserKeyValues({ userId: req.user.id, name: endpoint });
|
||||
}
|
||||
|
||||
let apiKey = userProvidesKey
|
||||
? userValues?.apiKey
|
||||
: credentials[endpoint as keyof typeof credentials];
|
||||
const baseURL = userProvidesURL
|
||||
? userValues?.baseURL
|
||||
: baseURLOptions[endpoint as keyof typeof baseURLOptions];
|
||||
|
||||
const clientOptions: LLMConfigOptions = {
|
||||
proxy: PROXY ?? undefined,
|
||||
reverseProxyUrl: baseURL || undefined,
|
||||
streaming: true,
|
||||
};
|
||||
|
||||
const isAzureOpenAI = endpoint === EModelEndpoint.azureOpenAI;
|
||||
const azureConfig = isAzureOpenAI && req.app.locals[EModelEndpoint.azureOpenAI];
|
||||
|
||||
if (isAzureOpenAI && azureConfig) {
|
||||
const { modelGroupMap, groupMap } = azureConfig;
|
||||
const {
|
||||
azureOptions,
|
||||
baseURL: configBaseURL,
|
||||
headers = {},
|
||||
serverless,
|
||||
} = mapModelToAzureConfig({
|
||||
modelName: modelName || '',
|
||||
modelGroupMap,
|
||||
groupMap,
|
||||
});
|
||||
|
||||
clientOptions.reverseProxyUrl = configBaseURL ?? clientOptions.reverseProxyUrl;
|
||||
clientOptions.headers = resolveHeaders({ ...headers, ...(clientOptions.headers ?? {}) });
|
||||
|
||||
const groupName = modelGroupMap[modelName || '']?.group;
|
||||
if (groupName && groupMap[groupName]) {
|
||||
clientOptions.addParams = groupMap[groupName]?.addParams;
|
||||
clientOptions.dropParams = groupMap[groupName]?.dropParams;
|
||||
}
|
||||
|
||||
apiKey = azureOptions.azureOpenAIApiKey;
|
||||
clientOptions.azure = !serverless ? azureOptions : undefined;
|
||||
|
||||
if (serverless === true) {
|
||||
clientOptions.defaultQuery = azureOptions.azureOpenAIApiVersion
|
||||
? { 'api-version': azureOptions.azureOpenAIApiVersion }
|
||||
: undefined;
|
||||
|
||||
if (!clientOptions.headers) {
|
||||
clientOptions.headers = {};
|
||||
}
|
||||
clientOptions.headers['api-key'] = apiKey;
|
||||
}
|
||||
} else if (isAzureOpenAI) {
|
||||
clientOptions.azure =
|
||||
userProvidesKey && userValues?.apiKey ? JSON.parse(userValues.apiKey) : getAzureCredentials();
|
||||
apiKey = clientOptions.azure?.azureOpenAIApiKey;
|
||||
}
|
||||
|
||||
if (userProvidesKey && !apiKey) {
|
||||
throw new Error(
|
||||
JSON.stringify({
|
||||
type: ErrorTypes.NO_USER_KEY,
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
if (!apiKey) {
|
||||
throw new Error(`${endpoint} API Key not provided.`);
|
||||
}
|
||||
|
||||
const modelOptions = {
|
||||
...endpointOption.model_parameters,
|
||||
model: modelName,
|
||||
user: req.user.id,
|
||||
};
|
||||
|
||||
const finalClientOptions: LLMConfigOptions = {
|
||||
...clientOptions,
|
||||
modelOptions,
|
||||
};
|
||||
|
||||
const options = getOpenAIConfig(apiKey, finalClientOptions, endpoint);
|
||||
|
||||
const openAIConfig = req.app.locals[EModelEndpoint.openAI];
|
||||
const allConfig = req.app.locals.all;
|
||||
const azureRate = modelName?.includes('gpt-4') ? 30 : 17;
|
||||
|
||||
let streamRate: number | undefined;
|
||||
|
||||
if (isAzureOpenAI && azureConfig) {
|
||||
streamRate = azureConfig.streamRate ?? azureRate;
|
||||
} else if (!isAzureOpenAI && openAIConfig) {
|
||||
streamRate = openAIConfig.streamRate;
|
||||
}
|
||||
|
||||
if (allConfig?.streamRate) {
|
||||
streamRate = allConfig.streamRate;
|
||||
}
|
||||
|
||||
if (streamRate) {
|
||||
options.llmConfig.callbacks = [
|
||||
{
|
||||
handleLLMNewToken: createHandleLLMNewToken(streamRate),
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
const result: OpenAIOptionsResult = {
|
||||
...options,
|
||||
streamRate,
|
||||
};
|
||||
|
||||
return result;
|
||||
};
|
||||
156
packages/api/src/endpoints/openai/llm.ts
Normal file
156
packages/api/src/endpoints/openai/llm.ts
Normal file
|
|
@ -0,0 +1,156 @@
|
|||
import { HttpsProxyAgent } from 'https-proxy-agent';
|
||||
import { KnownEndpoints } from 'librechat-data-provider';
|
||||
import type * as t from '~/types';
|
||||
import { sanitizeModelName, constructAzureURL } from '~/utils/azure';
|
||||
import { isEnabled } from '~/utils/common';
|
||||
|
||||
/**
|
||||
* Generates configuration options for creating a language model (LLM) instance.
|
||||
* @param apiKey - The API key for authentication.
|
||||
* @param options - Additional options for configuring the LLM.
|
||||
* @param endpoint - The endpoint name
|
||||
* @returns Configuration options for creating an LLM instance.
|
||||
*/
|
||||
export function getOpenAIConfig(
|
||||
apiKey: string,
|
||||
options: t.LLMConfigOptions = {},
|
||||
endpoint?: string | null,
|
||||
): t.LLMConfigResult {
|
||||
const {
|
||||
modelOptions = {},
|
||||
reverseProxyUrl,
|
||||
defaultQuery,
|
||||
headers,
|
||||
proxy,
|
||||
azure,
|
||||
streaming = true,
|
||||
addParams,
|
||||
dropParams,
|
||||
} = options;
|
||||
|
||||
const llmConfig: Partial<t.ClientOptions> & Partial<t.OpenAIParameters> = Object.assign(
|
||||
{
|
||||
streaming,
|
||||
model: modelOptions.model ?? '',
|
||||
},
|
||||
modelOptions,
|
||||
);
|
||||
|
||||
if (addParams && typeof addParams === 'object') {
|
||||
Object.assign(llmConfig, addParams);
|
||||
}
|
||||
|
||||
// Note: OpenAI Web Search models do not support any known parameters besides `max_tokens`
|
||||
if (modelOptions.model && /gpt-4o.*search/.test(modelOptions.model)) {
|
||||
const searchExcludeParams = [
|
||||
'frequency_penalty',
|
||||
'presence_penalty',
|
||||
'temperature',
|
||||
'top_p',
|
||||
'top_k',
|
||||
'stop',
|
||||
'logit_bias',
|
||||
'seed',
|
||||
'response_format',
|
||||
'n',
|
||||
'logprobs',
|
||||
'user',
|
||||
];
|
||||
|
||||
const updatedDropParams = dropParams || [];
|
||||
const combinedDropParams = [...new Set([...updatedDropParams, ...searchExcludeParams])];
|
||||
|
||||
combinedDropParams.forEach((param) => {
|
||||
if (param in llmConfig) {
|
||||
delete llmConfig[param as keyof t.ClientOptions];
|
||||
}
|
||||
});
|
||||
} else if (dropParams && Array.isArray(dropParams)) {
|
||||
dropParams.forEach((param) => {
|
||||
if (param in llmConfig) {
|
||||
delete llmConfig[param as keyof t.ClientOptions];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let useOpenRouter = false;
|
||||
const configOptions: t.OpenAIConfiguration = {};
|
||||
|
||||
if (
|
||||
(reverseProxyUrl && reverseProxyUrl.includes(KnownEndpoints.openrouter)) ||
|
||||
(endpoint && endpoint.toLowerCase().includes(KnownEndpoints.openrouter))
|
||||
) {
|
||||
useOpenRouter = true;
|
||||
llmConfig.include_reasoning = true;
|
||||
configOptions.baseURL = reverseProxyUrl;
|
||||
configOptions.defaultHeaders = Object.assign(
|
||||
{
|
||||
'HTTP-Referer': 'https://librechat.ai',
|
||||
'X-Title': 'LibreChat',
|
||||
},
|
||||
headers,
|
||||
);
|
||||
} else if (reverseProxyUrl) {
|
||||
configOptions.baseURL = reverseProxyUrl;
|
||||
if (headers) {
|
||||
configOptions.defaultHeaders = headers;
|
||||
}
|
||||
}
|
||||
|
||||
if (defaultQuery) {
|
||||
configOptions.defaultQuery = defaultQuery;
|
||||
}
|
||||
|
||||
if (proxy) {
|
||||
const proxyAgent = new HttpsProxyAgent(proxy);
|
||||
configOptions.httpAgent = proxyAgent;
|
||||
}
|
||||
|
||||
if (azure) {
|
||||
const useModelName = isEnabled(process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME);
|
||||
const updatedAzure = { ...azure };
|
||||
updatedAzure.azureOpenAIApiDeploymentName = useModelName
|
||||
? sanitizeModelName(llmConfig.model || '')
|
||||
: azure.azureOpenAIApiDeploymentName;
|
||||
|
||||
if (process.env.AZURE_OPENAI_DEFAULT_MODEL) {
|
||||
llmConfig.model = process.env.AZURE_OPENAI_DEFAULT_MODEL;
|
||||
}
|
||||
|
||||
if (configOptions.baseURL) {
|
||||
const azureURL = constructAzureURL({
|
||||
baseURL: configOptions.baseURL,
|
||||
azureOptions: updatedAzure,
|
||||
});
|
||||
updatedAzure.azureOpenAIBasePath = azureURL.split(
|
||||
`/${updatedAzure.azureOpenAIApiDeploymentName}`,
|
||||
)[0];
|
||||
}
|
||||
|
||||
Object.assign(llmConfig, updatedAzure);
|
||||
llmConfig.model = updatedAzure.azureOpenAIApiDeploymentName;
|
||||
} else {
|
||||
llmConfig.apiKey = apiKey;
|
||||
}
|
||||
|
||||
if (process.env.OPENAI_ORGANIZATION && azure) {
|
||||
configOptions.organization = process.env.OPENAI_ORGANIZATION;
|
||||
}
|
||||
|
||||
if (useOpenRouter && llmConfig.reasoning_effort != null) {
|
||||
llmConfig.reasoning = {
|
||||
effort: llmConfig.reasoning_effort,
|
||||
};
|
||||
delete llmConfig.reasoning_effort;
|
||||
}
|
||||
|
||||
if (llmConfig.max_tokens != null) {
|
||||
llmConfig.maxTokens = llmConfig.max_tokens;
|
||||
delete llmConfig.max_tokens;
|
||||
}
|
||||
|
||||
return {
|
||||
llmConfig,
|
||||
configOptions,
|
||||
};
|
||||
}
|
||||
14
packages/api/src/index.ts
Normal file
14
packages/api/src/index.ts
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
/* MCP */
|
||||
export * from './mcp/manager';
|
||||
/* Utilities */
|
||||
export * from './mcp/utils';
|
||||
export * from './utils';
|
||||
/* Flow */
|
||||
export * from './flow/manager';
|
||||
/* Agents */
|
||||
export * from './agents';
|
||||
/* Endpoints */
|
||||
export * from './endpoints';
|
||||
/* types */
|
||||
export type * from './mcp/types';
|
||||
export type * from './flow/types';
|
||||
|
|
@ -11,7 +11,7 @@ import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/
|
|||
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
|
||||
import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js';
|
||||
import type { Logger } from 'winston';
|
||||
import type * as t from './types/mcp.js';
|
||||
import type * as t from './types';
|
||||
|
||||
function isStdioOptions(options: t.MCPOptions): options is t.StdioOptions {
|
||||
return 'command' in options;
|
||||
|
|
@ -87,7 +87,7 @@ export class MCPConnection extends EventEmitter {
|
|||
this.lastPingTime = Date.now();
|
||||
this.client = new Client(
|
||||
{
|
||||
name: 'librechat-mcp-client',
|
||||
name: '@librechat/api-client',
|
||||
version: '1.2.2',
|
||||
},
|
||||
{
|
||||
|
|
@ -2,7 +2,7 @@ import { CallToolResultSchema, ErrorCode, McpError } from '@modelcontextprotocol
|
|||
import type { RequestOptions } from '@modelcontextprotocol/sdk/shared/protocol.js';
|
||||
import type { JsonSchemaType, MCPOptions } from 'librechat-data-provider';
|
||||
import type { Logger } from 'winston';
|
||||
import type * as t from './types/mcp';
|
||||
import type * as t from './types';
|
||||
import { formatToolContent } from './parsers';
|
||||
import { MCPConnection } from './connection';
|
||||
import { CONSTANTS } from './enum';
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
import type * as t from './types/mcp';
|
||||
import type * as t from './types';
|
||||
const RECOGNIZED_PROVIDERS = new Set([
|
||||
'google',
|
||||
'anthropic',
|
||||
|
|
@ -8,7 +8,6 @@ import {
|
|||
StreamableHTTPOptionsSchema,
|
||||
} from 'librechat-data-provider';
|
||||
import type { JsonSchemaType, TPlugin } from 'librechat-data-provider';
|
||||
import { ToolSchema, ListToolsResultSchema } from '@modelcontextprotocol/sdk/types.js';
|
||||
import type * as t from '@modelcontextprotocol/sdk/types.js';
|
||||
|
||||
export type StdioOptions = z.infer<typeof StdioOptionsSchema>;
|
||||
|
|
@ -45,8 +44,8 @@ export interface MCPPrompt {
|
|||
|
||||
export type ConnectionState = 'disconnected' | 'connecting' | 'connected' | 'error';
|
||||
|
||||
export type MCPTool = z.infer<typeof ToolSchema>;
|
||||
export type MCPToolListResponse = z.infer<typeof ListToolsResultSchema>;
|
||||
export type MCPTool = z.infer<typeof t.ToolSchema>;
|
||||
export type MCPToolListResponse = z.infer<typeof t.ListToolsResultSchema>;
|
||||
export type ToolContentPart = t.TextContent | t.ImageContent | t.EmbeddedResource | t.AudioContent;
|
||||
export type ImageContent = Extract<ToolContentPart, { type: 'image' }>;
|
||||
export type MCPToolCallResponse =
|
||||
19
packages/api/src/types/azure.ts
Normal file
19
packages/api/src/types/azure.ts
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
/**
|
||||
* Azure OpenAI configuration interface
|
||||
*/
|
||||
export interface AzureOptions {
|
||||
azureOpenAIApiKey?: string;
|
||||
azureOpenAIApiInstanceName?: string;
|
||||
azureOpenAIApiDeploymentName?: string;
|
||||
azureOpenAIApiVersion?: string;
|
||||
azureOpenAIBasePath?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Client with azure property for setting deployment name
|
||||
*/
|
||||
export interface GenericClient {
|
||||
azure: {
|
||||
azureOpenAIApiDeploymentName?: string;
|
||||
};
|
||||
}
|
||||
4
packages/api/src/types/events.ts
Normal file
4
packages/api/src/types/events.ts
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
export type ServerSentEvent = {
|
||||
data: string | Record<string, unknown>;
|
||||
event?: string;
|
||||
};
|
||||
4
packages/api/src/types/index.ts
Normal file
4
packages/api/src/types/index.ts
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
export * from './azure';
|
||||
export * from './events';
|
||||
export * from './openai';
|
||||
export * from './run';
|
||||
97
packages/api/src/types/openai.ts
Normal file
97
packages/api/src/types/openai.ts
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
import { z } from 'zod';
|
||||
import { openAISchema, EModelEndpoint } from 'librechat-data-provider';
|
||||
import type { TEndpointOption, TAzureConfig, TEndpoint } from 'librechat-data-provider';
|
||||
import type { OpenAIClientOptions } from '@librechat/agents';
|
||||
import type { AzureOptions } from './azure';
|
||||
|
||||
export type OpenAIParameters = z.infer<typeof openAISchema>;
|
||||
|
||||
/**
|
||||
* Configuration options for the getLLMConfig function
|
||||
*/
|
||||
export interface LLMConfigOptions {
|
||||
modelOptions?: Partial<OpenAIParameters>;
|
||||
reverseProxyUrl?: string;
|
||||
defaultQuery?: Record<string, string | undefined>;
|
||||
headers?: Record<string, string>;
|
||||
proxy?: string;
|
||||
azure?: AzureOptions;
|
||||
streaming?: boolean;
|
||||
addParams?: Record<string, unknown>;
|
||||
dropParams?: string[];
|
||||
}
|
||||
|
||||
export type OpenAIConfiguration = OpenAIClientOptions['configuration'];
|
||||
|
||||
export type ClientOptions = OpenAIClientOptions & {
|
||||
include_reasoning?: boolean;
|
||||
};
|
||||
|
||||
/**
|
||||
* Return type for getLLMConfig function
|
||||
*/
|
||||
export interface LLMConfigResult {
|
||||
llmConfig: ClientOptions;
|
||||
configOptions: OpenAIConfiguration;
|
||||
}
|
||||
|
||||
/**
|
||||
* Interface for user values retrieved from the database
|
||||
*/
|
||||
export interface UserKeyValues {
|
||||
apiKey?: string;
|
||||
baseURL?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Request interface with only the properties we need (avoids Express typing conflicts)
|
||||
*/
|
||||
export interface RequestData {
|
||||
user: {
|
||||
id: string;
|
||||
};
|
||||
body: {
|
||||
model?: string;
|
||||
endpoint?: string;
|
||||
key?: string;
|
||||
};
|
||||
app: {
|
||||
locals: {
|
||||
[EModelEndpoint.azureOpenAI]?: TAzureConfig;
|
||||
[EModelEndpoint.openAI]?: TEndpoint;
|
||||
all?: TEndpoint;
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Function type for getting user key values
|
||||
*/
|
||||
export type GetUserKeyValuesFunction = (params: {
|
||||
userId: string;
|
||||
name: string;
|
||||
}) => Promise<UserKeyValues>;
|
||||
|
||||
/**
|
||||
* Function type for checking user key expiry
|
||||
*/
|
||||
export type CheckUserKeyExpiryFunction = (expiresAt: string, endpoint: string) => void;
|
||||
|
||||
/**
|
||||
* Parameters for the initializeOpenAI function
|
||||
*/
|
||||
export interface InitializeOpenAIOptionsParams {
|
||||
req: RequestData;
|
||||
overrideModel?: string;
|
||||
overrideEndpoint?: string;
|
||||
endpointOption: Partial<TEndpointOption>;
|
||||
getUserKeyValues: GetUserKeyValuesFunction;
|
||||
checkUserKeyExpiry: CheckUserKeyExpiryFunction;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extended LLM config result with stream rate handling
|
||||
*/
|
||||
export interface OpenAIOptionsResult extends LLMConfigResult {
|
||||
streamRate?: number;
|
||||
}
|
||||
10
packages/api/src/types/run.ts
Normal file
10
packages/api/src/types/run.ts
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
import type { AgentModelParameters, EModelEndpoint } from 'librechat-data-provider';
|
||||
import type { OpenAIConfiguration } from './openai';
|
||||
|
||||
export type RunLLMConfig = {
|
||||
provider: EModelEndpoint;
|
||||
streaming: boolean;
|
||||
streamUsage: boolean;
|
||||
usage?: boolean;
|
||||
configuration?: OpenAIConfiguration;
|
||||
} & AgentModelParameters;
|
||||
269
packages/api/src/utils/azure.spec.ts
Normal file
269
packages/api/src/utils/azure.spec.ts
Normal file
|
|
@ -0,0 +1,269 @@
|
|||
import {
|
||||
genAzureChatCompletion,
|
||||
getAzureCredentials,
|
||||
constructAzureURL,
|
||||
sanitizeModelName,
|
||||
genAzureEndpoint,
|
||||
} from './azure';
|
||||
import type { GenericClient } from '~/types';
|
||||
|
||||
describe('sanitizeModelName', () => {
|
||||
test('removes periods from the model name', () => {
|
||||
const sanitized = sanitizeModelName('model.name');
|
||||
expect(sanitized).toBe('modelname');
|
||||
});
|
||||
|
||||
test('leaves model name unchanged if no periods are present', () => {
|
||||
const sanitized = sanitizeModelName('modelname');
|
||||
expect(sanitized).toBe('modelname');
|
||||
});
|
||||
});
|
||||
|
||||
describe('genAzureEndpoint', () => {
|
||||
test('generates correct endpoint URL', () => {
|
||||
const url = genAzureEndpoint({
|
||||
azureOpenAIApiInstanceName: 'instanceName',
|
||||
azureOpenAIApiDeploymentName: 'deploymentName',
|
||||
});
|
||||
expect(url).toBe('https://instanceName.openai.azure.com/openai/deployments/deploymentName');
|
||||
});
|
||||
});
|
||||
|
||||
describe('genAzureChatCompletion', () => {
|
||||
// Test with both deployment name and model name provided
|
||||
test('prefers model name over deployment name when both are provided and feature enabled', () => {
|
||||
process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME = 'true';
|
||||
const url = genAzureChatCompletion(
|
||||
{
|
||||
azureOpenAIApiInstanceName: 'instanceName',
|
||||
azureOpenAIApiDeploymentName: 'deploymentName',
|
||||
azureOpenAIApiVersion: 'v1',
|
||||
},
|
||||
'modelName',
|
||||
);
|
||||
expect(url).toBe(
|
||||
'https://instanceName.openai.azure.com/openai/deployments/modelName/chat/completions?api-version=v1',
|
||||
);
|
||||
});
|
||||
|
||||
// Test with only deployment name provided
|
||||
test('uses deployment name when model name is not provided', () => {
|
||||
const url = genAzureChatCompletion({
|
||||
azureOpenAIApiInstanceName: 'instanceName',
|
||||
azureOpenAIApiDeploymentName: 'deploymentName',
|
||||
azureOpenAIApiVersion: 'v1',
|
||||
});
|
||||
expect(url).toBe(
|
||||
'https://instanceName.openai.azure.com/openai/deployments/deploymentName/chat/completions?api-version=v1',
|
||||
);
|
||||
});
|
||||
|
||||
// Test with only model name provided
|
||||
test('uses model name when deployment name is not provided and feature enabled', () => {
|
||||
process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME = 'true';
|
||||
const url = genAzureChatCompletion(
|
||||
{
|
||||
azureOpenAIApiInstanceName: 'instanceName',
|
||||
azureOpenAIApiVersion: 'v1',
|
||||
},
|
||||
'modelName',
|
||||
);
|
||||
expect(url).toBe(
|
||||
'https://instanceName.openai.azure.com/openai/deployments/modelName/chat/completions?api-version=v1',
|
||||
);
|
||||
});
|
||||
|
||||
// Test with neither deployment name nor model name provided
|
||||
test('throws error if neither deployment name nor model name is provided', () => {
|
||||
expect(() => {
|
||||
genAzureChatCompletion({
|
||||
azureOpenAIApiInstanceName: 'instanceName',
|
||||
azureOpenAIApiVersion: 'v1',
|
||||
});
|
||||
}).toThrow(
|
||||
'Either a model name with the `AZURE_USE_MODEL_AS_DEPLOYMENT_NAME` setting or a deployment name must be provided if `AZURE_OPENAI_BASEURL` is omitted.',
|
||||
);
|
||||
});
|
||||
|
||||
// Test with feature disabled but model name provided
|
||||
test('ignores model name and uses deployment name when feature is disabled', () => {
|
||||
process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME = 'false';
|
||||
const url = genAzureChatCompletion(
|
||||
{
|
||||
azureOpenAIApiInstanceName: 'instanceName',
|
||||
azureOpenAIApiDeploymentName: 'deploymentName',
|
||||
azureOpenAIApiVersion: 'v1',
|
||||
},
|
||||
'modelName',
|
||||
);
|
||||
expect(url).toBe(
|
||||
'https://instanceName.openai.azure.com/openai/deployments/deploymentName/chat/completions?api-version=v1',
|
||||
);
|
||||
});
|
||||
|
||||
// Test with sanitized model name
|
||||
test('sanitizes model name when used in URL', () => {
|
||||
process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME = 'true';
|
||||
const url = genAzureChatCompletion(
|
||||
{
|
||||
azureOpenAIApiInstanceName: 'instanceName',
|
||||
azureOpenAIApiVersion: 'v1',
|
||||
},
|
||||
'model.name',
|
||||
);
|
||||
expect(url).toBe(
|
||||
'https://instanceName.openai.azure.com/openai/deployments/modelname/chat/completions?api-version=v1',
|
||||
);
|
||||
});
|
||||
|
||||
// Test with client parameter and model name
|
||||
test('updates client with sanitized model name when provided and feature enabled', () => {
|
||||
process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME = 'true';
|
||||
const clientMock = { azure: {} } as GenericClient;
|
||||
const url = genAzureChatCompletion(
|
||||
{
|
||||
azureOpenAIApiInstanceName: 'instanceName',
|
||||
azureOpenAIApiVersion: 'v1',
|
||||
},
|
||||
'model.name',
|
||||
clientMock,
|
||||
);
|
||||
expect(url).toBe(
|
||||
'https://instanceName.openai.azure.com/openai/deployments/modelname/chat/completions?api-version=v1',
|
||||
);
|
||||
expect(clientMock.azure.azureOpenAIApiDeploymentName).toBe('modelname');
|
||||
});
|
||||
|
||||
// Test with client parameter but without model name
|
||||
test('does not update client when model name is not provided', () => {
|
||||
const clientMock = { azure: {} } as GenericClient;
|
||||
const url = genAzureChatCompletion(
|
||||
{
|
||||
azureOpenAIApiInstanceName: 'instanceName',
|
||||
azureOpenAIApiDeploymentName: 'deploymentName',
|
||||
azureOpenAIApiVersion: 'v1',
|
||||
},
|
||||
undefined,
|
||||
clientMock,
|
||||
);
|
||||
expect(url).toBe(
|
||||
'https://instanceName.openai.azure.com/openai/deployments/deploymentName/chat/completions?api-version=v1',
|
||||
);
|
||||
expect(clientMock.azure.azureOpenAIApiDeploymentName).toBeUndefined();
|
||||
});
|
||||
|
||||
// Test with client parameter and deployment name when feature is disabled
|
||||
test('does not update client when feature is disabled', () => {
|
||||
process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME = 'false';
|
||||
const clientMock = { azure: {} } as GenericClient;
|
||||
const url = genAzureChatCompletion(
|
||||
{
|
||||
azureOpenAIApiInstanceName: 'instanceName',
|
||||
azureOpenAIApiDeploymentName: 'deploymentName',
|
||||
azureOpenAIApiVersion: 'v1',
|
||||
},
|
||||
'modelName',
|
||||
clientMock,
|
||||
);
|
||||
expect(url).toBe(
|
||||
'https://instanceName.openai.azure.com/openai/deployments/deploymentName/chat/completions?api-version=v1',
|
||||
);
|
||||
expect(clientMock.azure.azureOpenAIApiDeploymentName).toBeUndefined();
|
||||
});
|
||||
|
||||
// Reset environment variable after tests
|
||||
afterEach(() => {
|
||||
delete process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME;
|
||||
});
|
||||
});
|
||||
|
||||
describe('getAzureCredentials', () => {
|
||||
beforeEach(() => {
|
||||
process.env.AZURE_API_KEY = 'testApiKey';
|
||||
process.env.AZURE_OPENAI_API_INSTANCE_NAME = 'instanceName';
|
||||
process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME = 'deploymentName';
|
||||
process.env.AZURE_OPENAI_API_VERSION = 'v1';
|
||||
});
|
||||
|
||||
test('retrieves Azure OpenAI API credentials from environment variables', () => {
|
||||
const credentials = getAzureCredentials();
|
||||
expect(credentials).toEqual({
|
||||
azureOpenAIApiKey: 'testApiKey',
|
||||
azureOpenAIApiInstanceName: 'instanceName',
|
||||
azureOpenAIApiDeploymentName: 'deploymentName',
|
||||
azureOpenAIApiVersion: 'v1',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('constructAzureURL', () => {
|
||||
test('replaces both placeholders when both properties are provided', () => {
|
||||
const url = constructAzureURL({
|
||||
baseURL: 'https://example.com/${INSTANCE_NAME}/${DEPLOYMENT_NAME}',
|
||||
azureOptions: {
|
||||
azureOpenAIApiInstanceName: 'instance1',
|
||||
azureOpenAIApiDeploymentName: 'deployment1',
|
||||
},
|
||||
});
|
||||
expect(url).toBe('https://example.com/instance1/deployment1');
|
||||
});
|
||||
|
||||
test('replaces only INSTANCE_NAME when only azureOpenAIApiInstanceName is provided', () => {
|
||||
const url = constructAzureURL({
|
||||
baseURL: 'https://example.com/${INSTANCE_NAME}/${DEPLOYMENT_NAME}',
|
||||
azureOptions: {
|
||||
azureOpenAIApiInstanceName: 'instance2',
|
||||
},
|
||||
});
|
||||
expect(url).toBe('https://example.com/instance2/');
|
||||
});
|
||||
|
||||
test('replaces only DEPLOYMENT_NAME when only azureOpenAIApiDeploymentName is provided', () => {
|
||||
const url = constructAzureURL({
|
||||
baseURL: 'https://example.com/${INSTANCE_NAME}/${DEPLOYMENT_NAME}',
|
||||
azureOptions: {
|
||||
azureOpenAIApiDeploymentName: 'deployment2',
|
||||
},
|
||||
});
|
||||
expect(url).toBe('https://example.com//deployment2');
|
||||
});
|
||||
|
||||
test('does not replace any placeholders when azure object is empty', () => {
|
||||
const url = constructAzureURL({
|
||||
baseURL: 'https://example.com/${INSTANCE_NAME}/${DEPLOYMENT_NAME}',
|
||||
azureOptions: {},
|
||||
});
|
||||
expect(url).toBe('https://example.com//');
|
||||
});
|
||||
|
||||
test('returns baseURL as is when `azureOptions` object is not provided', () => {
|
||||
const url = constructAzureURL({
|
||||
baseURL: 'https://example.com/${INSTANCE_NAME}/${DEPLOYMENT_NAME}',
|
||||
});
|
||||
expect(url).toBe('https://example.com/${INSTANCE_NAME}/${DEPLOYMENT_NAME}');
|
||||
});
|
||||
|
||||
test('returns baseURL as is when no placeholders are set', () => {
|
||||
const url = constructAzureURL({
|
||||
baseURL: 'https://example.com/my_custom_instance/my_deployment',
|
||||
azureOptions: {
|
||||
azureOpenAIApiInstanceName: 'instance1',
|
||||
azureOpenAIApiDeploymentName: 'deployment1',
|
||||
},
|
||||
});
|
||||
expect(url).toBe('https://example.com/my_custom_instance/my_deployment');
|
||||
});
|
||||
|
||||
test('returns regular Azure OpenAI baseURL with placeholders set', () => {
|
||||
const baseURL =
|
||||
'https://${INSTANCE_NAME}.openai.azure.com/openai/deployments/${DEPLOYMENT_NAME}';
|
||||
const url = constructAzureURL({
|
||||
baseURL,
|
||||
azureOptions: {
|
||||
azureOpenAIApiInstanceName: 'instance1',
|
||||
azureOpenAIApiDeploymentName: 'deployment1',
|
||||
},
|
||||
});
|
||||
expect(url).toBe('https://instance1.openai.azure.com/openai/deployments/deployment1');
|
||||
});
|
||||
});
|
||||
120
packages/api/src/utils/azure.ts
Normal file
120
packages/api/src/utils/azure.ts
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
import { isEnabled } from './common';
|
||||
import type { AzureOptions, GenericClient } from '~/types';
|
||||
|
||||
/**
|
||||
* Sanitizes the model name to be used in the URL by removing or replacing disallowed characters.
|
||||
* @param modelName - The model name to be sanitized.
|
||||
* @returns The sanitized model name.
|
||||
*/
|
||||
export const sanitizeModelName = (modelName: string): string => {
|
||||
// Replace periods with empty strings and other disallowed characters as needed.
|
||||
return modelName.replace(/\./g, '');
|
||||
};
|
||||
|
||||
/**
|
||||
* Generates the Azure OpenAI API endpoint URL.
|
||||
* @param params - The parameters object.
|
||||
* @param params.azureOpenAIApiInstanceName - The Azure OpenAI API instance name.
|
||||
* @param params.azureOpenAIApiDeploymentName - The Azure OpenAI API deployment name.
|
||||
* @returns The complete endpoint URL for the Azure OpenAI API.
|
||||
*/
|
||||
export const genAzureEndpoint = ({
|
||||
azureOpenAIApiInstanceName,
|
||||
azureOpenAIApiDeploymentName,
|
||||
}: {
|
||||
azureOpenAIApiInstanceName: string;
|
||||
azureOpenAIApiDeploymentName: string;
|
||||
}): string => {
|
||||
return `https://${azureOpenAIApiInstanceName}.openai.azure.com/openai/deployments/${azureOpenAIApiDeploymentName}`;
|
||||
};
|
||||
|
||||
/**
|
||||
* Generates the Azure OpenAI API chat completion endpoint URL with the API version.
|
||||
* If both deploymentName and modelName are provided, modelName takes precedence.
|
||||
* @param azureConfig - The Azure configuration object.
|
||||
* @param azureConfig.azureOpenAIApiInstanceName - The Azure OpenAI API instance name.
|
||||
* @param azureConfig.azureOpenAIApiDeploymentName - The Azure OpenAI API deployment name (optional).
|
||||
* @param azureConfig.azureOpenAIApiVersion - The Azure OpenAI API version.
|
||||
* @param modelName - The model name to be included in the deployment name (optional).
|
||||
* @param client - The API Client class for optionally setting properties (optional).
|
||||
* @returns The complete chat completion endpoint URL for the Azure OpenAI API.
|
||||
* @throws Error if neither azureOpenAIApiDeploymentName nor modelName is provided.
|
||||
*/
|
||||
export const genAzureChatCompletion = (
|
||||
{
|
||||
azureOpenAIApiInstanceName,
|
||||
azureOpenAIApiDeploymentName,
|
||||
azureOpenAIApiVersion,
|
||||
}: {
|
||||
azureOpenAIApiInstanceName: string;
|
||||
azureOpenAIApiDeploymentName?: string;
|
||||
azureOpenAIApiVersion: string;
|
||||
},
|
||||
modelName?: string,
|
||||
client?: GenericClient,
|
||||
): string => {
|
||||
// Determine the deployment segment of the URL based on provided modelName or azureOpenAIApiDeploymentName
|
||||
let deploymentSegment: string;
|
||||
if (isEnabled(process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME) && modelName) {
|
||||
const sanitizedModelName = sanitizeModelName(modelName);
|
||||
deploymentSegment = sanitizedModelName;
|
||||
if (client && typeof client === 'object') {
|
||||
client.azure.azureOpenAIApiDeploymentName = sanitizedModelName;
|
||||
}
|
||||
} else if (azureOpenAIApiDeploymentName) {
|
||||
deploymentSegment = azureOpenAIApiDeploymentName;
|
||||
} else if (!process.env.AZURE_OPENAI_BASEURL) {
|
||||
throw new Error(
|
||||
'Either a model name with the `AZURE_USE_MODEL_AS_DEPLOYMENT_NAME` setting or a deployment name must be provided if `AZURE_OPENAI_BASEURL` is omitted.',
|
||||
);
|
||||
} else {
|
||||
deploymentSegment = '';
|
||||
}
|
||||
|
||||
return `https://${azureOpenAIApiInstanceName}.openai.azure.com/openai/deployments/${deploymentSegment}/chat/completions?api-version=${azureOpenAIApiVersion}`;
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieves the Azure OpenAI API credentials from environment variables.
|
||||
* @returns An object containing the Azure OpenAI API credentials.
|
||||
*/
|
||||
export const getAzureCredentials = (): AzureOptions => {
|
||||
return {
|
||||
azureOpenAIApiKey: process.env.AZURE_API_KEY ?? process.env.AZURE_OPENAI_API_KEY,
|
||||
azureOpenAIApiInstanceName: process.env.AZURE_OPENAI_API_INSTANCE_NAME,
|
||||
azureOpenAIApiDeploymentName: process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME,
|
||||
azureOpenAIApiVersion: process.env.AZURE_OPENAI_API_VERSION,
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Constructs a URL by replacing placeholders in the baseURL with values from the azure object.
|
||||
* It specifically looks for '${INSTANCE_NAME}' and '${DEPLOYMENT_NAME}' within the baseURL and replaces
|
||||
* them with 'azureOpenAIApiInstanceName' and 'azureOpenAIApiDeploymentName' from the azure object.
|
||||
* If the respective azure property is not provided, the placeholder is replaced with an empty string.
|
||||
*
|
||||
* @param params - The parameters object.
|
||||
* @param params.baseURL - The baseURL to inspect for replacement placeholders.
|
||||
* @param params.azureOptions - The azure options object containing the instance and deployment names.
|
||||
* @returns The complete baseURL with credentials injected for the Azure OpenAI API.
|
||||
*/
|
||||
export function constructAzureURL({
|
||||
baseURL,
|
||||
azureOptions,
|
||||
}: {
|
||||
baseURL: string;
|
||||
azureOptions?: AzureOptions;
|
||||
}): string {
|
||||
let finalURL = baseURL;
|
||||
|
||||
// Replace INSTANCE_NAME and DEPLOYMENT_NAME placeholders with actual values if available
|
||||
if (azureOptions) {
|
||||
finalURL = finalURL.replace('${INSTANCE_NAME}', azureOptions.azureOpenAIApiInstanceName ?? '');
|
||||
finalURL = finalURL.replace(
|
||||
'${DEPLOYMENT_NAME}',
|
||||
azureOptions.azureOpenAIApiDeploymentName ?? '',
|
||||
);
|
||||
}
|
||||
|
||||
return finalURL;
|
||||
}
|
||||
55
packages/api/src/utils/common.spec.ts
Normal file
55
packages/api/src/utils/common.spec.ts
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
/* eslint-disable @typescript-eslint/ban-ts-comment */
|
||||
import { isEnabled } from './common';
|
||||
|
||||
describe('isEnabled', () => {
|
||||
test('should return true when input is "true"', () => {
|
||||
expect(isEnabled('true')).toBe(true);
|
||||
});
|
||||
|
||||
test('should return true when input is "TRUE"', () => {
|
||||
expect(isEnabled('TRUE')).toBe(true);
|
||||
});
|
||||
|
||||
test('should return true when input is true', () => {
|
||||
expect(isEnabled(true)).toBe(true);
|
||||
});
|
||||
|
||||
test('should return false when input is "false"', () => {
|
||||
expect(isEnabled('false')).toBe(false);
|
||||
});
|
||||
|
||||
test('should return false when input is false', () => {
|
||||
expect(isEnabled(false)).toBe(false);
|
||||
});
|
||||
|
||||
test('should return false when input is null', () => {
|
||||
expect(isEnabled(null)).toBe(false);
|
||||
});
|
||||
|
||||
test('should return false when input is undefined', () => {
|
||||
expect(isEnabled()).toBe(false);
|
||||
});
|
||||
|
||||
test('should return false when input is an empty string', () => {
|
||||
expect(isEnabled('')).toBe(false);
|
||||
});
|
||||
|
||||
test('should return false when input is a whitespace string', () => {
|
||||
expect(isEnabled(' ')).toBe(false);
|
||||
});
|
||||
|
||||
test('should return false when input is a number', () => {
|
||||
// @ts-expect-error
|
||||
expect(isEnabled(123)).toBe(false);
|
||||
});
|
||||
|
||||
test('should return false when input is an object', () => {
|
||||
// @ts-expect-error
|
||||
expect(isEnabled({})).toBe(false);
|
||||
});
|
||||
|
||||
test('should return false when input is an array', () => {
|
||||
// @ts-expect-error
|
||||
expect(isEnabled([])).toBe(false);
|
||||
});
|
||||
});
|
||||
48
packages/api/src/utils/common.ts
Normal file
48
packages/api/src/utils/common.ts
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
/**
|
||||
* Checks if the given value is truthy by being either the boolean `true` or a string
|
||||
* that case-insensitively matches 'true'.
|
||||
*
|
||||
* @param value - The value to check.
|
||||
* @returns Returns `true` if the value is the boolean `true` or a case-insensitive
|
||||
* match for the string 'true', otherwise returns `false`.
|
||||
* @example
|
||||
*
|
||||
* isEnabled("True"); // returns true
|
||||
* isEnabled("TRUE"); // returns true
|
||||
* isEnabled(true); // returns true
|
||||
* isEnabled("false"); // returns false
|
||||
* isEnabled(false); // returns false
|
||||
* isEnabled(null); // returns false
|
||||
* isEnabled(); // returns false
|
||||
*/
|
||||
export function isEnabled(value?: string | boolean | null | undefined): boolean {
|
||||
if (typeof value === 'boolean') {
|
||||
return value;
|
||||
}
|
||||
if (typeof value === 'string') {
|
||||
return value.toLowerCase().trim() === 'true';
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if the provided value is 'user_provided'.
|
||||
*
|
||||
* @param value - The value to check.
|
||||
* @returns - Returns true if the value is 'user_provided', otherwise false.
|
||||
*/
|
||||
export const isUserProvided = (value?: string): boolean => value === 'user_provided';
|
||||
|
||||
/**
|
||||
* @param values
|
||||
*/
|
||||
export function optionalChainWithEmptyCheck(
|
||||
...values: (string | number | undefined)[]
|
||||
): string | number | undefined {
|
||||
for (const value of values) {
|
||||
if (value !== undefined && value !== null && value !== '') {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
return values[values.length - 1];
|
||||
}
|
||||
16
packages/api/src/utils/events.ts
Normal file
16
packages/api/src/utils/events.ts
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
import type { Response as ServerResponse } from 'express';
|
||||
import type { ServerSentEvent } from '~/types';
|
||||
|
||||
/**
|
||||
* Sends message data in Server Sent Events format.
|
||||
* @param res - The server response.
|
||||
* @param event - The message event.
|
||||
* @param event.event - The type of event.
|
||||
* @param event.data - The message to be sent.
|
||||
*/
|
||||
export function sendEvent(res: ServerResponse, event: ServerSentEvent): void {
|
||||
if (typeof event.data === 'string' && event.data.length === 0) {
|
||||
return;
|
||||
}
|
||||
res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`);
|
||||
}
|
||||
115
packages/api/src/utils/files.spec.ts
Normal file
115
packages/api/src/utils/files.spec.ts
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
import { sanitizeFilename } from './files';
|
||||
|
||||
jest.mock('node:crypto', () => {
|
||||
const actualModule = jest.requireActual('node:crypto');
|
||||
return {
|
||||
...actualModule,
|
||||
randomBytes: jest.fn().mockReturnValue(Buffer.from('abc123', 'hex')),
|
||||
};
|
||||
});
|
||||
|
||||
describe('sanitizeFilename', () => {
|
||||
test('removes directory components (1/2)', () => {
|
||||
expect(sanitizeFilename('/path/to/file.txt')).toBe('file.txt');
|
||||
});
|
||||
|
||||
test('removes directory components (2/2)', () => {
|
||||
expect(sanitizeFilename('../../../../file.txt')).toBe('file.txt');
|
||||
});
|
||||
|
||||
test('replaces non-alphanumeric characters', () => {
|
||||
expect(sanitizeFilename('file name@#$.txt')).toBe('file_name___.txt');
|
||||
});
|
||||
|
||||
test('preserves dots and hyphens', () => {
|
||||
expect(sanitizeFilename('file-name.with.dots.txt')).toBe('file-name.with.dots.txt');
|
||||
});
|
||||
|
||||
test('prepends underscore to filenames starting with a dot', () => {
|
||||
expect(sanitizeFilename('.hiddenfile')).toBe('_.hiddenfile');
|
||||
});
|
||||
|
||||
test('truncates long filenames', () => {
|
||||
const longName = 'a'.repeat(300) + '.txt';
|
||||
const result = sanitizeFilename(longName);
|
||||
expect(result.length).toBe(255);
|
||||
expect(result).toMatch(/^a+-abc123\.txt$/);
|
||||
});
|
||||
|
||||
test('handles filenames with no extension', () => {
|
||||
const longName = 'a'.repeat(300);
|
||||
const result = sanitizeFilename(longName);
|
||||
expect(result.length).toBe(255);
|
||||
expect(result).toMatch(/^a+-abc123$/);
|
||||
});
|
||||
|
||||
test('handles empty input', () => {
|
||||
expect(sanitizeFilename('')).toBe('_');
|
||||
});
|
||||
|
||||
test('handles input with only special characters', () => {
|
||||
expect(sanitizeFilename('@#$%^&*')).toBe('_______');
|
||||
});
|
||||
});
|
||||
|
||||
describe('sanitizeFilename with real crypto', () => {
|
||||
// Temporarily unmock crypto for these tests
|
||||
beforeAll(() => {
|
||||
jest.resetModules();
|
||||
jest.unmock('node:crypto');
|
||||
});
|
||||
|
||||
afterAll(() => {
|
||||
jest.resetModules();
|
||||
jest.mock('node:crypto', () => {
|
||||
const actualModule = jest.requireActual('node:crypto');
|
||||
return {
|
||||
...actualModule,
|
||||
randomBytes: jest.fn().mockReturnValue(Buffer.from('abc123', 'hex')),
|
||||
};
|
||||
});
|
||||
});
|
||||
|
||||
test('truncates long filenames with real crypto', async () => {
|
||||
const { sanitizeFilename: realSanitizeFilename } = await import('./files');
|
||||
const longName = 'b'.repeat(300) + '.pdf';
|
||||
const result = realSanitizeFilename(longName);
|
||||
|
||||
expect(result.length).toBe(255);
|
||||
expect(result).toMatch(/^b+-[a-f0-9]{6}\.pdf$/);
|
||||
expect(result.endsWith('.pdf')).toBe(true);
|
||||
});
|
||||
|
||||
test('handles filenames with no extension with real crypto', async () => {
|
||||
const { sanitizeFilename: realSanitizeFilename } = await import('./files');
|
||||
const longName = 'c'.repeat(300);
|
||||
const result = realSanitizeFilename(longName);
|
||||
|
||||
expect(result.length).toBe(255);
|
||||
expect(result).toMatch(/^c+-[a-f0-9]{6}$/);
|
||||
expect(result).not.toContain('.');
|
||||
});
|
||||
|
||||
test('generates unique suffixes for identical long filenames', async () => {
|
||||
const { sanitizeFilename: realSanitizeFilename } = await import('./files');
|
||||
const longName = 'd'.repeat(300) + '.doc';
|
||||
const result1 = realSanitizeFilename(longName);
|
||||
const result2 = realSanitizeFilename(longName);
|
||||
|
||||
expect(result1.length).toBe(255);
|
||||
expect(result2.length).toBe(255);
|
||||
expect(result1).not.toBe(result2); // Should be different due to random suffix
|
||||
expect(result1.endsWith('.doc')).toBe(true);
|
||||
expect(result2.endsWith('.doc')).toBe(true);
|
||||
});
|
||||
|
||||
test('real crypto produces valid hex strings', async () => {
|
||||
const { sanitizeFilename: realSanitizeFilename } = await import('./files');
|
||||
const longName = 'test'.repeat(100) + '.txt';
|
||||
const result = realSanitizeFilename(longName);
|
||||
|
||||
const hexMatch = result.match(/-([a-f0-9]{6})\.txt$/);
|
||||
expect(hexMatch).toBeTruthy();
|
||||
expect(hexMatch![1]).toMatch(/^[a-f0-9]{6}$/);
|
||||
});
|
||||
});
|
||||
33
packages/api/src/utils/files.ts
Normal file
33
packages/api/src/utils/files.ts
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
import path from 'path';
|
||||
import crypto from 'node:crypto';
|
||||
|
||||
/**
|
||||
* Sanitize a filename by removing any directory components, replacing non-alphanumeric characters
|
||||
* @param inputName
|
||||
*/
|
||||
export function sanitizeFilename(inputName: string): string {
|
||||
// Remove any directory components
|
||||
let name = path.basename(inputName);
|
||||
|
||||
// Replace any non-alphanumeric characters except for '.' and '-'
|
||||
name = name.replace(/[^a-zA-Z0-9.-]/g, '_');
|
||||
|
||||
// Ensure the name doesn't start with a dot (hidden file in Unix-like systems)
|
||||
if (name.startsWith('.') || name === '') {
|
||||
name = '_' + name;
|
||||
}
|
||||
|
||||
// Limit the length of the filename
|
||||
const MAX_LENGTH = 255;
|
||||
if (name.length > MAX_LENGTH) {
|
||||
const ext = path.extname(name);
|
||||
const nameWithoutExt = path.basename(name, ext);
|
||||
name =
|
||||
nameWithoutExt.slice(0, MAX_LENGTH - ext.length - 7) +
|
||||
'-' +
|
||||
crypto.randomBytes(3).toString('hex') +
|
||||
ext;
|
||||
}
|
||||
|
||||
return name;
|
||||
}
|
||||
75
packages/api/src/utils/generators.ts
Normal file
75
packages/api/src/utils/generators.ts
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
import fetch from 'node-fetch';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import { GraphEvents, sleep } from '@librechat/agents';
|
||||
import type { Response as ServerResponse } from 'express';
|
||||
import type { ServerSentEvent } from '~/types';
|
||||
import { sendEvent } from './events';
|
||||
|
||||
/**
|
||||
* Makes a function to make HTTP request and logs the process.
|
||||
* @param params
|
||||
* @param params.directEndpoint - Whether to use a direct endpoint.
|
||||
* @param params.reverseProxyUrl - The reverse proxy URL to use for the request.
|
||||
* @returns A promise that resolves to the response of the fetch request.
|
||||
*/
|
||||
export function createFetch({
|
||||
directEndpoint = false,
|
||||
reverseProxyUrl = '',
|
||||
}: {
|
||||
directEndpoint?: boolean;
|
||||
reverseProxyUrl?: string;
|
||||
}) {
|
||||
/**
|
||||
* Makes an HTTP request and logs the process.
|
||||
* @param url - The URL to make the request to. Can be a string or a Request object.
|
||||
* @param init - Optional init options for the request.
|
||||
* @returns A promise that resolves to the response of the fetch request.
|
||||
*/
|
||||
return async function (
|
||||
_url: fetch.RequestInfo,
|
||||
init: fetch.RequestInit,
|
||||
): Promise<fetch.Response> {
|
||||
let url = _url;
|
||||
if (directEndpoint) {
|
||||
url = reverseProxyUrl;
|
||||
}
|
||||
logger.debug(`Making request to ${url}`);
|
||||
if (typeof Bun !== 'undefined') {
|
||||
return await fetch(url, init);
|
||||
}
|
||||
return await fetch(url, init);
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates event handlers for stream events that don't capture client references
|
||||
* @param res - The response object to send events to
|
||||
* @returns Object containing handler functions
|
||||
*/
|
||||
export function createStreamEventHandlers(res: ServerResponse) {
|
||||
return {
|
||||
[GraphEvents.ON_RUN_STEP]: function (event: ServerSentEvent) {
|
||||
if (res) {
|
||||
sendEvent(res, event);
|
||||
}
|
||||
},
|
||||
[GraphEvents.ON_MESSAGE_DELTA]: function (event: ServerSentEvent) {
|
||||
if (res) {
|
||||
sendEvent(res, event);
|
||||
}
|
||||
},
|
||||
[GraphEvents.ON_REASONING_DELTA]: function (event: ServerSentEvent) {
|
||||
if (res) {
|
||||
sendEvent(res, event);
|
||||
}
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
export function createHandleLLMNewToken(streamRate: number) {
|
||||
return async function () {
|
||||
if (streamRate) {
|
||||
await sleep(streamRate);
|
||||
}
|
||||
};
|
||||
}
|
||||
5
packages/api/src/utils/index.ts
Normal file
5
packages/api/src/utils/index.ts
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
export * from './azure';
|
||||
export * from './common';
|
||||
export * from './events';
|
||||
export * from './generators';
|
||||
export { default as Tokenizer } from './tokenizer';
|
||||
143
packages/api/src/utils/tokenizer.spec.ts
Normal file
143
packages/api/src/utils/tokenizer.spec.ts
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
/**
|
||||
* @file Tokenizer.spec.cjs
|
||||
*
|
||||
* Tests the real TokenizerSingleton (no mocking of `tiktoken`).
|
||||
* Make sure to install `tiktoken` and have it configured properly.
|
||||
*/
|
||||
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import type { Tiktoken } from 'tiktoken';
|
||||
import Tokenizer from './tokenizer';
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
error: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
describe('Tokenizer', () => {
|
||||
it('should be a singleton (same instance)', async () => {
|
||||
const AnotherTokenizer = await import('./tokenizer'); // same path
|
||||
expect(Tokenizer).toBe(AnotherTokenizer.default);
|
||||
});
|
||||
|
||||
describe('getTokenizer', () => {
|
||||
it('should create an encoder for an explicit model name (e.g., "gpt-4")', () => {
|
||||
// The real `encoding_for_model` will be called internally
|
||||
// as soon as we pass isModelName = true.
|
||||
const tokenizer = Tokenizer.getTokenizer('gpt-4', true);
|
||||
|
||||
// Basic sanity checks
|
||||
expect(tokenizer).toBeDefined();
|
||||
// You can optionally check certain properties from `tiktoken` if they exist
|
||||
// e.g., expect(typeof tokenizer.encode).toBe('function');
|
||||
});
|
||||
|
||||
it('should create an encoder for a known encoding (e.g., "cl100k_base")', () => {
|
||||
// The real `get_encoding` will be called internally
|
||||
// as soon as we pass isModelName = false.
|
||||
const tokenizer = Tokenizer.getTokenizer('cl100k_base', false);
|
||||
|
||||
expect(tokenizer).toBeDefined();
|
||||
// e.g., expect(typeof tokenizer.encode).toBe('function');
|
||||
});
|
||||
|
||||
it('should return cached tokenizer if previously fetched', () => {
|
||||
const tokenizer1 = Tokenizer.getTokenizer('cl100k_base', false);
|
||||
const tokenizer2 = Tokenizer.getTokenizer('cl100k_base', false);
|
||||
// Should be the exact same instance from the cache
|
||||
expect(tokenizer1).toBe(tokenizer2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('freeAndResetAllEncoders', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should free all encoders and reset tokenizerCallsCount to 1', () => {
|
||||
// By creating two different encodings, we populate the cache
|
||||
Tokenizer.getTokenizer('cl100k_base', false);
|
||||
Tokenizer.getTokenizer('r50k_base', false);
|
||||
|
||||
// Now free them
|
||||
Tokenizer.freeAndResetAllEncoders();
|
||||
|
||||
// The internal cache is cleared
|
||||
expect(Tokenizer.tokenizersCache['cl100k_base']).toBeUndefined();
|
||||
expect(Tokenizer.tokenizersCache['r50k_base']).toBeUndefined();
|
||||
|
||||
// tokenizerCallsCount is reset to 1
|
||||
expect(Tokenizer.tokenizerCallsCount).toBe(1);
|
||||
});
|
||||
|
||||
it('should catch and log errors if freeing fails', () => {
|
||||
// Mock logger.error before the test
|
||||
const mockLoggerError = jest.spyOn(logger, 'error');
|
||||
|
||||
// Set up a problematic tokenizer in the cache
|
||||
Tokenizer.tokenizersCache['cl100k_base'] = {
|
||||
free() {
|
||||
throw new Error('Intentional free error');
|
||||
},
|
||||
} as unknown as Tiktoken;
|
||||
|
||||
// Should not throw uncaught errors
|
||||
Tokenizer.freeAndResetAllEncoders();
|
||||
|
||||
// Verify logger.error was called with correct arguments
|
||||
expect(mockLoggerError).toHaveBeenCalledWith(
|
||||
'[Tokenizer] Free and reset encoders error',
|
||||
expect.any(Error),
|
||||
);
|
||||
|
||||
// Clean up
|
||||
mockLoggerError.mockRestore();
|
||||
Tokenizer.tokenizersCache = {};
|
||||
});
|
||||
});
|
||||
|
||||
describe('getTokenCount', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
Tokenizer.freeAndResetAllEncoders();
|
||||
});
|
||||
|
||||
it('should return the number of tokens in the given text', () => {
|
||||
const text = 'Hello, world!';
|
||||
const count = Tokenizer.getTokenCount(text, 'cl100k_base');
|
||||
expect(count).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('should reset encoders if an error is thrown', () => {
|
||||
// We can simulate an error by temporarily overriding the selected tokenizer's `encode` method.
|
||||
const tokenizer = Tokenizer.getTokenizer('cl100k_base', false);
|
||||
const originalEncode = tokenizer.encode;
|
||||
tokenizer.encode = () => {
|
||||
throw new Error('Forced error');
|
||||
};
|
||||
|
||||
// Despite the forced error, the code should catch and reset, then re-encode
|
||||
const count = Tokenizer.getTokenCount('Hello again', 'cl100k_base');
|
||||
expect(count).toBeGreaterThan(0);
|
||||
|
||||
// Restore the original encode
|
||||
tokenizer.encode = originalEncode;
|
||||
});
|
||||
|
||||
it('should reset tokenizers after 25 calls', () => {
|
||||
// Spy on freeAndResetAllEncoders
|
||||
const resetSpy = jest.spyOn(Tokenizer, 'freeAndResetAllEncoders');
|
||||
|
||||
// Make 24 calls; should NOT reset yet
|
||||
for (let i = 0; i < 24; i++) {
|
||||
Tokenizer.getTokenCount('test text', 'cl100k_base');
|
||||
}
|
||||
expect(resetSpy).not.toHaveBeenCalled();
|
||||
|
||||
// 25th call triggers the reset
|
||||
Tokenizer.getTokenCount('the 25th call!', 'cl100k_base');
|
||||
expect(resetSpy).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
});
|
||||
78
packages/api/src/utils/tokenizer.ts
Normal file
78
packages/api/src/utils/tokenizer.ts
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
import { logger } from '@librechat/data-schemas';
|
||||
import { encoding_for_model as encodingForModel, get_encoding as getEncoding } from 'tiktoken';
|
||||
import type { Tiktoken, TiktokenModel, TiktokenEncoding } from 'tiktoken';
|
||||
|
||||
interface TokenizerOptions {
|
||||
debug?: boolean;
|
||||
}
|
||||
|
||||
class Tokenizer {
|
||||
tokenizersCache: Record<string, Tiktoken>;
|
||||
tokenizerCallsCount: number;
|
||||
private options?: TokenizerOptions;
|
||||
|
||||
constructor() {
|
||||
this.tokenizersCache = {};
|
||||
this.tokenizerCallsCount = 0;
|
||||
}
|
||||
|
||||
getTokenizer(
|
||||
encoding: TiktokenModel | TiktokenEncoding,
|
||||
isModelName = false,
|
||||
extendSpecialTokens: Record<string, number> = {},
|
||||
): Tiktoken {
|
||||
let tokenizer: Tiktoken;
|
||||
if (this.tokenizersCache[encoding]) {
|
||||
tokenizer = this.tokenizersCache[encoding];
|
||||
} else {
|
||||
if (isModelName) {
|
||||
tokenizer = encodingForModel(encoding as TiktokenModel, extendSpecialTokens);
|
||||
} else {
|
||||
tokenizer = getEncoding(encoding as TiktokenEncoding, extendSpecialTokens);
|
||||
}
|
||||
this.tokenizersCache[encoding] = tokenizer;
|
||||
}
|
||||
return tokenizer;
|
||||
}
|
||||
|
||||
freeAndResetAllEncoders(): void {
|
||||
try {
|
||||
Object.keys(this.tokenizersCache).forEach((key) => {
|
||||
if (this.tokenizersCache[key]) {
|
||||
this.tokenizersCache[key].free();
|
||||
delete this.tokenizersCache[key];
|
||||
}
|
||||
});
|
||||
this.tokenizerCallsCount = 1;
|
||||
} catch (error) {
|
||||
logger.error('[Tokenizer] Free and reset encoders error', error);
|
||||
}
|
||||
}
|
||||
|
||||
resetTokenizersIfNecessary(): void {
|
||||
if (this.tokenizerCallsCount >= 25) {
|
||||
if (this.options?.debug) {
|
||||
logger.debug('[Tokenizer] freeAndResetAllEncoders: reached 25 encodings, resetting...');
|
||||
}
|
||||
this.freeAndResetAllEncoders();
|
||||
}
|
||||
this.tokenizerCallsCount++;
|
||||
}
|
||||
|
||||
getTokenCount(text: string, encoding: TiktokenModel | TiktokenEncoding = 'cl100k_base'): number {
|
||||
this.resetTokenizersIfNecessary();
|
||||
try {
|
||||
const tokenizer = this.getTokenizer(encoding);
|
||||
return tokenizer.encode(text, 'all').length;
|
||||
} catch (error) {
|
||||
logger.error('[Tokenizer] Error getting token count:', error);
|
||||
this.freeAndResetAllEncoders();
|
||||
const tokenizer = this.getTokenizer(encoding);
|
||||
return tokenizer.encode(text, 'all').length;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const TokenizerSingleton = new Tokenizer();
|
||||
|
||||
export default TokenizerSingleton;
|
||||
|
|
@ -18,7 +18,10 @@
|
|||
"isolatedModules": true,
|
||||
"noEmit": true,
|
||||
"sourceMap": true,
|
||||
"baseUrl": "."
|
||||
"baseUrl": ".",
|
||||
"paths": {
|
||||
"~/*": ["./src/*"]
|
||||
}
|
||||
},
|
||||
"ts-node": {
|
||||
"experimentalSpecifierResolution": "node",
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"name": "librechat-data-provider",
|
||||
"version": "0.7.86",
|
||||
"version": "0.7.87",
|
||||
"description": "data services for librechat apps",
|
||||
"main": "dist/index.js",
|
||||
"module": "dist/index.es.js",
|
||||
|
|
|
|||
|
|
@ -254,6 +254,7 @@ export const getAllPromptGroups = () => `${prompts()}/all`;
|
|||
export const roles = () => '/api/roles';
|
||||
export const getRole = (roleName: string) => `${roles()}/${roleName.toLowerCase()}`;
|
||||
export const updatePromptPermissions = (roleName: string) => `${getRole(roleName)}/prompts`;
|
||||
export const updateMemoryPermissions = (roleName: string) => `${getRole(roleName)}/memories`;
|
||||
export const updateAgentPermissions = (roleName: string) => `${getRole(roleName)}/agents`;
|
||||
|
||||
/* Conversation Tags */
|
||||
|
|
@ -283,3 +284,8 @@ export const confirmTwoFactor = () => '/api/auth/2fa/confirm';
|
|||
export const disableTwoFactor = () => '/api/auth/2fa/disable';
|
||||
export const regenerateBackupCodes = () => '/api/auth/2fa/backup/regenerate';
|
||||
export const verifyTwoFactorTemp = () => '/api/auth/2fa/verify-temp';
|
||||
|
||||
/* Memories */
|
||||
export const memories = () => '/api/memories';
|
||||
export const memory = (key: string) => `${memories()}/${encodeURIComponent(key)}`;
|
||||
export const memoryPreferences = () => `${memories()}/preferences`;
|
||||
|
|
|
|||
|
|
@ -493,6 +493,7 @@ export const intefaceSchema = z
|
|||
sidePanel: z.boolean().optional(),
|
||||
multiConvo: z.boolean().optional(),
|
||||
bookmarks: z.boolean().optional(),
|
||||
memories: z.boolean().optional(),
|
||||
presets: z.boolean().optional(),
|
||||
prompts: z.boolean().optional(),
|
||||
agents: z.boolean().optional(),
|
||||
|
|
@ -508,6 +509,7 @@ export const intefaceSchema = z
|
|||
presets: true,
|
||||
multiConvo: true,
|
||||
bookmarks: true,
|
||||
memories: true,
|
||||
prompts: true,
|
||||
agents: true,
|
||||
temporaryChat: true,
|
||||
|
|
@ -649,11 +651,35 @@ export const balanceSchema = z.object({
|
|||
refillAmount: z.number().optional().default(10000),
|
||||
});
|
||||
|
||||
export const memorySchema = z.object({
|
||||
disabled: z.boolean().optional(),
|
||||
validKeys: z.array(z.string()).optional(),
|
||||
tokenLimit: z.number().optional(),
|
||||
personalize: z.boolean().default(true),
|
||||
messageWindowSize: z.number().optional().default(5),
|
||||
agent: z
|
||||
.union([
|
||||
z.object({
|
||||
id: z.string(),
|
||||
}),
|
||||
z.object({
|
||||
provider: z.string(),
|
||||
model: z.string(),
|
||||
instructions: z.string().optional(),
|
||||
model_parameters: z.record(z.any()).optional(),
|
||||
}),
|
||||
])
|
||||
.optional(),
|
||||
});
|
||||
|
||||
export type TMemoryConfig = z.infer<typeof memorySchema>;
|
||||
|
||||
export const configSchema = z.object({
|
||||
version: z.string(),
|
||||
cache: z.boolean().default(true),
|
||||
ocr: ocrSchema.optional(),
|
||||
webSearch: webSearchSchema.optional(),
|
||||
memory: memorySchema.optional(),
|
||||
secureImageLinks: z.boolean().optional(),
|
||||
imageOutputType: z.nativeEnum(EImageOutputType).default(EImageOutputType.PNG),
|
||||
includedTools: z.array(z.string()).optional(),
|
||||
|
|
@ -1291,6 +1317,10 @@ export enum SettingsTabValues {
|
|||
* Chat input commands
|
||||
*/
|
||||
COMMANDS = 'commands',
|
||||
/**
|
||||
* Tab for Personalization Settings
|
||||
*/
|
||||
PERSONALIZATION = 'personalization',
|
||||
}
|
||||
|
||||
export enum STTProviders {
|
||||
|
|
|
|||
|
|
@ -13,11 +13,11 @@ export default function createPayload(submission: t.TSubmission) {
|
|||
ephemeralAgent,
|
||||
} = submission;
|
||||
const { conversationId } = s.tConvoUpdateSchema.parse(conversation);
|
||||
const { endpoint, endpointType } = endpointOption as {
|
||||
const { endpoint: _e, endpointType } = endpointOption as {
|
||||
endpoint: s.EModelEndpoint;
|
||||
endpointType?: s.EModelEndpoint;
|
||||
};
|
||||
|
||||
const endpoint = _e as s.EModelEndpoint;
|
||||
let server = EndpointURLs[endpointType ?? endpoint];
|
||||
const isEphemeral = s.isEphemeralAgent(endpoint, ephemeralAgent);
|
||||
|
||||
|
|
@ -32,6 +32,7 @@ export default function createPayload(submission: t.TSubmission) {
|
|||
const payload: t.TPayload = {
|
||||
...userMessage,
|
||||
...endpointOption,
|
||||
endpoint,
|
||||
ephemeralAgent: isEphemeral ? ephemeralAgent : undefined,
|
||||
isContinued: !!(isEdited && isContinued),
|
||||
conversationId,
|
||||
|
|
|
|||
|
|
@ -718,6 +718,12 @@ export function updateAgentPermissions(
|
|||
return request.put(endpoints.updateAgentPermissions(variables.roleName), variables.updates);
|
||||
}
|
||||
|
||||
export function updateMemoryPermissions(
|
||||
variables: m.UpdateMemoryPermVars,
|
||||
): Promise<m.UpdatePermResponse> {
|
||||
return request.put(endpoints.updateMemoryPermissions(variables.roleName), variables.updates);
|
||||
}
|
||||
|
||||
/* Tags */
|
||||
export function getConversationTags(): Promise<t.TConversationTagsResponse> {
|
||||
return request.get(endpoints.conversationTags());
|
||||
|
|
@ -799,3 +805,33 @@ export function verifyTwoFactorTemp(
|
|||
): Promise<t.TVerify2FATempResponse> {
|
||||
return request.post(endpoints.verifyTwoFactorTemp(), payload);
|
||||
}
|
||||
|
||||
/* Memories */
|
||||
export const getMemories = (): Promise<q.MemoriesResponse> => {
|
||||
return request.get(endpoints.memories());
|
||||
};
|
||||
|
||||
export const deleteMemory = (key: string): Promise<void> => {
|
||||
return request.delete(endpoints.memory(key));
|
||||
};
|
||||
|
||||
export const updateMemory = (
|
||||
key: string,
|
||||
value: string,
|
||||
originalKey?: string,
|
||||
): Promise<q.TUserMemory> => {
|
||||
return request.patch(endpoints.memory(originalKey || key), { key, value });
|
||||
};
|
||||
|
||||
export const updateMemoryPreferences = (preferences: {
|
||||
memories: boolean;
|
||||
}): Promise<{ updated: boolean; preferences: { memories: boolean } }> => {
|
||||
return request.patch(endpoints.memoryPreferences(), preferences);
|
||||
};
|
||||
|
||||
export const createMemory = (data: {
|
||||
key: string;
|
||||
value: string;
|
||||
}): Promise<{ created: boolean; memory: q.TUserMemory }> => {
|
||||
return request.post(endpoints.memories(), data);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ export * from './models';
|
|||
export * from './mcp';
|
||||
/* web search */
|
||||
export * from './web';
|
||||
/* memory */
|
||||
export * from './memory';
|
||||
/* RBAC */
|
||||
export * from './permissions';
|
||||
export * from './roles';
|
||||
|
|
|
|||
|
|
@ -46,6 +46,8 @@ export enum QueryKeys {
|
|||
health = 'health',
|
||||
userTerms = 'userTerms',
|
||||
banner = 'banner',
|
||||
/* Memories */
|
||||
memories = 'memories',
|
||||
}
|
||||
|
||||
export enum MutationKeys {
|
||||
|
|
@ -70,4 +72,5 @@ export enum MutationKeys {
|
|||
updateRole = 'updateRole',
|
||||
enableTwoFactor = 'enableTwoFactor',
|
||||
verifyTwoFactor = 'verifyTwoFactor',
|
||||
updateMemoryPreferences = 'updateMemoryPreferences',
|
||||
}
|
||||
|
|
|
|||
62
packages/data-provider/src/memory.ts
Normal file
62
packages/data-provider/src/memory.ts
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
import type { TCustomConfig, TMemoryConfig } from './config';
|
||||
|
||||
/**
|
||||
* Loads the memory configuration and validates it
|
||||
* @param config - The memory configuration from librechat.yaml
|
||||
* @returns The validated memory configuration
|
||||
*/
|
||||
export function loadMemoryConfig(config: TCustomConfig['memory']): TMemoryConfig | undefined {
|
||||
if (!config) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// If disabled is explicitly true, return the config as-is
|
||||
if (config.disabled === true) {
|
||||
return config;
|
||||
}
|
||||
|
||||
// Check if the agent configuration is valid
|
||||
const hasValidAgent =
|
||||
config.agent &&
|
||||
(('id' in config.agent && !!config.agent.id) ||
|
||||
('provider' in config.agent &&
|
||||
'model' in config.agent &&
|
||||
!!config.agent.provider &&
|
||||
!!config.agent.model));
|
||||
|
||||
// If agent config is invalid, treat as disabled
|
||||
if (!hasValidAgent) {
|
||||
return {
|
||||
...config,
|
||||
disabled: true,
|
||||
};
|
||||
}
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if memory feature is enabled based on the configuration
|
||||
* @param config - The memory configuration
|
||||
* @returns True if memory is enabled, false otherwise
|
||||
*/
|
||||
export function isMemoryEnabled(config: TMemoryConfig | undefined): boolean {
|
||||
if (!config) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (config.disabled === true) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if agent configuration is valid
|
||||
const hasValidAgent =
|
||||
config.agent &&
|
||||
(('id' in config.agent && !!config.agent.id) ||
|
||||
('provider' in config.agent &&
|
||||
'model' in config.agent &&
|
||||
!!config.agent.provider &&
|
||||
!!config.agent.model));
|
||||
|
||||
return !!hasValidAgent;
|
||||
}
|
||||
|
|
@ -225,13 +225,15 @@ const extractOmniVersion = (modelStr: string): string => {
|
|||
export const getResponseSender = (endpointOption: t.TEndpointOption): string => {
|
||||
const {
|
||||
model: _m,
|
||||
endpoint,
|
||||
endpoint: _e,
|
||||
endpointType,
|
||||
modelDisplayLabel: _mdl,
|
||||
chatGptLabel: _cgl,
|
||||
modelLabel: _ml,
|
||||
} = endpointOption;
|
||||
|
||||
const endpoint = _e as EModelEndpoint;
|
||||
|
||||
const model = _m ?? '';
|
||||
const modelDisplayLabel = _mdl ?? '';
|
||||
const chatGptLabel = _cgl ?? '';
|
||||
|
|
|
|||
|
|
@ -16,6 +16,10 @@ export enum PermissionTypes {
|
|||
* Type for Agent Permissions
|
||||
*/
|
||||
AGENTS = 'AGENTS',
|
||||
/**
|
||||
* Type for Memory Permissions
|
||||
*/
|
||||
MEMORIES = 'MEMORIES',
|
||||
/**
|
||||
* Type for Multi-Conversation Permissions
|
||||
*/
|
||||
|
|
@ -45,6 +49,8 @@ export enum Permissions {
|
|||
READ = 'READ',
|
||||
READ_AUTHOR = 'READ_AUTHOR',
|
||||
SHARE = 'SHARE',
|
||||
/** Can disable if desired */
|
||||
OPT_OUT = 'OPT_OUT',
|
||||
}
|
||||
|
||||
export const promptPermissionsSchema = z.object({
|
||||
|
|
@ -60,6 +66,15 @@ export const bookmarkPermissionsSchema = z.object({
|
|||
});
|
||||
export type TBookmarkPermissions = z.infer<typeof bookmarkPermissionsSchema>;
|
||||
|
||||
export const memoryPermissionsSchema = z.object({
|
||||
[Permissions.USE]: z.boolean().default(true),
|
||||
[Permissions.CREATE]: z.boolean().default(true),
|
||||
[Permissions.UPDATE]: z.boolean().default(true),
|
||||
[Permissions.READ]: z.boolean().default(true),
|
||||
[Permissions.OPT_OUT]: z.boolean().default(true),
|
||||
});
|
||||
export type TMemoryPermissions = z.infer<typeof memoryPermissionsSchema>;
|
||||
|
||||
export const agentPermissionsSchema = z.object({
|
||||
[Permissions.SHARED_GLOBAL]: z.boolean().default(false),
|
||||
[Permissions.USE]: z.boolean().default(true),
|
||||
|
|
@ -92,6 +107,7 @@ export type TWebSearchPermissions = z.infer<typeof webSearchPermissionsSchema>;
|
|||
export const permissionsSchema = z.object({
|
||||
[PermissionTypes.PROMPTS]: promptPermissionsSchema,
|
||||
[PermissionTypes.BOOKMARKS]: bookmarkPermissionsSchema,
|
||||
[PermissionTypes.MEMORIES]: memoryPermissionsSchema,
|
||||
[PermissionTypes.AGENTS]: agentPermissionsSchema,
|
||||
[PermissionTypes.MULTI_CONVO]: multiConvoPermissionsSchema,
|
||||
[PermissionTypes.TEMPORARY_CHAT]: temporaryChatPermissionsSchema,
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import {
|
|||
permissionsSchema,
|
||||
agentPermissionsSchema,
|
||||
promptPermissionsSchema,
|
||||
memoryPermissionsSchema,
|
||||
runCodePermissionsSchema,
|
||||
webSearchPermissionsSchema,
|
||||
bookmarkPermissionsSchema,
|
||||
|
|
@ -48,6 +49,13 @@ const defaultRolesSchema = z.object({
|
|||
[PermissionTypes.BOOKMARKS]: bookmarkPermissionsSchema.extend({
|
||||
[Permissions.USE]: z.boolean().default(true),
|
||||
}),
|
||||
[PermissionTypes.MEMORIES]: memoryPermissionsSchema.extend({
|
||||
[Permissions.USE]: z.boolean().default(true),
|
||||
[Permissions.CREATE]: z.boolean().default(true),
|
||||
[Permissions.UPDATE]: z.boolean().default(true),
|
||||
[Permissions.READ]: z.boolean().default(true),
|
||||
[Permissions.OPT_OUT]: z.boolean().default(true),
|
||||
}),
|
||||
[PermissionTypes.AGENTS]: agentPermissionsSchema.extend({
|
||||
[Permissions.SHARED_GLOBAL]: z.boolean().default(true),
|
||||
[Permissions.USE]: z.boolean().default(true),
|
||||
|
|
@ -86,6 +94,13 @@ export const roleDefaults = defaultRolesSchema.parse({
|
|||
[PermissionTypes.BOOKMARKS]: {
|
||||
[Permissions.USE]: true,
|
||||
},
|
||||
[PermissionTypes.MEMORIES]: {
|
||||
[Permissions.USE]: true,
|
||||
[Permissions.CREATE]: true,
|
||||
[Permissions.UPDATE]: true,
|
||||
[Permissions.READ]: true,
|
||||
[Permissions.OPT_OUT]: true,
|
||||
},
|
||||
[PermissionTypes.AGENTS]: {
|
||||
[Permissions.SHARED_GLOBAL]: true,
|
||||
[Permissions.USE]: true,
|
||||
|
|
@ -110,6 +125,7 @@ export const roleDefaults = defaultRolesSchema.parse({
|
|||
permissions: {
|
||||
[PermissionTypes.PROMPTS]: {},
|
||||
[PermissionTypes.BOOKMARKS]: {},
|
||||
[PermissionTypes.MEMORIES]: {},
|
||||
[PermissionTypes.AGENTS]: {},
|
||||
[PermissionTypes.MULTI_CONVO]: {},
|
||||
[PermissionTypes.TEMPORARY_CHAT]: {},
|
||||
|
|
|
|||
|
|
@ -522,11 +522,19 @@ export const tMessageSchema = z.object({
|
|||
feedback: feedbackSchema.optional(),
|
||||
});
|
||||
|
||||
export type MemoryArtifact = {
|
||||
key: string;
|
||||
value?: string;
|
||||
tokenCount?: number;
|
||||
type: 'update' | 'delete';
|
||||
};
|
||||
|
||||
export type TAttachmentMetadata = {
|
||||
type?: Tools;
|
||||
messageId: string;
|
||||
toolCallId: string;
|
||||
[Tools.web_search]?: SearchResultData;
|
||||
[Tools.memory]?: MemoryArtifact;
|
||||
};
|
||||
|
||||
export type TAttachment =
|
||||
|
|
|
|||
|
|
@ -1,17 +1,18 @@
|
|||
import type OpenAI from 'openai';
|
||||
import type { InfiniteData } from '@tanstack/react-query';
|
||||
import type {
|
||||
TBanner,
|
||||
TMessage,
|
||||
TResPlugin,
|
||||
ImageDetail,
|
||||
TSharedLink,
|
||||
TConversation,
|
||||
EModelEndpoint,
|
||||
TConversationTag,
|
||||
TBanner,
|
||||
TAttachment,
|
||||
} from './schemas';
|
||||
import { TMinimalFeedback } from './feedback';
|
||||
import { SettingDefinition } from './generate';
|
||||
import type { SettingDefinition } from './generate';
|
||||
import type { TMinimalFeedback } from './feedback';
|
||||
import type { Agent } from './types/assistants';
|
||||
|
||||
export type TOpenAIMessage = OpenAI.Chat.ChatCompletionMessageParam;
|
||||
|
||||
|
|
@ -20,28 +21,78 @@ export * from './schemas';
|
|||
export type TMessages = TMessage[];
|
||||
|
||||
/* TODO: Cleanup EndpointOption types */
|
||||
export type TEndpointOption = {
|
||||
spec?: string | null;
|
||||
iconURL?: string | null;
|
||||
endpoint: EModelEndpoint;
|
||||
endpointType?: EModelEndpoint;
|
||||
export type TEndpointOption = Pick<
|
||||
TConversation,
|
||||
// Core conversation fields
|
||||
| 'endpoint'
|
||||
| 'endpointType'
|
||||
| 'model'
|
||||
| 'modelLabel'
|
||||
| 'chatGptLabel'
|
||||
| 'promptPrefix'
|
||||
| 'temperature'
|
||||
| 'topP'
|
||||
| 'topK'
|
||||
| 'top_p'
|
||||
| 'frequency_penalty'
|
||||
| 'presence_penalty'
|
||||
| 'maxOutputTokens'
|
||||
| 'maxContextTokens'
|
||||
| 'max_tokens'
|
||||
| 'maxTokens'
|
||||
| 'resendFiles'
|
||||
| 'imageDetail'
|
||||
| 'reasoning_effort'
|
||||
| 'instructions'
|
||||
| 'additional_instructions'
|
||||
| 'append_current_datetime'
|
||||
| 'tools'
|
||||
| 'stop'
|
||||
| 'region'
|
||||
| 'additionalModelRequestFields'
|
||||
// Anthropic-specific
|
||||
| 'promptCache'
|
||||
| 'thinking'
|
||||
| 'thinkingBudget'
|
||||
// Assistant/Agent fields
|
||||
| 'assistant_id'
|
||||
| 'agent_id'
|
||||
// UI/Display fields
|
||||
| 'iconURL'
|
||||
| 'greeting'
|
||||
| 'spec'
|
||||
// Artifacts
|
||||
| 'artifacts'
|
||||
// Files
|
||||
| 'file_ids'
|
||||
// System field
|
||||
| 'system'
|
||||
// Google examples
|
||||
| 'examples'
|
||||
// Context
|
||||
| 'context'
|
||||
> & {
|
||||
// Fields specific to endpoint options that don't exist on TConversation
|
||||
modelDisplayLabel?: string;
|
||||
resendFiles?: boolean;
|
||||
promptCache?: boolean;
|
||||
maxContextTokens?: number;
|
||||
imageDetail?: ImageDetail;
|
||||
model?: string | null;
|
||||
promptPrefix?: string;
|
||||
temperature?: number;
|
||||
chatGptLabel?: string | null;
|
||||
modelLabel?: string | null;
|
||||
jailbreak?: boolean;
|
||||
key?: string | null;
|
||||
/* assistant */
|
||||
/** @deprecated Assistants API */
|
||||
thread_id?: string;
|
||||
/* multi-response stream */
|
||||
// Conversation identifiers for multi-response streams
|
||||
overrideConvoId?: string;
|
||||
overrideUserMessageId?: string;
|
||||
// Model parameters (used by different endpoints)
|
||||
modelOptions?: Record<string, unknown>;
|
||||
model_parameters?: Record<string, unknown>;
|
||||
// Configuration data (added by middleware)
|
||||
modelsConfig?: TModelsConfig;
|
||||
// File attachments (processed by middleware)
|
||||
attachments?: TAttachment[];
|
||||
// Generated prompts
|
||||
artifactsPrompt?: string;
|
||||
// Agent-specific fields
|
||||
agent?: Promise<Agent>;
|
||||
// Client-specific options
|
||||
clientOptions?: Record<string, unknown>;
|
||||
};
|
||||
|
||||
export type TEphemeralAgent = {
|
||||
|
|
@ -130,6 +181,9 @@ export type TUser = {
|
|||
plugins?: string[];
|
||||
twoFactorEnabled?: boolean;
|
||||
backupCodes?: TBackupCode[];
|
||||
personalization?: {
|
||||
memories?: boolean;
|
||||
};
|
||||
createdAt: string;
|
||||
updatedAt: string;
|
||||
};
|
||||
|
|
@ -557,7 +611,7 @@ export type TUpdateFeedbackResponse = {
|
|||
messageId: string;
|
||||
conversationId: string;
|
||||
feedback?: TMinimalFeedback;
|
||||
}
|
||||
};
|
||||
|
||||
export type TBalanceResponse = {
|
||||
tokenCredits: number;
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ export enum Tools {
|
|||
web_search = 'web_search',
|
||||
retrieval = 'retrieval',
|
||||
function = 'function',
|
||||
memory = 'memory',
|
||||
}
|
||||
|
||||
export enum EToolResources {
|
||||
|
|
|
|||
|
|
@ -278,7 +278,7 @@ export type UpdatePermVars<T> = {
|
|||
};
|
||||
|
||||
export type UpdatePromptPermVars = UpdatePermVars<p.TPromptPermissions>;
|
||||
|
||||
export type UpdateMemoryPermVars = UpdatePermVars<p.TMemoryPermissions>;
|
||||
export type UpdateAgentPermVars = UpdatePermVars<p.TAgentPermissions>;
|
||||
|
||||
export type UpdatePermResponse = r.TRole;
|
||||
|
|
@ -290,6 +290,13 @@ export type UpdatePromptPermOptions = MutationOptions<
|
|||
types.TError | null | undefined
|
||||
>;
|
||||
|
||||
export type UpdateMemoryPermOptions = MutationOptions<
|
||||
UpdatePermResponse,
|
||||
UpdateMemoryPermVars,
|
||||
unknown,
|
||||
types.TError | null | undefined
|
||||
>;
|
||||
|
||||
export type UpdateAgentPermOptions = MutationOptions<
|
||||
UpdatePermResponse,
|
||||
UpdateAgentPermVars,
|
||||
|
|
|
|||
|
|
@ -109,3 +109,18 @@ export type VerifyToolAuthResponse = {
|
|||
|
||||
export type GetToolCallParams = { conversationId: string };
|
||||
export type ToolCallResults = a.ToolCallResult[];
|
||||
|
||||
/* Memories */
|
||||
export type TUserMemory = {
|
||||
key: string;
|
||||
value: string;
|
||||
updated_at: string;
|
||||
tokenCount?: number;
|
||||
};
|
||||
|
||||
export type MemoriesResponse = {
|
||||
memories: TUserMemory[];
|
||||
totalTokens: number;
|
||||
tokenLimit: number | null;
|
||||
usagePercentage: number | null;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@ import type {
|
|||
SearchProviders,
|
||||
TWebSearchConfig,
|
||||
} from './config';
|
||||
import { extractVariableName } from './utils';
|
||||
import { SearchCategories, SafeSearchTypes } from './config';
|
||||
import { extractVariableName } from './utils';
|
||||
import { AuthType } from './schemas';
|
||||
|
||||
export function loadWebSearchConfig(
|
||||
|
|
@ -64,23 +64,29 @@ export const webSearchAuth = {
|
|||
/**
|
||||
* Extracts all API keys from the webSearchAuth configuration object
|
||||
*/
|
||||
export const webSearchKeys: TWebSearchKeys[] = [];
|
||||
export function getWebSearchKeys(): TWebSearchKeys[] {
|
||||
const keys: TWebSearchKeys[] = [];
|
||||
|
||||
// Iterate through each category (providers, scrapers, rerankers)
|
||||
for (const category of Object.keys(webSearchAuth)) {
|
||||
const categoryObj = webSearchAuth[category as TWebSearchCategories];
|
||||
// Iterate through each category (providers, scrapers, rerankers)
|
||||
for (const category of Object.keys(webSearchAuth)) {
|
||||
const categoryObj = webSearchAuth[category as TWebSearchCategories];
|
||||
|
||||
// Iterate through each service within the category
|
||||
for (const service of Object.keys(categoryObj)) {
|
||||
const serviceObj = categoryObj[service as keyof typeof categoryObj];
|
||||
// Iterate through each service within the category
|
||||
for (const service of Object.keys(categoryObj)) {
|
||||
const serviceObj = categoryObj[service as keyof typeof categoryObj];
|
||||
|
||||
// Extract the API keys from the service
|
||||
for (const key of Object.keys(serviceObj)) {
|
||||
webSearchKeys.push(key as TWebSearchKeys);
|
||||
// Extract the API keys from the service
|
||||
for (const key of Object.keys(serviceObj)) {
|
||||
keys.push(key as TWebSearchKeys);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return keys;
|
||||
}
|
||||
|
||||
export const webSearchKeys: TWebSearchKeys[] = getWebSearchKeys();
|
||||
|
||||
export function extractWebSearchEnvVars({
|
||||
keys,
|
||||
config,
|
||||
|
|
|
|||
|
|
@ -1,114 +0,0 @@
|
|||
# `@librechat/data-schemas`
|
||||
|
||||
Mongoose schemas and models for LibreChat. This package provides a comprehensive collection of Mongoose schemas used across the LibreChat project, enabling robust data modeling and validation for various entities such as actions, agents, messages, users, and more.
|
||||
|
||||
|
||||
## Features
|
||||
|
||||
- **Modular Schemas:** Includes schemas for actions, agents, assistants, balance, banners, categories, conversation tags, conversations, files, keys, messages, plugin authentication, presets, projects, prompts, prompt groups, roles, sessions, shared links, tokens, tool calls, transactions, and users.
|
||||
- **TypeScript Support:** Provides TypeScript definitions for type-safe development.
|
||||
- **Ready for Mongoose Integration:** Easily integrate with Mongoose to create models and interact with your MongoDB database.
|
||||
- **Flexible & Extensible:** Designed to support the evolving needs of LibreChat while being adaptable to other projects.
|
||||
|
||||
|
||||
## Installation
|
||||
|
||||
Install the package via npm or yarn:
|
||||
|
||||
```bash
|
||||
npm install @librechat/data-schemas
|
||||
```
|
||||
|
||||
Or with yarn:
|
||||
|
||||
```bash
|
||||
yarn add @librechat/data-schemas
|
||||
```
|
||||
|
||||
|
||||
## Usage
|
||||
|
||||
After installation, you can import and use the schemas in your project. For example, to create a Mongoose model for a user:
|
||||
|
||||
```js
|
||||
import mongoose from 'mongoose';
|
||||
import { userSchema } from '@librechat/data-schemas';
|
||||
|
||||
const UserModel = mongoose.model('User', userSchema);
|
||||
|
||||
// Now you can use UserModel to create, read, update, and delete user documents.
|
||||
```
|
||||
|
||||
You can also import other schemas as needed:
|
||||
|
||||
```js
|
||||
import { actionSchema, agentSchema, messageSchema } from '@librechat/data-schemas';
|
||||
```
|
||||
|
||||
Each schema is designed to integrate seamlessly with Mongoose and provides indexes, timestamps, and validations tailored for LibreChat’s use cases.
|
||||
|
||||
|
||||
## Development
|
||||
|
||||
This package uses Rollup and TypeScript for building and bundling.
|
||||
|
||||
### Available Scripts
|
||||
|
||||
- **Build:**
|
||||
Cleans the `dist` directory and builds the package.
|
||||
```bash
|
||||
npm run build
|
||||
```
|
||||
|
||||
- **Build Watch:**
|
||||
Rebuilds automatically on file changes.
|
||||
```bash
|
||||
npm run build:watch
|
||||
```
|
||||
|
||||
- **Test:**
|
||||
Runs tests with coverage in watch mode.
|
||||
```bash
|
||||
npm run test
|
||||
```
|
||||
|
||||
- **Test (CI):**
|
||||
Runs tests with coverage for CI environments.
|
||||
```bash
|
||||
npm run test:ci
|
||||
```
|
||||
|
||||
- **Verify:**
|
||||
Runs tests in CI mode to verify code integrity.
|
||||
```bash
|
||||
npm run verify
|
||||
```
|
||||
|
||||
- **Clean:**
|
||||
Removes the `dist` directory.
|
||||
```bash
|
||||
npm run clean
|
||||
```
|
||||
|
||||
For those using Bun, equivalent scripts are available:
|
||||
- **Bun Clean:** `bun run b:clean`
|
||||
- **Bun Build:** `bun run b:build`
|
||||
|
||||
|
||||
## Repository & Issues
|
||||
|
||||
The source code is maintained on GitHub.
|
||||
- **Repository:** [LibreChat Repository](https://github.com/danny-avila/LibreChat.git)
|
||||
- **Issues & Bug Reports:** [LibreChat Issues](https://github.com/danny-avila/LibreChat/issues)
|
||||
|
||||
|
||||
## License
|
||||
|
||||
This project is licensed under the [MIT License](LICENSE).
|
||||
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions to improve and expand the data schemas are welcome. If you have suggestions, improvements, or bug fixes, please open an issue or submit a pull request on the [GitHub repository](https://github.com/danny-avila/LibreChat/issues).
|
||||
|
||||
For more detailed documentation on each schema and model, please refer to the source code or visit the [LibreChat website](https://librechat.ai).
|
||||
|
|
@ -3,5 +3,6 @@ export * from './schema';
|
|||
export { createModels } from './models';
|
||||
export { createMethods } from './methods';
|
||||
export type * from './types';
|
||||
export type * from './methods';
|
||||
export { default as logger } from './config/winston';
|
||||
export { default as meiliLogger } from './config/meiliLogger';
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ import { createUserMethods, type UserMethods } from './user';
|
|||
import { createSessionMethods, type SessionMethods } from './session';
|
||||
import { createTokenMethods, type TokenMethods } from './token';
|
||||
import { createRoleMethods, type RoleMethods } from './role';
|
||||
/* Memories */
|
||||
import { createMemoryMethods, type MemoryMethods } from './memory';
|
||||
|
||||
/**
|
||||
* Creates all database methods for all collections
|
||||
|
|
@ -12,7 +14,9 @@ export function createMethods(mongoose: typeof import('mongoose')) {
|
|||
...createSessionMethods(mongoose),
|
||||
...createTokenMethods(mongoose),
|
||||
...createRoleMethods(mongoose),
|
||||
...createMemoryMethods(mongoose),
|
||||
};
|
||||
}
|
||||
|
||||
export type AllMethods = UserMethods & SessionMethods & TokenMethods & RoleMethods;
|
||||
export type { MemoryMethods };
|
||||
export type AllMethods = UserMethods & SessionMethods & TokenMethods & RoleMethods & MemoryMethods;
|
||||
|
|
|
|||
168
packages/data-schemas/src/methods/memory.ts
Normal file
168
packages/data-schemas/src/methods/memory.ts
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
import { Types } from 'mongoose';
|
||||
import logger from '~/config/winston';
|
||||
import type * as t from '~/types';
|
||||
|
||||
/**
|
||||
* Formats a date in YYYY-MM-DD format
|
||||
*/
|
||||
const formatDate = (date: Date): string => {
|
||||
return date.toISOString().split('T')[0];
|
||||
};
|
||||
|
||||
// Factory function that takes mongoose instance and returns the methods
|
||||
export function createMemoryMethods(mongoose: typeof import('mongoose')) {
|
||||
const MemoryEntry = mongoose.models.MemoryEntry;
|
||||
|
||||
/**
|
||||
* Creates a new memory entry for a user
|
||||
* Throws an error if a memory with the same key already exists
|
||||
*/
|
||||
async function createMemory({
|
||||
userId,
|
||||
key,
|
||||
value,
|
||||
tokenCount = 0,
|
||||
}: t.SetMemoryParams): Promise<t.MemoryResult> {
|
||||
try {
|
||||
if (key?.toLowerCase() === 'nothing') {
|
||||
return { ok: false };
|
||||
}
|
||||
|
||||
const existingMemory = await MemoryEntry.findOne({ userId, key });
|
||||
if (existingMemory) {
|
||||
throw new Error('Memory with this key already exists');
|
||||
}
|
||||
|
||||
await MemoryEntry.create({
|
||||
userId,
|
||||
key,
|
||||
value,
|
||||
tokenCount,
|
||||
updated_at: new Date(),
|
||||
});
|
||||
|
||||
return { ok: true };
|
||||
} catch (error) {
|
||||
throw new Error(
|
||||
`Failed to create memory: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets or updates a memory entry for a user
|
||||
*/
|
||||
async function setMemory({
|
||||
userId,
|
||||
key,
|
||||
value,
|
||||
tokenCount = 0,
|
||||
}: t.SetMemoryParams): Promise<t.MemoryResult> {
|
||||
try {
|
||||
if (key?.toLowerCase() === 'nothing') {
|
||||
return { ok: false };
|
||||
}
|
||||
|
||||
await MemoryEntry.findOneAndUpdate(
|
||||
{ userId, key },
|
||||
{
|
||||
value,
|
||||
tokenCount,
|
||||
updated_at: new Date(),
|
||||
},
|
||||
{
|
||||
upsert: true,
|
||||
new: true,
|
||||
},
|
||||
);
|
||||
|
||||
return { ok: true };
|
||||
} catch (error) {
|
||||
throw new Error(
|
||||
`Failed to set memory: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Deletes a specific memory entry for a user
|
||||
*/
|
||||
async function deleteMemory({ userId, key }: t.DeleteMemoryParams): Promise<t.MemoryResult> {
|
||||
try {
|
||||
const result = await MemoryEntry.findOneAndDelete({ userId, key });
|
||||
return { ok: !!result };
|
||||
} catch (error) {
|
||||
throw new Error(
|
||||
`Failed to delete memory: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets all memory entries for a user
|
||||
*/
|
||||
async function getAllUserMemories(
|
||||
userId: string | Types.ObjectId,
|
||||
): Promise<t.IMemoryEntryLean[]> {
|
||||
try {
|
||||
return (await MemoryEntry.find({ userId }).lean()) as t.IMemoryEntryLean[];
|
||||
} catch (error) {
|
||||
throw new Error(
|
||||
`Failed to get all memories: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets and formats all memories for a user in two different formats
|
||||
*/
|
||||
async function getFormattedMemories({
|
||||
userId,
|
||||
}: t.GetFormattedMemoriesParams): Promise<t.FormattedMemoriesResult> {
|
||||
try {
|
||||
const memories = await getAllUserMemories(userId);
|
||||
|
||||
if (!memories || memories.length === 0) {
|
||||
return { withKeys: '', withoutKeys: '', totalTokens: 0 };
|
||||
}
|
||||
|
||||
const sortedMemories = memories.sort(
|
||||
(a, b) => new Date(a.updated_at!).getTime() - new Date(b.updated_at!).getTime(),
|
||||
);
|
||||
|
||||
const totalTokens = sortedMemories.reduce((sum, memory) => {
|
||||
return sum + (memory.tokenCount || 0);
|
||||
}, 0);
|
||||
|
||||
const withKeys = sortedMemories
|
||||
.map((memory, index) => {
|
||||
const date = formatDate(new Date(memory.updated_at!));
|
||||
const tokenInfo = memory.tokenCount ? ` [${memory.tokenCount} tokens]` : '';
|
||||
return `${index + 1}. [${date}]. ["key": "${memory.key}"]${tokenInfo}. ["value": "${memory.value}"]`;
|
||||
})
|
||||
.join('\n\n');
|
||||
|
||||
const withoutKeys = sortedMemories
|
||||
.map((memory, index) => {
|
||||
const date = formatDate(new Date(memory.updated_at!));
|
||||
return `${index + 1}. [${date}]. ${memory.value}`;
|
||||
})
|
||||
.join('\n\n');
|
||||
|
||||
return { withKeys, withoutKeys, totalTokens };
|
||||
} catch (error) {
|
||||
logger.error('Failed to get formatted memories:', error);
|
||||
return { withKeys: '', withoutKeys: '', totalTokens: 0 };
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
setMemory,
|
||||
createMemory,
|
||||
deleteMemory,
|
||||
getAllUserMemories,
|
||||
getFormattedMemories,
|
||||
};
|
||||
}
|
||||
|
||||
export type MemoryMethods = ReturnType<typeof createMemoryMethods>;
|
||||
|
|
@ -170,6 +170,35 @@ export function createUserMethods(mongoose: typeof import('mongoose')) {
|
|||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Update a user's personalization memories setting.
|
||||
* Handles the edge case where the personalization object doesn't exist.
|
||||
*/
|
||||
async function toggleUserMemories(
|
||||
userId: string,
|
||||
memoriesEnabled: boolean,
|
||||
): Promise<IUser | null> {
|
||||
const User = mongoose.models.User;
|
||||
|
||||
// First, ensure the personalization object exists
|
||||
const user = await User.findById(userId);
|
||||
if (!user) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Use $set to update the nested field, which will create the personalization object if it doesn't exist
|
||||
const updateOperation = {
|
||||
$set: {
|
||||
'personalization.memories': memoriesEnabled,
|
||||
},
|
||||
};
|
||||
|
||||
return (await User.findByIdAndUpdate(userId, updateOperation, {
|
||||
new: true,
|
||||
runValidators: true,
|
||||
}).lean()) as IUser | null;
|
||||
}
|
||||
|
||||
// Return all methods
|
||||
return {
|
||||
findUser,
|
||||
|
|
@ -179,6 +208,7 @@ export function createUserMethods(mongoose: typeof import('mongoose')) {
|
|||
getUserById,
|
||||
deleteUserById,
|
||||
generateToken,
|
||||
toggleUserMemories,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ import { createPromptGroupModel } from './promptGroup';
|
|||
import { createConversationTagModel } from './conversationTag';
|
||||
import { createSharedLinkModel } from './sharedLink';
|
||||
import { createToolCallModel } from './toolCall';
|
||||
import { createMemoryModel } from './memory';
|
||||
|
||||
/**
|
||||
* Creates all database models for all collections
|
||||
|
|
@ -48,5 +49,6 @@ export function createModels(mongoose: typeof import('mongoose')) {
|
|||
ConversationTag: createConversationTagModel(mongoose),
|
||||
SharedLink: createSharedLinkModel(mongoose),
|
||||
ToolCall: createToolCallModel(mongoose),
|
||||
MemoryEntry: createMemoryModel(mongoose),
|
||||
};
|
||||
}
|
||||
|
|
|
|||
6
packages/data-schemas/src/models/memory.ts
Normal file
6
packages/data-schemas/src/models/memory.ts
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
import memorySchema from '~/schema/memory';
|
||||
import type { IMemoryEntry } from '~/types/memory';
|
||||
|
||||
export function createMemoryModel(mongoose: typeof import('mongoose')) {
|
||||
return mongoose.models.MemoryEntry || mongoose.model<IMemoryEntry>('MemoryEntry', memorySchema);
|
||||
}
|
||||
|
|
@ -1,6 +1,14 @@
|
|||
import _ from 'lodash';
|
||||
import { MeiliSearch, Index } from 'meilisearch';
|
||||
import type { FilterQuery, Types, Schema, Document, Model, Query } from 'mongoose';
|
||||
import type {
|
||||
CallbackWithoutResultAndOptionalError,
|
||||
FilterQuery,
|
||||
Document,
|
||||
Schema,
|
||||
Query,
|
||||
Types,
|
||||
Model,
|
||||
} from 'mongoose';
|
||||
import logger from '~/config/meiliLogger';
|
||||
|
||||
interface MongoMeiliOptions {
|
||||
|
|
@ -24,12 +32,12 @@ interface ContentItem {
|
|||
interface DocumentWithMeiliIndex extends Document {
|
||||
_meiliIndex?: boolean;
|
||||
preprocessObjectForIndex?: () => Record<string, unknown>;
|
||||
addObjectToMeili?: () => Promise<void>;
|
||||
updateObjectToMeili?: () => Promise<void>;
|
||||
deleteObjectFromMeili?: () => Promise<void>;
|
||||
postSaveHook?: () => void;
|
||||
postUpdateHook?: () => void;
|
||||
postRemoveHook?: () => void;
|
||||
addObjectToMeili?: (next: CallbackWithoutResultAndOptionalError) => Promise<void>;
|
||||
updateObjectToMeili?: (next: CallbackWithoutResultAndOptionalError) => Promise<void>;
|
||||
deleteObjectFromMeili?: (next: CallbackWithoutResultAndOptionalError) => Promise<void>;
|
||||
postSaveHook?: (next: CallbackWithoutResultAndOptionalError) => void;
|
||||
postUpdateHook?: (next: CallbackWithoutResultAndOptionalError) => void;
|
||||
postRemoveHook?: (next: CallbackWithoutResultAndOptionalError) => void;
|
||||
conversationId?: string;
|
||||
content?: ContentItem[];
|
||||
messageId?: string;
|
||||
|
|
@ -220,7 +228,7 @@ const createMeiliMongooseModel = ({
|
|||
);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('[syncWithMeili] Error adding document to Meili', error);
|
||||
logger.error('[syncWithMeili] Error adding document to Meili:', error);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -306,28 +314,48 @@ const createMeiliMongooseModel = ({
|
|||
/**
|
||||
* Adds the current document to the MeiliSearch index
|
||||
*/
|
||||
async addObjectToMeili(this: DocumentWithMeiliIndex): Promise<void> {
|
||||
async addObjectToMeili(
|
||||
this: DocumentWithMeiliIndex,
|
||||
next: CallbackWithoutResultAndOptionalError,
|
||||
): Promise<void> {
|
||||
const object = this.preprocessObjectForIndex!();
|
||||
try {
|
||||
await index.addDocuments([object]);
|
||||
} catch (error) {
|
||||
logger.error('[addObjectToMeili] Error adding document to Meili', error);
|
||||
logger.error('[addObjectToMeili] Error adding document to Meili:', error);
|
||||
return next();
|
||||
}
|
||||
|
||||
await this.collection.updateMany(
|
||||
{ _id: this._id as Types.ObjectId },
|
||||
{ $set: { _meiliIndex: true } },
|
||||
);
|
||||
try {
|
||||
await this.collection.updateMany(
|
||||
{ _id: this._id as Types.ObjectId },
|
||||
{ $set: { _meiliIndex: true } },
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error('[addObjectToMeili] Error updating _meiliIndex field:', error);
|
||||
return next();
|
||||
}
|
||||
|
||||
next();
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates the current document in the MeiliSearch index
|
||||
*/
|
||||
async updateObjectToMeili(this: DocumentWithMeiliIndex): Promise<void> {
|
||||
const object = _.omitBy(_.pick(this.toJSON(), attributesToIndex), (v, k) =>
|
||||
k.startsWith('$'),
|
||||
);
|
||||
await index.updateDocuments([object]);
|
||||
async updateObjectToMeili(
|
||||
this: DocumentWithMeiliIndex,
|
||||
next: CallbackWithoutResultAndOptionalError,
|
||||
): Promise<void> {
|
||||
try {
|
||||
const object = _.omitBy(_.pick(this.toJSON(), attributesToIndex), (v, k) =>
|
||||
k.startsWith('$'),
|
||||
);
|
||||
await index.updateDocuments([object]);
|
||||
next();
|
||||
} catch (error) {
|
||||
logger.error('[updateObjectToMeili] Error updating document in Meili:', error);
|
||||
return next();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -335,8 +363,17 @@ const createMeiliMongooseModel = ({
|
|||
*
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async deleteObjectFromMeili(this: DocumentWithMeiliIndex): Promise<void> {
|
||||
await index.deleteDocument(this._id as string);
|
||||
async deleteObjectFromMeili(
|
||||
this: DocumentWithMeiliIndex,
|
||||
next: CallbackWithoutResultAndOptionalError,
|
||||
): Promise<void> {
|
||||
try {
|
||||
await index.deleteDocument(this._id as string);
|
||||
next();
|
||||
} catch (error) {
|
||||
logger.error('[deleteObjectFromMeili] Error deleting document from Meili:', error);
|
||||
return next();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -345,11 +382,11 @@ const createMeiliMongooseModel = ({
|
|||
* If the document is already indexed (i.e. `_meiliIndex` is true), it updates it;
|
||||
* otherwise, it adds the document to the index.
|
||||
*/
|
||||
postSaveHook(this: DocumentWithMeiliIndex): void {
|
||||
postSaveHook(this: DocumentWithMeiliIndex, next: CallbackWithoutResultAndOptionalError): void {
|
||||
if (this._meiliIndex) {
|
||||
this.updateObjectToMeili!();
|
||||
this.updateObjectToMeili!(next);
|
||||
} else {
|
||||
this.addObjectToMeili!();
|
||||
this.addObjectToMeili!(next);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -359,9 +396,14 @@ const createMeiliMongooseModel = ({
|
|||
* This hook is triggered after a document update, ensuring that changes are
|
||||
* propagated to the MeiliSearch index if the document is indexed.
|
||||
*/
|
||||
postUpdateHook(this: DocumentWithMeiliIndex): void {
|
||||
postUpdateHook(
|
||||
this: DocumentWithMeiliIndex,
|
||||
next: CallbackWithoutResultAndOptionalError,
|
||||
): void {
|
||||
if (this._meiliIndex) {
|
||||
this.updateObjectToMeili!();
|
||||
this.updateObjectToMeili!(next);
|
||||
} else {
|
||||
next();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -371,9 +413,14 @@ const createMeiliMongooseModel = ({
|
|||
* This hook is triggered after a document is removed, ensuring that the document
|
||||
* is also removed from the MeiliSearch index if it was previously indexed.
|
||||
*/
|
||||
postRemoveHook(this: DocumentWithMeiliIndex): void {
|
||||
postRemoveHook(
|
||||
this: DocumentWithMeiliIndex,
|
||||
next: CallbackWithoutResultAndOptionalError,
|
||||
): void {
|
||||
if (this._meiliIndex) {
|
||||
this.deleteObjectFromMeili!();
|
||||
this.deleteObjectFromMeili!(next);
|
||||
} else {
|
||||
next();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -429,16 +476,16 @@ export default function mongoMeili(schema: Schema, options: MongoMeiliOptions):
|
|||
schema.loadClass(createMeiliMongooseModel({ index, attributesToIndex }));
|
||||
|
||||
// Register Mongoose hooks
|
||||
schema.post('save', function (doc: DocumentWithMeiliIndex) {
|
||||
doc.postSaveHook?.();
|
||||
schema.post('save', function (doc: DocumentWithMeiliIndex, next) {
|
||||
doc.postSaveHook?.(next);
|
||||
});
|
||||
|
||||
schema.post('updateOne', function (doc: DocumentWithMeiliIndex) {
|
||||
doc.postUpdateHook?.();
|
||||
schema.post('updateOne', function (doc: DocumentWithMeiliIndex, next) {
|
||||
doc.postUpdateHook?.(next);
|
||||
});
|
||||
|
||||
schema.post('deleteOne', function (doc: DocumentWithMeiliIndex) {
|
||||
doc.postRemoveHook?.();
|
||||
schema.post('deleteOne', function (doc: DocumentWithMeiliIndex, next) {
|
||||
doc.postRemoveHook?.(next);
|
||||
});
|
||||
|
||||
// Pre-deleteMany hook: remove corresponding documents from MeiliSearch when multiple documents are deleted.
|
||||
|
|
@ -486,13 +533,13 @@ export default function mongoMeili(schema: Schema, options: MongoMeiliOptions):
|
|||
});
|
||||
|
||||
// Post-findOneAndUpdate hook
|
||||
schema.post('findOneAndUpdate', async function (doc: DocumentWithMeiliIndex) {
|
||||
schema.post('findOneAndUpdate', async function (doc: DocumentWithMeiliIndex, next) {
|
||||
if (!meiliEnabled) {
|
||||
return;
|
||||
return next();
|
||||
}
|
||||
|
||||
if (doc.unfinished) {
|
||||
return;
|
||||
return next();
|
||||
}
|
||||
|
||||
let meiliDoc: Record<string, unknown> | undefined;
|
||||
|
|
@ -509,9 +556,9 @@ export default function mongoMeili(schema: Schema, options: MongoMeiliOptions):
|
|||
}
|
||||
|
||||
if (meiliDoc && meiliDoc.title === doc.title) {
|
||||
return;
|
||||
return next();
|
||||
}
|
||||
|
||||
doc.postSaveHook?.();
|
||||
doc.postSaveHook?.(next);
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,3 +21,4 @@ export { default as tokenSchema } from './token';
|
|||
export { default as toolCallSchema } from './toolCall';
|
||||
export { default as transactionSchema } from './transaction';
|
||||
export { default as userSchema } from './user';
|
||||
export { default as memorySchema } from './memory';
|
||||
|
|
|
|||
33
packages/data-schemas/src/schema/memory.ts
Normal file
33
packages/data-schemas/src/schema/memory.ts
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
import { Schema } from 'mongoose';
|
||||
import type { IMemoryEntry } from '~/types/memory';
|
||||
|
||||
const MemoryEntrySchema: Schema<IMemoryEntry> = new Schema({
|
||||
userId: {
|
||||
type: Schema.Types.ObjectId,
|
||||
ref: 'User',
|
||||
index: true,
|
||||
required: true,
|
||||
},
|
||||
key: {
|
||||
type: String,
|
||||
required: true,
|
||||
validate: {
|
||||
validator: (v: string) => /^[a-z_]+$/.test(v),
|
||||
message: 'Key must only contain lowercase letters and underscores',
|
||||
},
|
||||
},
|
||||
value: {
|
||||
type: String,
|
||||
required: true,
|
||||
},
|
||||
tokenCount: {
|
||||
type: Number,
|
||||
default: 0,
|
||||
},
|
||||
updated_at: {
|
||||
type: Date,
|
||||
default: Date.now,
|
||||
},
|
||||
});
|
||||
|
||||
export default MemoryEntrySchema;
|
||||
|
|
@ -13,6 +13,13 @@ const rolePermissionsSchema = new Schema(
|
|||
[Permissions.USE]: { type: Boolean, default: true },
|
||||
[Permissions.CREATE]: { type: Boolean, default: true },
|
||||
},
|
||||
[PermissionTypes.MEMORIES]: {
|
||||
[Permissions.USE]: { type: Boolean, default: true },
|
||||
[Permissions.CREATE]: { type: Boolean, default: true },
|
||||
[Permissions.UPDATE]: { type: Boolean, default: true },
|
||||
[Permissions.READ]: { type: Boolean, default: true },
|
||||
[Permissions.OPT_OUT]: { type: Boolean, default: true },
|
||||
},
|
||||
[PermissionTypes.AGENTS]: {
|
||||
[Permissions.SHARED_GLOBAL]: { type: Boolean, default: false },
|
||||
[Permissions.USE]: { type: Boolean, default: true },
|
||||
|
|
@ -45,6 +52,12 @@ const roleSchema: Schema<IRole> = new Schema({
|
|||
[Permissions.USE]: true,
|
||||
[Permissions.CREATE]: true,
|
||||
},
|
||||
[PermissionTypes.MEMORIES]: {
|
||||
[Permissions.USE]: true,
|
||||
[Permissions.CREATE]: true,
|
||||
[Permissions.UPDATE]: true,
|
||||
[Permissions.READ]: true,
|
||||
},
|
||||
[PermissionTypes.AGENTS]: {
|
||||
[Permissions.SHARED_GLOBAL]: false,
|
||||
[Permissions.USE]: true,
|
||||
|
|
|
|||
|
|
@ -129,6 +129,15 @@ const userSchema = new Schema<IUser>(
|
|||
type: Boolean,
|
||||
default: false,
|
||||
},
|
||||
personalization: {
|
||||
type: {
|
||||
memories: {
|
||||
type: Boolean,
|
||||
default: true,
|
||||
},
|
||||
},
|
||||
default: {},
|
||||
},
|
||||
},
|
||||
{ timestamps: true },
|
||||
);
|
||||
|
|
|
|||
|
|
@ -1,3 +1,6 @@
|
|||
import type { Types } from 'mongoose';
|
||||
|
||||
export type ObjectId = Types.ObjectId;
|
||||
export * from './user';
|
||||
export * from './token';
|
||||
export * from './convo';
|
||||
|
|
@ -10,3 +13,5 @@ export * from './role';
|
|||
export * from './action';
|
||||
export * from './assistant';
|
||||
export * from './file';
|
||||
/* Memories */
|
||||
export * from './memory';
|
||||
|
|
|
|||
48
packages/data-schemas/src/types/memory.ts
Normal file
48
packages/data-schemas/src/types/memory.ts
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
import type { Types, Document } from 'mongoose';
|
||||
|
||||
// Base memory interfaces
|
||||
export interface IMemoryEntry extends Document {
|
||||
userId: Types.ObjectId;
|
||||
key: string;
|
||||
value: string;
|
||||
tokenCount?: number;
|
||||
updated_at?: Date;
|
||||
}
|
||||
|
||||
export interface IMemoryEntryLean {
|
||||
_id: Types.ObjectId;
|
||||
userId: Types.ObjectId;
|
||||
key: string;
|
||||
value: string;
|
||||
tokenCount?: number;
|
||||
updated_at?: Date;
|
||||
__v?: number;
|
||||
}
|
||||
|
||||
// Method parameter interfaces
|
||||
export interface SetMemoryParams {
|
||||
userId: string | Types.ObjectId;
|
||||
key: string;
|
||||
value: string;
|
||||
tokenCount?: number;
|
||||
}
|
||||
|
||||
export interface DeleteMemoryParams {
|
||||
userId: string | Types.ObjectId;
|
||||
key: string;
|
||||
}
|
||||
|
||||
export interface GetFormattedMemoriesParams {
|
||||
userId: string | Types.ObjectId;
|
||||
}
|
||||
|
||||
// Result interfaces
|
||||
export interface MemoryResult {
|
||||
ok: boolean;
|
||||
}
|
||||
|
||||
export interface FormattedMemoriesResult {
|
||||
withKeys: string;
|
||||
withoutKeys: string;
|
||||
totalTokens?: number;
|
||||
}
|
||||
|
|
@ -12,6 +12,12 @@ export interface IRole extends Document {
|
|||
[Permissions.USE]?: boolean;
|
||||
[Permissions.CREATE]?: boolean;
|
||||
};
|
||||
[PermissionTypes.MEMORIES]?: {
|
||||
[Permissions.USE]?: boolean;
|
||||
[Permissions.CREATE]?: boolean;
|
||||
[Permissions.UPDATE]?: boolean;
|
||||
[Permissions.READ]?: boolean;
|
||||
};
|
||||
[PermissionTypes.AGENTS]?: {
|
||||
[Permissions.SHARED_GLOBAL]?: boolean;
|
||||
[Permissions.USE]?: boolean;
|
||||
|
|
|
|||
|
|
@ -30,6 +30,9 @@ export interface IUser extends Document {
|
|||
}>;
|
||||
expiresAt?: Date;
|
||||
termsAccepted?: boolean;
|
||||
personalization?: {
|
||||
memories?: boolean;
|
||||
};
|
||||
createdAt?: Date;
|
||||
updatedAt?: Date;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,231 +0,0 @@
|
|||
import express from 'express';
|
||||
import { EventSource } from 'eventsource';
|
||||
import { MCPConnection } from '../connection';
|
||||
import type { MCPOptions } from '../types/mcp';
|
||||
|
||||
// Set up EventSource for Node environment
|
||||
global.EventSource = EventSource;
|
||||
|
||||
const app = express();
|
||||
app.use(express.json());
|
||||
|
||||
let mcp: MCPConnection;
|
||||
|
||||
const initializeMCP = async () => {
|
||||
console.log('Initializing MCP with SSE transport...');
|
||||
|
||||
const mcpOptions: MCPOptions = {
|
||||
type: 'sse' as const,
|
||||
url: 'http://localhost:3001/sse',
|
||||
// type: 'stdio' as const,
|
||||
// 'command': 'npx',
|
||||
// 'args': [
|
||||
// '-y',
|
||||
// '@modelcontextprotocol/server-everything',
|
||||
// ],
|
||||
};
|
||||
|
||||
try {
|
||||
await MCPConnection.destroyInstance();
|
||||
mcp = MCPConnection.getInstance('everything', mcpOptions);
|
||||
|
||||
mcp.on('connectionChange', (state) => {
|
||||
console.log(`MCP connection state changed to: ${state}`);
|
||||
});
|
||||
|
||||
mcp.on('error', (error) => {
|
||||
console.error('MCP error:', error);
|
||||
});
|
||||
|
||||
console.log('Connecting to MCP server...');
|
||||
await mcp.connectClient();
|
||||
console.log('Connected to MCP server');
|
||||
|
||||
// Test the connection
|
||||
try {
|
||||
const resources = await mcp.fetchResources();
|
||||
console.log('Available resources:', resources);
|
||||
} catch (error) {
|
||||
console.error('Error fetching resources:', error);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to connect to MCP server:', error);
|
||||
}
|
||||
};
|
||||
|
||||
// API Endpoints
|
||||
app.get('/status', (req, res) => {
|
||||
res.json({
|
||||
connected: mcp.isConnected(),
|
||||
state: mcp.getConnectionState(),
|
||||
error: mcp.getLastError()?.message,
|
||||
});
|
||||
});
|
||||
|
||||
// Resources endpoint
|
||||
app.get('/resources', async (req, res) => {
|
||||
try {
|
||||
const resources = await mcp.fetchResources();
|
||||
res.json({ resources });
|
||||
} catch (error) {
|
||||
res.status(500).json({ error: String(error) });
|
||||
}
|
||||
});
|
||||
|
||||
// Tools endpoint with all tool operations
|
||||
app.get('/tools', async (req, res) => {
|
||||
try {
|
||||
const tools = await mcp.fetchTools();
|
||||
res.json({ tools });
|
||||
} catch (error) {
|
||||
res.status(500).json({ error: String(error) });
|
||||
}
|
||||
});
|
||||
|
||||
// Echo tool endpoint
|
||||
app.post('/tools/echo', async (req, res) => {
|
||||
try {
|
||||
const { message } = req.body;
|
||||
const result = await mcp.client.callTool({
|
||||
name: 'echo',
|
||||
arguments: { message },
|
||||
});
|
||||
res.json(result);
|
||||
} catch (error) {
|
||||
res.status(500).json({ error: String(error) });
|
||||
}
|
||||
});
|
||||
|
||||
// Add tool endpoint
|
||||
app.post('/tools/add', async (req, res) => {
|
||||
try {
|
||||
const { a, b } = req.body;
|
||||
const result = await mcp.client.callTool({
|
||||
name: 'add',
|
||||
arguments: { a, b },
|
||||
});
|
||||
res.json(result);
|
||||
} catch (error) {
|
||||
res.status(500).json({ error: String(error) });
|
||||
}
|
||||
});
|
||||
|
||||
// Long running operation endpoint
|
||||
app.post('/tools/long-operation', async (req, res) => {
|
||||
try {
|
||||
const { duration, steps } = req.body;
|
||||
const result = await mcp.client.callTool({
|
||||
name: 'longRunningOperation',
|
||||
arguments: { duration, steps },
|
||||
});
|
||||
res.json(result);
|
||||
} catch (error) {
|
||||
res.status(500).json({ error: String(error) });
|
||||
}
|
||||
});
|
||||
|
||||
// Sample LLM endpoint
|
||||
app.post('/tools/sample', async (req, res) => {
|
||||
try {
|
||||
const { prompt, maxTokens } = req.body;
|
||||
const result = await mcp.client.callTool({
|
||||
name: 'sampleLLM',
|
||||
arguments: { prompt, maxTokens },
|
||||
});
|
||||
res.json(result);
|
||||
} catch (error) {
|
||||
res.status(500).json({ error: String(error) });
|
||||
}
|
||||
});
|
||||
|
||||
// Get tiny image endpoint
|
||||
app.get('/tools/tiny-image', async (req, res) => {
|
||||
try {
|
||||
const result = await mcp.client.callTool({
|
||||
name: 'getTinyImage',
|
||||
arguments: {},
|
||||
});
|
||||
res.json(result);
|
||||
} catch (error) {
|
||||
res.status(500).json({ error: String(error) });
|
||||
}
|
||||
});
|
||||
|
||||
// Prompts endpoints
|
||||
app.get('/prompts', async (req, res) => {
|
||||
try {
|
||||
const prompts = await mcp.fetchPrompts();
|
||||
res.json({ prompts });
|
||||
} catch (error) {
|
||||
res.status(500).json({ error: String(error) });
|
||||
}
|
||||
});
|
||||
|
||||
app.post('/prompts/simple', async (req, res) => {
|
||||
try {
|
||||
const result = await mcp.client.getPrompt({
|
||||
name: 'simple_prompt',
|
||||
});
|
||||
res.json(result);
|
||||
} catch (error) {
|
||||
res.status(500).json({ error: String(error) });
|
||||
}
|
||||
});
|
||||
|
||||
app.post('/prompts/complex', async (req, res) => {
|
||||
try {
|
||||
const { temperature, style } = req.body;
|
||||
const result = await mcp.client.getPrompt({
|
||||
name: 'complex_prompt',
|
||||
arguments: { temperature, style },
|
||||
});
|
||||
res.json(result);
|
||||
} catch (error) {
|
||||
res.status(500).json({ error: String(error) });
|
||||
}
|
||||
});
|
||||
|
||||
// Resource subscription endpoints
|
||||
app.post('/resources/subscribe', async (req, res) => {
|
||||
try {
|
||||
const { uri } = req.body;
|
||||
await mcp.client.subscribeResource({ uri });
|
||||
res.json({ success: true });
|
||||
} catch (error) {
|
||||
res.status(500).json({ error: String(error) });
|
||||
}
|
||||
});
|
||||
|
||||
app.post('/resources/unsubscribe', async (req, res) => {
|
||||
try {
|
||||
const { uri } = req.body;
|
||||
await mcp.client.unsubscribeResource({ uri });
|
||||
res.json({ success: true });
|
||||
} catch (error) {
|
||||
res.status(500).json({ error: String(error) });
|
||||
}
|
||||
});
|
||||
|
||||
// Error handling
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
app.use((err: Error, req: express.Request, res: express.Response, next: express.NextFunction) => {
|
||||
console.error('Unhandled error:', err);
|
||||
res.status(500).json({
|
||||
error: 'Internal server error',
|
||||
message: err.message,
|
||||
});
|
||||
});
|
||||
|
||||
// Cleanup on shutdown
|
||||
process.on('SIGINT', async () => {
|
||||
console.log('Shutting down...');
|
||||
await MCPConnection.destroyInstance();
|
||||
process.exit(0);
|
||||
});
|
||||
|
||||
// Start server
|
||||
const PORT = process.env.MCP_PORT ?? 3000;
|
||||
app.listen(PORT, () => {
|
||||
console.log(`Server running on http://localhost:${PORT}`);
|
||||
initializeMCP();
|
||||
});
|
||||
|
|
@ -1,211 +0,0 @@
|
|||
import express from 'express';
|
||||
import { EventSource } from 'eventsource';
|
||||
import { MCPConnection } from '../connection';
|
||||
import type { MCPOptions } from '../types/mcp';
|
||||
|
||||
// Set up EventSource for Node environment
|
||||
global.EventSource = EventSource;
|
||||
|
||||
const app = express();
|
||||
app.use(express.json());
|
||||
|
||||
let mcp: MCPConnection;
|
||||
|
||||
const initializeMCP = async () => {
|
||||
console.log('Initializing MCP with SSE transport...');
|
||||
|
||||
const mcpOptions: MCPOptions = {
|
||||
// type: 'sse' as const,
|
||||
// url: 'http://localhost:3001/sse',
|
||||
type: 'stdio' as const,
|
||||
command: 'npx',
|
||||
args: ['-y', '@modelcontextprotocol/server-filesystem', '/home/danny/LibreChat/'],
|
||||
};
|
||||
|
||||
try {
|
||||
// Clean up any existing instance
|
||||
await MCPConnection.destroyInstance();
|
||||
|
||||
// Get singleton instance
|
||||
mcp = MCPConnection.getInstance('filesystem', mcpOptions);
|
||||
|
||||
// Add event listeners
|
||||
mcp.on('connectionChange', (state) => {
|
||||
console.log(`MCP connection state changed to: ${state}`);
|
||||
});
|
||||
|
||||
mcp.on('error', (error) => {
|
||||
console.error('MCP error:', error);
|
||||
});
|
||||
|
||||
// Connect to server
|
||||
console.log('Connecting to MCP server...');
|
||||
await mcp.connectClient();
|
||||
console.log('Connected to MCP server');
|
||||
} catch (error) {
|
||||
console.error('Failed to connect to MCP server:', error);
|
||||
}
|
||||
};
|
||||
|
||||
// Initialize MCP connection
|
||||
initializeMCP();
|
||||
|
||||
// API Endpoints
|
||||
app.get('/status', (req, res) => {
|
||||
res.json({
|
||||
connected: mcp.isConnected(),
|
||||
state: mcp.getConnectionState(),
|
||||
error: mcp.getLastError()?.message,
|
||||
});
|
||||
});
|
||||
|
||||
app.get('/resources', async (req, res) => {
|
||||
try {
|
||||
const resources = await mcp.fetchResources();
|
||||
res.json({ resources });
|
||||
} catch (error) {
|
||||
res.status(500).json({ error: String(error) });
|
||||
}
|
||||
});
|
||||
|
||||
app.get('/tools', async (req, res) => {
|
||||
try {
|
||||
const tools = await mcp.fetchTools();
|
||||
res.json({ tools });
|
||||
} catch (error) {
|
||||
res.status(500).json({ error: String(error) });
|
||||
}
|
||||
});
|
||||
|
||||
// File operations
|
||||
// @ts-ignore
|
||||
app.get('/files/read', async (req, res) => {
|
||||
const filePath = req.query.path as string;
|
||||
if (!filePath) {
|
||||
return res.status(400).json({ error: 'Path parameter is required' });
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await mcp.client.callTool({
|
||||
name: 'read_file',
|
||||
arguments: { path: filePath },
|
||||
});
|
||||
res.json(result);
|
||||
} catch (error) {
|
||||
res.status(500).json({ error: String(error) });
|
||||
}
|
||||
});
|
||||
|
||||
// @ts-ignore
|
||||
app.post('/files/write', async (req, res) => {
|
||||
const { path, content } = req.body;
|
||||
if (!path || content === undefined) {
|
||||
return res.status(400).json({ error: 'Path and content are required' });
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await mcp.client.callTool({
|
||||
name: 'write_file',
|
||||
arguments: { path, content },
|
||||
});
|
||||
res.json(result);
|
||||
} catch (error) {
|
||||
res.status(500).json({ error: String(error) });
|
||||
}
|
||||
});
|
||||
|
||||
// @ts-ignore
|
||||
app.post('/files/edit', async (req, res) => {
|
||||
const { path, edits, dryRun = false } = req.body;
|
||||
if (!path || !edits) {
|
||||
return res.status(400).json({ error: 'Path and edits are required' });
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await mcp.client.callTool({
|
||||
name: 'edit_file',
|
||||
arguments: { path, edits, dryRun },
|
||||
});
|
||||
res.json(result);
|
||||
} catch (error) {
|
||||
res.status(500).json({ error: String(error) });
|
||||
}
|
||||
});
|
||||
|
||||
// Directory operations
|
||||
// @ts-ignore
|
||||
app.get('/directory/list', async (req, res) => {
|
||||
const dirPath = req.query.path as string;
|
||||
if (!dirPath) {
|
||||
return res.status(400).json({ error: 'Path parameter is required' });
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await mcp.client.callTool({
|
||||
name: 'list_directory',
|
||||
arguments: { path: dirPath },
|
||||
});
|
||||
res.json(result);
|
||||
} catch (error) {
|
||||
res.status(500).json({ error: String(error) });
|
||||
}
|
||||
});
|
||||
|
||||
// @ts-ignore
|
||||
app.post('/directory/create', async (req, res) => {
|
||||
const { path } = req.body;
|
||||
if (!path) {
|
||||
return res.status(400).json({ error: 'Path is required' });
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await mcp.client.callTool({
|
||||
name: 'create_directory',
|
||||
arguments: { path },
|
||||
});
|
||||
res.json(result);
|
||||
} catch (error) {
|
||||
res.status(500).json({ error: String(error) });
|
||||
}
|
||||
});
|
||||
|
||||
// Search endpoint
|
||||
// @ts-ignore
|
||||
app.get('/search', async (req, res) => {
|
||||
const { path, pattern } = req.query;
|
||||
if (!path || !pattern) {
|
||||
return res.status(400).json({ error: 'Path and pattern parameters are required' });
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await mcp.client.callTool({
|
||||
name: 'search_files',
|
||||
arguments: { path, pattern },
|
||||
});
|
||||
res.json(result);
|
||||
} catch (error) {
|
||||
res.status(500).json({ error: String(error) });
|
||||
}
|
||||
});
|
||||
|
||||
// Error handling
|
||||
app.use((err: Error, req: express.Request, res: express.Response, next: express.NextFunction) => {
|
||||
console.error('Unhandled error:', err);
|
||||
res.status(500).json({
|
||||
error: 'Internal server error',
|
||||
message: err.message,
|
||||
});
|
||||
});
|
||||
|
||||
// Cleanup on shutdown
|
||||
process.on('SIGINT', async () => {
|
||||
console.log('Shutting down...');
|
||||
await MCPConnection.destroyInstance();
|
||||
process.exit(0);
|
||||
});
|
||||
|
||||
// Start server
|
||||
const PORT = process.env.MCP_PORT || 3000;
|
||||
app.listen(PORT, () => {
|
||||
console.log(`Server running on http://localhost:${PORT}`);
|
||||
});
|
||||
|
|
@ -1,226 +0,0 @@
|
|||
// server.ts
|
||||
import express from 'express';
|
||||
import { EventSource } from 'eventsource';
|
||||
import { MCPManager } from '../manager';
|
||||
import { MCPConnection } from '../connection';
|
||||
import type * as t from '../types/mcp';
|
||||
|
||||
// Set up EventSource for Node environment
|
||||
global.EventSource = EventSource;
|
||||
|
||||
const app = express();
|
||||
app.use(express.json());
|
||||
|
||||
const mcpManager = MCPManager.getInstance();
|
||||
|
||||
const mcpServers: t.MCPServers = {
|
||||
everything: {
|
||||
type: 'sse' as const,
|
||||
url: 'http://localhost:3001/sse',
|
||||
},
|
||||
filesystem: {
|
||||
type: 'stdio' as const,
|
||||
command: 'npx',
|
||||
args: ['-y', '@modelcontextprotocol/server-filesystem', '/home/danny/LibreChat/'],
|
||||
},
|
||||
};
|
||||
|
||||
// Generic helper to get connection and handle errors
|
||||
const withConnection = async (
|
||||
serverName: string,
|
||||
res: express.Response,
|
||||
callback: (connection: MCPConnection) => Promise<void>,
|
||||
) => {
|
||||
const connection = mcpManager.getConnection(serverName);
|
||||
if (!connection) {
|
||||
return res.status(404).json({ error: `Server "${serverName}" not found` });
|
||||
}
|
||||
try {
|
||||
await callback(connection);
|
||||
} catch (error) {
|
||||
res.status(500).json({ error: String(error) });
|
||||
}
|
||||
};
|
||||
|
||||
// Common endpoints for all servers
|
||||
// @ts-ignore
|
||||
app.get('/status/:server', (req, res) => {
|
||||
const connection = mcpManager.getConnection(req.params.server);
|
||||
if (!connection) {
|
||||
return res.status(404).json({ error: 'Server not found' });
|
||||
}
|
||||
|
||||
res.json({
|
||||
connected: connection.isConnected(),
|
||||
state: connection.getConnectionState(),
|
||||
error: connection.getLastError()?.message,
|
||||
});
|
||||
});
|
||||
|
||||
app.get('/resources/:server', async (req, res) => {
|
||||
await withConnection(req.params.server, res, async (connection) => {
|
||||
const resources = await connection.fetchResources();
|
||||
res.json({ resources });
|
||||
});
|
||||
});
|
||||
|
||||
app.get('/tools/:server', async (req, res) => {
|
||||
await withConnection(req.params.server, res, async (connection) => {
|
||||
const tools = await connection.fetchTools();
|
||||
res.json({ tools });
|
||||
});
|
||||
});
|
||||
|
||||
// "Everything" server specific endpoints
|
||||
app.post('/everything/tools/echo', async (req, res) => {
|
||||
await withConnection('everything', res, async (connection) => {
|
||||
const { message } = req.body;
|
||||
const result = await connection.client.callTool({
|
||||
name: 'echo',
|
||||
arguments: { message },
|
||||
});
|
||||
res.json(result);
|
||||
});
|
||||
});
|
||||
|
||||
app.post('/everything/tools/add', async (req, res) => {
|
||||
await withConnection('everything', res, async (connection) => {
|
||||
const { a, b } = req.body;
|
||||
const result = await connection.client.callTool({
|
||||
name: 'add',
|
||||
arguments: { a, b },
|
||||
});
|
||||
res.json(result);
|
||||
});
|
||||
});
|
||||
|
||||
app.post('/everything/tools/long-operation', async (req, res) => {
|
||||
await withConnection('everything', res, async (connection) => {
|
||||
const { duration, steps } = req.body;
|
||||
const result = await connection.client.callTool({
|
||||
name: 'longRunningOperation',
|
||||
arguments: { duration, steps },
|
||||
});
|
||||
res.json(result);
|
||||
});
|
||||
});
|
||||
|
||||
// Filesystem server specific endpoints
|
||||
// @ts-ignore
|
||||
app.get('/filesystem/files/read', async (req, res) => {
|
||||
const filePath = req.query.path as string;
|
||||
if (!filePath) {
|
||||
return res.status(400).json({ error: 'Path parameter is required' });
|
||||
}
|
||||
|
||||
await withConnection('filesystem', res, async (connection) => {
|
||||
const result = await connection.client.callTool({
|
||||
name: 'read_file',
|
||||
arguments: { path: filePath },
|
||||
});
|
||||
res.json(result);
|
||||
});
|
||||
});
|
||||
|
||||
// @ts-ignore
|
||||
app.post('/filesystem/files/write', async (req, res) => {
|
||||
const { path, content } = req.body;
|
||||
if (!path || content === undefined) {
|
||||
return res.status(400).json({ error: 'Path and content are required' });
|
||||
}
|
||||
|
||||
await withConnection('filesystem', res, async (connection) => {
|
||||
const result = await connection.client.callTool({
|
||||
name: 'write_file',
|
||||
arguments: { path, content },
|
||||
});
|
||||
res.json(result);
|
||||
});
|
||||
});
|
||||
|
||||
// @ts-ignore
|
||||
app.post('/filesystem/files/edit', async (req, res) => {
|
||||
const { path, edits, dryRun = false } = req.body;
|
||||
if (!path || !edits) {
|
||||
return res.status(400).json({ error: 'Path and edits are required' });
|
||||
}
|
||||
|
||||
await withConnection('filesystem', res, async (connection) => {
|
||||
const result = await connection.client.callTool({
|
||||
name: 'edit_file',
|
||||
arguments: { path, edits, dryRun },
|
||||
});
|
||||
res.json(result);
|
||||
});
|
||||
});
|
||||
|
||||
// @ts-ignore
|
||||
app.get('/filesystem/directory/list', async (req, res) => {
|
||||
const dirPath = req.query.path as string;
|
||||
if (!dirPath) {
|
||||
return res.status(400).json({ error: 'Path parameter is required' });
|
||||
}
|
||||
|
||||
await withConnection('filesystem', res, async (connection) => {
|
||||
const result = await connection.client.callTool({
|
||||
name: 'list_directory',
|
||||
arguments: { path: dirPath },
|
||||
});
|
||||
res.json(result);
|
||||
});
|
||||
});
|
||||
|
||||
// @ts-ignore
|
||||
app.post('/filesystem/directory/create', async (req, res) => {
|
||||
const { path } = req.body;
|
||||
if (!path) {
|
||||
return res.status(400).json({ error: 'Path is required' });
|
||||
}
|
||||
|
||||
await withConnection('filesystem', res, async (connection) => {
|
||||
const result = await connection.client.callTool({
|
||||
name: 'create_directory',
|
||||
arguments: { path },
|
||||
});
|
||||
res.json(result);
|
||||
});
|
||||
});
|
||||
|
||||
// @ts-ignore
|
||||
app.get('/filesystem/search', async (req, res) => {
|
||||
const { path, pattern } = req.query;
|
||||
if (!path || !pattern) {
|
||||
return res.status(400).json({ error: 'Path and pattern parameters are required' });
|
||||
}
|
||||
|
||||
await withConnection('filesystem', res, async (connection) => {
|
||||
const result = await connection.client.callTool({
|
||||
name: 'search_files',
|
||||
arguments: { path, pattern },
|
||||
});
|
||||
res.json(result);
|
||||
});
|
||||
});
|
||||
|
||||
// Error handling
|
||||
app.use((err: Error, req: express.Request, res: express.Response, next: express.NextFunction) => {
|
||||
console.error('Unhandled error:', err);
|
||||
res.status(500).json({
|
||||
error: 'Internal server error',
|
||||
message: err.message,
|
||||
});
|
||||
});
|
||||
|
||||
// Cleanup on shutdown
|
||||
process.on('SIGINT', async () => {
|
||||
console.log('Shutting down...');
|
||||
await MCPManager.destroyInstance();
|
||||
process.exit(0);
|
||||
});
|
||||
|
||||
// Start server
|
||||
const PORT = process.env.MCP_PORT ?? 3000;
|
||||
app.listen(PORT, async () => {
|
||||
console.log(`Server running on http://localhost:${PORT}`);
|
||||
await mcpManager.initializeMCP(mcpServers);
|
||||
});
|
||||
File diff suppressed because one or more lines are too long
|
|
@ -1,23 +0,0 @@
|
|||
#!/usr/bin/env node
|
||||
|
||||
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js';
|
||||
import { createServer } from './everything';
|
||||
|
||||
async function main() {
|
||||
const transport = new StdioServerTransport();
|
||||
const { server, cleanup } = createServer();
|
||||
|
||||
await server.connect(transport);
|
||||
|
||||
// Cleanup on exit
|
||||
process.on('SIGINT', async () => {
|
||||
await cleanup();
|
||||
await server.close();
|
||||
process.exit(0);
|
||||
});
|
||||
}
|
||||
|
||||
main().catch((error) => {
|
||||
console.error('Server error:', error);
|
||||
process.exit(1);
|
||||
});
|
||||
|
|
@ -1,24 +0,0 @@
|
|||
import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js';
|
||||
import express from 'express';
|
||||
import { createServer } from './everything.js';
|
||||
const app = express();
|
||||
const { server, cleanup } = createServer();
|
||||
let transport: SSEServerTransport;
|
||||
app.get('/sse', async (req, res) => {
|
||||
console.log('Received connection');
|
||||
transport = new SSEServerTransport('/message', res);
|
||||
await server.connect(transport);
|
||||
server.onclose = async () => {
|
||||
await cleanup();
|
||||
await server.close();
|
||||
process.exit(0);
|
||||
};
|
||||
});
|
||||
app.post('/message', async (req, res) => {
|
||||
console.log('Received message');
|
||||
await transport.handlePostMessage(req, res);
|
||||
});
|
||||
const PORT = process.env.SSE_PORT ?? 3001;
|
||||
app.listen(PORT, () => {
|
||||
console.log(`Server is running on port ${PORT}`);
|
||||
});
|
||||
|
|
@ -1,700 +0,0 @@
|
|||
#!/usr/bin/env node
|
||||
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
|
||||
// @ts-nocheck
|
||||
import { Server } from '@modelcontextprotocol/sdk/server/index.js';
|
||||
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js';
|
||||
import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js';
|
||||
import {
|
||||
JSONRPCMessage,
|
||||
CallToolRequestSchema,
|
||||
ListToolsRequestSchema,
|
||||
InitializeRequestSchema,
|
||||
ToolSchema,
|
||||
} from '@modelcontextprotocol/sdk/types.js';
|
||||
import fs from 'fs/promises';
|
||||
import path from 'path';
|
||||
import os from 'os';
|
||||
import { z } from 'zod';
|
||||
import { zodToJsonSchema } from 'zod-to-json-schema';
|
||||
import { diffLines, createTwoFilesPatch } from 'diff';
|
||||
import { IncomingMessage, ServerResponse } from 'node:http';
|
||||
import { minimatch } from 'minimatch';
|
||||
import express from 'express';
|
||||
|
||||
function normalizePath(p: string): string {
|
||||
return path.normalize(p).toLowerCase();
|
||||
}
|
||||
|
||||
function expandHome(filepath: string): string {
|
||||
if (filepath.startsWith('~/') || filepath === '~') {
|
||||
return path.join(os.homedir(), filepath.slice(1));
|
||||
}
|
||||
return filepath;
|
||||
}
|
||||
|
||||
// Command line argument parsing
|
||||
const args = process.argv.slice(2);
|
||||
|
||||
// Parse command line arguments for transport type
|
||||
const transportArg = args.find((arg) => arg.startsWith('--transport='));
|
||||
const portArg = args.find((arg) => arg.startsWith('--port='));
|
||||
const directories = args.filter((arg) => !arg.startsWith('--'));
|
||||
|
||||
if (directories.length === 0) {
|
||||
console.error(
|
||||
'Usage: mcp-server-filesystem [--transport=stdio|sse] [--port=3000] <allowed-directory> [additional-directories...]',
|
||||
);
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
// Extract transport type and port from arguments
|
||||
const transport = transportArg ? (transportArg.split('=')[1] as 'stdio' | 'sse') : 'stdio';
|
||||
|
||||
const port = portArg ? parseInt(portArg.split('=')[1], 10) : undefined;
|
||||
|
||||
// Store allowed directories in normalized form
|
||||
const allowedDirectories = directories.map((dir) => normalizePath(path.resolve(expandHome(dir))));
|
||||
|
||||
// Validate that all directories exist and are accessible
|
||||
/** @ts-ignore */
|
||||
await Promise.all(
|
||||
directories.map(async (dir) => {
|
||||
try {
|
||||
const stats = await fs.stat(dir);
|
||||
if (!stats.isDirectory()) {
|
||||
console.error(`Error: ${dir} is not a directory`);
|
||||
process.exit(1);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`Error accessing directory ${dir}:`, error);
|
||||
process.exit(1);
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
// Security utilities
|
||||
async function validatePath(requestedPath: string): Promise<string> {
|
||||
const expandedPath = expandHome(requestedPath);
|
||||
const absolute = path.isAbsolute(expandedPath)
|
||||
? path.resolve(expandedPath)
|
||||
: path.resolve(process.cwd(), expandedPath);
|
||||
|
||||
const normalizedRequested = normalizePath(absolute);
|
||||
|
||||
// Check if path is within allowed directories
|
||||
const isAllowed = allowedDirectories.some((dir) => normalizedRequested.startsWith(dir));
|
||||
if (!isAllowed) {
|
||||
throw new Error(
|
||||
`Access denied - path outside allowed directories: ${absolute} not in ${allowedDirectories.join(
|
||||
', ',
|
||||
)}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Handle symlinks by checking their real path
|
||||
try {
|
||||
const realPath = await fs.realpath(absolute);
|
||||
const normalizedReal = normalizePath(realPath);
|
||||
const isRealPathAllowed = allowedDirectories.some((dir) => normalizedReal.startsWith(dir));
|
||||
if (!isRealPathAllowed) {
|
||||
throw new Error('Access denied - symlink target outside allowed directories');
|
||||
}
|
||||
return realPath;
|
||||
} catch (error) {
|
||||
// For new files that don't exist yet, verify parent directory
|
||||
const parentDir = path.dirname(absolute);
|
||||
try {
|
||||
const realParentPath = await fs.realpath(parentDir);
|
||||
const normalizedParent = normalizePath(realParentPath);
|
||||
const isParentAllowed = allowedDirectories.some((dir) => normalizedParent.startsWith(dir));
|
||||
if (!isParentAllowed) {
|
||||
throw new Error('Access denied - parent directory outside allowed directories');
|
||||
}
|
||||
return absolute;
|
||||
} catch {
|
||||
throw new Error(`Parent directory does not exist: ${parentDir}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Schema definitions
|
||||
const ReadFileArgsSchema = z.object({
|
||||
path: z.string(),
|
||||
});
|
||||
|
||||
const ReadMultipleFilesArgsSchema = z.object({
|
||||
paths: z.array(z.string()),
|
||||
});
|
||||
|
||||
const WriteFileArgsSchema = z.object({
|
||||
path: z.string(),
|
||||
content: z.string(),
|
||||
});
|
||||
|
||||
const EditOperation = z.object({
|
||||
oldText: z.string().describe('Text to search for - must match exactly'),
|
||||
newText: z.string().describe('Text to replace with'),
|
||||
});
|
||||
|
||||
const EditFileArgsSchema = z.object({
|
||||
path: z.string(),
|
||||
edits: z.array(EditOperation),
|
||||
dryRun: z.boolean().default(false).describe('Preview changes using git-style diff format'),
|
||||
});
|
||||
|
||||
const CreateDirectoryArgsSchema = z.object({
|
||||
path: z.string(),
|
||||
});
|
||||
|
||||
const ListDirectoryArgsSchema = z.object({
|
||||
path: z.string(),
|
||||
});
|
||||
|
||||
const MoveFileArgsSchema = z.object({
|
||||
source: z.string(),
|
||||
destination: z.string(),
|
||||
});
|
||||
|
||||
const SearchFilesArgsSchema = z.object({
|
||||
path: z.string(),
|
||||
pattern: z.string(),
|
||||
excludePatterns: z.array(z.string()).optional().default([]),
|
||||
});
|
||||
|
||||
const GetFileInfoArgsSchema = z.object({
|
||||
path: z.string(),
|
||||
});
|
||||
|
||||
const ToolInputSchema = ToolSchema.shape.inputSchema;
|
||||
type ToolInput = z.infer<typeof ToolInputSchema>;
|
||||
|
||||
interface FileInfo {
|
||||
size: number;
|
||||
created: Date;
|
||||
modified: Date;
|
||||
accessed: Date;
|
||||
isDirectory: boolean;
|
||||
isFile: boolean;
|
||||
permissions: string;
|
||||
}
|
||||
|
||||
// Server setup
|
||||
const server = new Server(
|
||||
{
|
||||
name: 'secure-filesystem-server',
|
||||
version: '0.2.0',
|
||||
},
|
||||
{
|
||||
capabilities: {
|
||||
tools: {},
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
// Tool implementations
|
||||
async function getFileStats(filePath: string): Promise<FileInfo> {
|
||||
const stats = await fs.stat(filePath);
|
||||
return {
|
||||
size: stats.size,
|
||||
created: stats.birthtime,
|
||||
modified: stats.mtime,
|
||||
accessed: stats.atime,
|
||||
isDirectory: stats.isDirectory(),
|
||||
isFile: stats.isFile(),
|
||||
permissions: stats.mode.toString(8).slice(-3),
|
||||
};
|
||||
}
|
||||
|
||||
async function searchFiles(
|
||||
rootPath: string,
|
||||
pattern: string,
|
||||
excludePatterns: string[] = [],
|
||||
): Promise<string[]> {
|
||||
const results: string[] = [];
|
||||
|
||||
async function search(currentPath: string) {
|
||||
const entries = await fs.readdir(currentPath, { withFileTypes: true });
|
||||
|
||||
for (const entry of entries) {
|
||||
const fullPath = path.join(currentPath, entry.name);
|
||||
|
||||
try {
|
||||
// Validate each path before processing
|
||||
await validatePath(fullPath);
|
||||
|
||||
// Check if path matches any exclude pattern
|
||||
const relativePath = path.relative(rootPath, fullPath);
|
||||
const shouldExclude = excludePatterns.some((pattern) => {
|
||||
const globPattern = pattern.includes('*') ? pattern : `**/${pattern}/**`;
|
||||
return minimatch(relativePath, globPattern, { dot: true });
|
||||
});
|
||||
|
||||
if (shouldExclude) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (entry.name.toLowerCase().includes(pattern.toLowerCase())) {
|
||||
results.push(fullPath);
|
||||
}
|
||||
|
||||
if (entry.isDirectory()) {
|
||||
await search(fullPath);
|
||||
}
|
||||
} catch (error) {
|
||||
// Skip invalid paths during search
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
await search(rootPath);
|
||||
return results;
|
||||
}
|
||||
|
||||
// file editing and diffing utilities
|
||||
function normalizeLineEndings(text: string): string {
|
||||
return text.replace(/\r\n/g, '\n');
|
||||
}
|
||||
|
||||
function createUnifiedDiff(originalContent: string, newContent: string, filepath = 'file'): string {
|
||||
// Ensure consistent line endings for diff
|
||||
const normalizedOriginal = normalizeLineEndings(originalContent);
|
||||
const normalizedNew = normalizeLineEndings(newContent);
|
||||
|
||||
return createTwoFilesPatch(
|
||||
filepath,
|
||||
filepath,
|
||||
normalizedOriginal,
|
||||
normalizedNew,
|
||||
'original',
|
||||
'modified',
|
||||
);
|
||||
}
|
||||
|
||||
async function applyFileEdits(
|
||||
filePath: string,
|
||||
edits: Array<{ oldText: string; newText: string }>,
|
||||
dryRun = false,
|
||||
): Promise<string> {
|
||||
// Read file content and normalize line endings
|
||||
const content = normalizeLineEndings(await fs.readFile(filePath, 'utf-8'));
|
||||
|
||||
// Apply edits sequentially
|
||||
let modifiedContent = content;
|
||||
for (const edit of edits) {
|
||||
const normalizedOld = normalizeLineEndings(edit.oldText);
|
||||
const normalizedNew = normalizeLineEndings(edit.newText);
|
||||
|
||||
// If exact match exists, use it
|
||||
if (modifiedContent.includes(normalizedOld)) {
|
||||
modifiedContent = modifiedContent.replace(normalizedOld, normalizedNew);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Otherwise, try line-by-line matching with flexibility for whitespace
|
||||
const oldLines = normalizedOld.split('\n');
|
||||
const contentLines = modifiedContent.split('\n');
|
||||
let matchFound = false;
|
||||
|
||||
for (let i = 0; i <= contentLines.length - oldLines.length; i++) {
|
||||
const potentialMatch = contentLines.slice(i, i + oldLines.length);
|
||||
|
||||
// Compare lines with normalized whitespace
|
||||
const isMatch = oldLines.every((oldLine, j) => {
|
||||
const contentLine = potentialMatch[j];
|
||||
return oldLine.trim() === contentLine.trim();
|
||||
});
|
||||
|
||||
if (isMatch) {
|
||||
// Preserve original indentation of first line
|
||||
const originalIndent = contentLines[i].match(/^\s*/)?.[0] || '';
|
||||
const newLines = normalizedNew.split('\n').map((line, j) => {
|
||||
if (j === 0) {
|
||||
return originalIndent + line.trimStart();
|
||||
}
|
||||
// For subsequent lines, try to preserve relative indentation
|
||||
const oldIndent = oldLines[j]?.match(/^\s*/)?.[0] || '';
|
||||
const newIndent = line.match(/^\s*/)?.[0] || '';
|
||||
if (oldIndent && newIndent) {
|
||||
const relativeIndent = newIndent.length - oldIndent.length;
|
||||
return originalIndent + ' '.repeat(Math.max(0, relativeIndent)) + line.trimStart();
|
||||
}
|
||||
return line;
|
||||
});
|
||||
|
||||
contentLines.splice(i, oldLines.length, ...newLines);
|
||||
modifiedContent = contentLines.join('\n');
|
||||
matchFound = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!matchFound) {
|
||||
throw new Error(`Could not find exact match for edit:\n${edit.oldText}`);
|
||||
}
|
||||
}
|
||||
|
||||
// Create unified diff
|
||||
const diff = createUnifiedDiff(content, modifiedContent, filePath);
|
||||
|
||||
// Format diff with appropriate number of backticks
|
||||
let numBackticks = 3;
|
||||
while (diff.includes('`'.repeat(numBackticks))) {
|
||||
numBackticks++;
|
||||
}
|
||||
const formattedDiff = `${'`'.repeat(numBackticks)}diff\n${diff}${'`'.repeat(numBackticks)}\n\n`;
|
||||
|
||||
if (!dryRun) {
|
||||
await fs.writeFile(filePath, modifiedContent, 'utf-8');
|
||||
}
|
||||
|
||||
return formattedDiff;
|
||||
}
|
||||
|
||||
// Tool handlers
|
||||
server.setRequestHandler(ListToolsRequestSchema, async () => {
|
||||
return {
|
||||
tools: [
|
||||
{
|
||||
name: 'read_file',
|
||||
description:
|
||||
'Read the complete contents of a file from the file system. ' +
|
||||
'Handles various text encodings and provides detailed error messages ' +
|
||||
'if the file cannot be read. Use this tool when you need to examine ' +
|
||||
'the contents of a single file. Only works within allowed directories.',
|
||||
inputSchema: zodToJsonSchema(ReadFileArgsSchema) as ToolInput,
|
||||
},
|
||||
{
|
||||
name: 'read_multiple_files',
|
||||
description:
|
||||
'Read the contents of multiple files simultaneously. This is more ' +
|
||||
'efficient than reading files one by one when you need to analyze ' +
|
||||
'or compare multiple files. Each file\'s content is returned with its ' +
|
||||
'path as a reference. Failed reads for individual files won\'t stop ' +
|
||||
'the entire operation. Only works within allowed directories.',
|
||||
inputSchema: zodToJsonSchema(ReadMultipleFilesArgsSchema) as ToolInput,
|
||||
},
|
||||
{
|
||||
name: 'write_file',
|
||||
description:
|
||||
'Create a new file or completely overwrite an existing file with new content. ' +
|
||||
'Use with caution as it will overwrite existing files without warning. ' +
|
||||
'Handles text content with proper encoding. Only works within allowed directories.',
|
||||
inputSchema: zodToJsonSchema(WriteFileArgsSchema) as ToolInput,
|
||||
},
|
||||
{
|
||||
name: 'edit_file',
|
||||
description:
|
||||
'Make line-based edits to a text file. Each edit replaces exact line sequences ' +
|
||||
'with new content. Returns a git-style diff showing the changes made. ' +
|
||||
'Only works within allowed directories.',
|
||||
inputSchema: zodToJsonSchema(EditFileArgsSchema) as ToolInput,
|
||||
},
|
||||
{
|
||||
name: 'create_directory',
|
||||
description:
|
||||
'Create a new directory or ensure a directory exists. Can create multiple ' +
|
||||
'nested directories in one operation. If the directory already exists, ' +
|
||||
'this operation will succeed silently. Perfect for setting up directory ' +
|
||||
'structures for projects or ensuring required paths exist. Only works within allowed directories.',
|
||||
inputSchema: zodToJsonSchema(CreateDirectoryArgsSchema) as ToolInput,
|
||||
},
|
||||
{
|
||||
name: 'list_directory',
|
||||
description:
|
||||
'Get a detailed listing of all files and directories in a specified path. ' +
|
||||
'Results clearly distinguish between files and directories with [FILE] and [DIR] ' +
|
||||
'prefixes. This tool is essential for understanding directory structure and ' +
|
||||
'finding specific files within a directory. Only works within allowed directories.',
|
||||
inputSchema: zodToJsonSchema(ListDirectoryArgsSchema) as ToolInput,
|
||||
},
|
||||
{
|
||||
name: 'move_file',
|
||||
description:
|
||||
'Move or rename files and directories. Can move files between directories ' +
|
||||
'and rename them in a single operation. If the destination exists, the ' +
|
||||
'operation will fail. Works across different directories and can be used ' +
|
||||
'for simple renaming within the same directory. Both source and destination must be within allowed directories.',
|
||||
inputSchema: zodToJsonSchema(MoveFileArgsSchema) as ToolInput,
|
||||
},
|
||||
{
|
||||
name: 'search_files',
|
||||
description:
|
||||
'Recursively search for files and directories matching a pattern. ' +
|
||||
'Searches through all subdirectories from the starting path. The search ' +
|
||||
'is case-insensitive and matches partial names. Returns full paths to all ' +
|
||||
'matching items. Great for finding files when you don\'t know their exact location. ' +
|
||||
'Only searches within allowed directories.',
|
||||
inputSchema: zodToJsonSchema(SearchFilesArgsSchema) as ToolInput,
|
||||
},
|
||||
{
|
||||
name: 'get_file_info',
|
||||
description:
|
||||
'Retrieve detailed metadata about a file or directory. Returns comprehensive ' +
|
||||
'information including size, creation time, last modified time, permissions, ' +
|
||||
'and type. This tool is perfect for understanding file characteristics ' +
|
||||
'without reading the actual content. Only works within allowed directories.',
|
||||
inputSchema: zodToJsonSchema(GetFileInfoArgsSchema) as ToolInput,
|
||||
},
|
||||
{
|
||||
name: 'list_allowed_directories',
|
||||
description:
|
||||
'Returns the list of directories that this server is allowed to access. ' +
|
||||
'Use this to understand which directories are available before trying to access files.',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {},
|
||||
required: [],
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
});
|
||||
|
||||
server.setRequestHandler(CallToolRequestSchema, async (request) => {
|
||||
try {
|
||||
const { name, arguments: args } = request.params;
|
||||
|
||||
switch (name) {
|
||||
case 'read_file': {
|
||||
const parsed = ReadFileArgsSchema.safeParse(args);
|
||||
if (!parsed.success) {
|
||||
throw new Error(`Invalid arguments for read_file: ${parsed.error}`);
|
||||
}
|
||||
const validPath = await validatePath(parsed.data.path);
|
||||
const content = await fs.readFile(validPath, 'utf-8');
|
||||
return {
|
||||
content: [{ type: 'text', text: content }],
|
||||
};
|
||||
}
|
||||
|
||||
case 'read_multiple_files': {
|
||||
const parsed = ReadMultipleFilesArgsSchema.safeParse(args);
|
||||
if (!parsed.success) {
|
||||
throw new Error(`Invalid arguments for read_multiple_files: ${parsed.error}`);
|
||||
}
|
||||
const results = await Promise.all(
|
||||
parsed.data.paths.map(async (filePath: string) => {
|
||||
try {
|
||||
const validPath = await validatePath(filePath);
|
||||
const content = await fs.readFile(validPath, 'utf-8');
|
||||
return `${filePath}:\n${content}\n`;
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : String(error);
|
||||
return `${filePath}: Error - ${errorMessage}`;
|
||||
}
|
||||
}),
|
||||
);
|
||||
return {
|
||||
content: [{ type: 'text', text: results.join('\n---\n') }],
|
||||
};
|
||||
}
|
||||
|
||||
case 'write_file': {
|
||||
const parsed = WriteFileArgsSchema.safeParse(args);
|
||||
if (!parsed.success) {
|
||||
throw new Error(`Invalid arguments for write_file: ${parsed.error}`);
|
||||
}
|
||||
const validPath = await validatePath(parsed.data.path);
|
||||
await fs.writeFile(validPath, parsed.data.content, 'utf-8');
|
||||
return {
|
||||
content: [{ type: 'text', text: `Successfully wrote to ${parsed.data.path}` }],
|
||||
};
|
||||
}
|
||||
|
||||
case 'edit_file': {
|
||||
const parsed = EditFileArgsSchema.safeParse(args);
|
||||
if (!parsed.success) {
|
||||
throw new Error(`Invalid arguments for edit_file: ${parsed.error}`);
|
||||
}
|
||||
const validPath = await validatePath(parsed.data.path);
|
||||
const result = await applyFileEdits(validPath, parsed.data.edits, parsed.data.dryRun);
|
||||
return {
|
||||
content: [{ type: 'text', text: result }],
|
||||
};
|
||||
}
|
||||
|
||||
case 'create_directory': {
|
||||
const parsed = CreateDirectoryArgsSchema.safeParse(args);
|
||||
if (!parsed.success) {
|
||||
throw new Error(`Invalid arguments for create_directory: ${parsed.error}`);
|
||||
}
|
||||
const validPath = await validatePath(parsed.data.path);
|
||||
await fs.mkdir(validPath, { recursive: true });
|
||||
return {
|
||||
content: [{ type: 'text', text: `Successfully created directory ${parsed.data.path}` }],
|
||||
};
|
||||
}
|
||||
|
||||
case 'list_directory': {
|
||||
const parsed = ListDirectoryArgsSchema.safeParse(args);
|
||||
if (!parsed.success) {
|
||||
throw new Error(`Invalid arguments for list_directory: ${parsed.error}`);
|
||||
}
|
||||
const validPath = await validatePath(parsed.data.path);
|
||||
const entries = await fs.readdir(validPath, { withFileTypes: true });
|
||||
const formatted = entries
|
||||
.map((entry) => `${entry.isDirectory() ? '[DIR]' : '[FILE]'} ${entry.name}`)
|
||||
.join('\n');
|
||||
return {
|
||||
content: [{ type: 'text', text: formatted }],
|
||||
};
|
||||
}
|
||||
|
||||
case 'move_file': {
|
||||
const parsed = MoveFileArgsSchema.safeParse(args);
|
||||
if (!parsed.success) {
|
||||
throw new Error(`Invalid arguments for move_file: ${parsed.error}`);
|
||||
}
|
||||
const validSourcePath = await validatePath(parsed.data.source);
|
||||
const validDestPath = await validatePath(parsed.data.destination);
|
||||
await fs.rename(validSourcePath, validDestPath);
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: `Successfully moved ${parsed.data.source} to ${parsed.data.destination}`,
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
|
||||
case 'search_files': {
|
||||
const parsed = SearchFilesArgsSchema.safeParse(args);
|
||||
if (!parsed.success) {
|
||||
throw new Error(`Invalid arguments for search_files: ${parsed.error}`);
|
||||
}
|
||||
const validPath = await validatePath(parsed.data.path);
|
||||
const results = await searchFiles(
|
||||
validPath,
|
||||
parsed.data.pattern,
|
||||
parsed.data.excludePatterns,
|
||||
);
|
||||
return {
|
||||
content: [
|
||||
{ type: 'text', text: results.length > 0 ? results.join('\n') : 'No matches found' },
|
||||
],
|
||||
};
|
||||
}
|
||||
|
||||
case 'get_file_info': {
|
||||
const parsed = GetFileInfoArgsSchema.safeParse(args);
|
||||
if (!parsed.success) {
|
||||
throw new Error(`Invalid arguments for get_file_info: ${parsed.error}`);
|
||||
}
|
||||
const validPath = await validatePath(parsed.data.path);
|
||||
const info = await getFileStats(validPath);
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: Object.entries(info)
|
||||
.map(([key, value]) => `${key}: ${value}`)
|
||||
.join('\n'),
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
|
||||
case 'list_allowed_directories': {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: `Allowed directories:\n${allowedDirectories.join('\n')}`,
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
|
||||
default:
|
||||
throw new Error(`Unknown tool: ${name}`);
|
||||
}
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : String(error);
|
||||
return {
|
||||
content: [{ type: 'text', text: `Error: ${errorMessage}` }],
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
});
|
||||
|
||||
// Start server
|
||||
// async function runServer() {
|
||||
// const transport = new StdioServerTransport();
|
||||
// await server.connect(transport);
|
||||
// console.error('Secure MCP Filesystem Server running on stdio');
|
||||
// console.error('Allowed directories:', allowedDirectories);
|
||||
// }
|
||||
|
||||
// runServer().catch((error) => {
|
||||
// console.error('Fatal error running server:', error);
|
||||
// process.exit(1);
|
||||
// });
|
||||
|
||||
async function runServer(transport: 'stdio' | 'sse', port?: number) {
|
||||
if (transport === 'stdio') {
|
||||
const stdioTransport = new StdioServerTransport();
|
||||
await server.connect(stdioTransport);
|
||||
console.error('Secure MCP Filesystem Server running on stdio');
|
||||
console.error('Allowed directories:', allowedDirectories);
|
||||
} else {
|
||||
const app = express();
|
||||
app.use(express.json());
|
||||
|
||||
// Set up CORS
|
||||
app.use((req, res, next) => {
|
||||
res.header('Access-Control-Allow-Origin', '*');
|
||||
res.header('Access-Control-Allow-Methods', 'GET, POST, OPTIONS');
|
||||
res.header('Access-Control-Allow-Headers', 'Content-Type');
|
||||
if (req.method === 'OPTIONS') {
|
||||
return res.sendStatus(200);
|
||||
}
|
||||
next();
|
||||
});
|
||||
|
||||
let transport: SSEServerTransport;
|
||||
|
||||
// SSE endpoint
|
||||
app.get('/sse', async (req, res) => {
|
||||
console.log('New SSE connection');
|
||||
transport = new SSEServerTransport('/message', res);
|
||||
await server.connect(transport);
|
||||
|
||||
// Cleanup on close
|
||||
res.on('close', async () => {
|
||||
console.log('SSE connection closed');
|
||||
await server.close();
|
||||
});
|
||||
});
|
||||
|
||||
// Message endpoint
|
||||
app.post('/message', async (req, res) => {
|
||||
if (!transport) {
|
||||
return res.status(503).send('SSE connection not established');
|
||||
}
|
||||
await transport.handlePostMessage(req, res);
|
||||
});
|
||||
|
||||
const serverPort = port || 3001;
|
||||
app.listen(serverPort, () => {
|
||||
console.log(
|
||||
`Secure MCP Filesystem Server running on SSE at http://localhost:${serverPort}/sse`,
|
||||
);
|
||||
console.log('Allowed directories:', allowedDirectories);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (directories.length === 0) {
|
||||
console.error(
|
||||
'Usage: mcp-server-filesystem [--transport=stdio|sse] [--port=3000] <allowed-directory> [additional-directories...]',
|
||||
);
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
// Start the server with the specified transport
|
||||
runServer(transport, port).catch((error) => {
|
||||
console.error('Fatal error running server:', error);
|
||||
process.exit(1);
|
||||
});
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
/* MCP */
|
||||
export * from './manager';
|
||||
/* Utilities */
|
||||
export * from './utils';
|
||||
/* Flow */
|
||||
export * from './flow/manager';
|
||||
/* types */
|
||||
export type * from './types/mcp';
|
||||
export type * from './flow/types';
|
||||
Loading…
Add table
Add a link
Reference in a new issue